matmul_internal.m4 16 KB


  1. `void
  2. 'matmul_name` ('rtype` * const restrict retarray,
  3. 'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
  4. int blas_limit, blas_call gemm)
  5. {
  6. const 'rtype_name` * restrict abase;
  7. const 'rtype_name` * restrict bbase;
  8. 'rtype_name` * restrict dest;
  9. index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
  10. index_type x, y, n, count, xcount, ycount;
  11. assert (GFC_DESCRIPTOR_RANK (a) == 2
  12. || GFC_DESCRIPTOR_RANK (b) == 2);
  13. /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
  14. Either A or B (but not both) can be rank 1:
  15. o One-dimensional argument A is implicitly treated as a row matrix
  16. dimensioned [1,count], so xcount=1.
  17. o One-dimensional argument B is implicitly treated as a column matrix
  18. dimensioned [count, 1], so ycount=1.
  19. */
  20. if (retarray->base_addr == NULL)
  21. {
  22. if (GFC_DESCRIPTOR_RANK (a) == 1)
  23. {
  24. GFC_DIMENSION_SET(retarray->dim[0], 0,
  25. GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
  26. }
  27. else if (GFC_DESCRIPTOR_RANK (b) == 1)
  28. {
  29. GFC_DIMENSION_SET(retarray->dim[0], 0,
  30. GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
  31. }
  32. else
  33. {
  34. GFC_DIMENSION_SET(retarray->dim[0], 0,
  35. GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
  36. GFC_DIMENSION_SET(retarray->dim[1], 0,
  37. GFC_DESCRIPTOR_EXTENT(b,1) - 1,
  38. GFC_DESCRIPTOR_EXTENT(retarray,0));
  39. }
  40. retarray->base_addr
  41. = xmallocarray (size0 ((array_t *) retarray), sizeof ('rtype_name`));
  42. retarray->offset = 0;
  43. }
  44. else if (unlikely (compile_options.bounds_check))
  45. {
  46. index_type ret_extent, arg_extent;
  47. if (GFC_DESCRIPTOR_RANK (a) == 1)
  48. {
  49. arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
  50. ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
  51. if (arg_extent != ret_extent)
  52. runtime_error ("Array bound mismatch for dimension 1 of "
  53. "array (%ld/%ld) ",
  54. (long int) ret_extent, (long int) arg_extent);
  55. }
  56. else if (GFC_DESCRIPTOR_RANK (b) == 1)
  57. {
  58. arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
  59. ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
  60. if (arg_extent != ret_extent)
  61. runtime_error ("Array bound mismatch for dimension 1 of "
  62. "array (%ld/%ld) ",
  63. (long int) ret_extent, (long int) arg_extent);
  64. }
  65. else
  66. {
  67. arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
  68. ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
  69. if (arg_extent != ret_extent)
  70. runtime_error ("Array bound mismatch for dimension 1 of "
  71. "array (%ld/%ld) ",
  72. (long int) ret_extent, (long int) arg_extent);
  73. arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
  74. ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
  75. if (arg_extent != ret_extent)
  76. runtime_error ("Array bound mismatch for dimension 2 of "
  77. "array (%ld/%ld) ",
  78. (long int) ret_extent, (long int) arg_extent);
  79. }
  80. }
  81. '
  82. sinclude(`matmul_asm_'rtype_code`.m4')dnl
  83. `
  84. if (GFC_DESCRIPTOR_RANK (retarray) == 1)
  85. {
  86. /* One-dimensional result may be addressed in the code below
  87. either as a row or a column matrix. We want both cases to
  88. work. */
  89. rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
  90. }
  91. else
  92. {
  93. rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
  94. rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
  95. }
  96. if (GFC_DESCRIPTOR_RANK (a) == 1)
  97. {
  98. /* Treat it as a a row matrix A[1,count]. */
  99. axstride = GFC_DESCRIPTOR_STRIDE(a,0);
  100. aystride = 1;
  101. xcount = 1;
  102. count = GFC_DESCRIPTOR_EXTENT(a,0);
  103. }
  104. else
  105. {
  106. axstride = GFC_DESCRIPTOR_STRIDE(a,0);
  107. aystride = GFC_DESCRIPTOR_STRIDE(a,1);
  108. count = GFC_DESCRIPTOR_EXTENT(a,1);
  109. xcount = GFC_DESCRIPTOR_EXTENT(a,0);
  110. }
  111. if (count != GFC_DESCRIPTOR_EXTENT(b,0))
  112. {
  113. if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
  114. runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
  115. "in dimension 1: is %ld, should be %ld",
  116. (long int) GFC_DESCRIPTOR_EXTENT(b,0), (long int) count);
  117. }
  118. if (GFC_DESCRIPTOR_RANK (b) == 1)
  119. {
  120. /* Treat it as a column matrix B[count,1] */
  121. bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
  122. /* bystride should never be used for 1-dimensional b.
  123. The value is only used for calculation of the
  124. memory by the buffer. */
  125. bystride = 256;
  126. ycount = 1;
  127. }
  128. else
  129. {
  130. bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
  131. bystride = GFC_DESCRIPTOR_STRIDE(b,1);
  132. ycount = GFC_DESCRIPTOR_EXTENT(b,1);
  133. }
  134. abase = a->base_addr;
  135. bbase = b->base_addr;
  136. dest = retarray->base_addr;
  137. /* Now that everything is set up, we perform the multiplication
  138. itself. */
  139. #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
  140. #define min(a,b) ((a) <= (b) ? (a) : (b))
  141. #define max(a,b) ((a) >= (b) ? (a) : (b))
  142. if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
  143. && (bxstride == 1 || bystride == 1)
  144. && (((float) xcount) * ((float) ycount) * ((float) count)
  145. > POW3(blas_limit)))
  146. {
  147. const int m = xcount, n = ycount, k = count, ldc = rystride;
  148. const 'rtype_name` one = 1, zero = 0;
  149. const int lda = (axstride == 1) ? aystride : axstride,
  150. ldb = (bxstride == 1) ? bystride : bxstride;
  151. if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
  152. {
  153. assert (gemm != NULL);
  154. const char *transa, *transb;
  155. if (try_blas & 2)
  156. transa = "C";
  157. else
  158. transa = axstride == 1 ? "N" : "T";
  159. if (try_blas & 4)
  160. transb = "C";
  161. else
  162. transb = bxstride == 1 ? "N" : "T";
  163. gemm (transa, transb , &m,
  164. &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
  165. &ldc, 1, 1);
  166. return;
  167. }
  168. }
  169. if (rxstride == 1 && axstride == 1 && bxstride == 1
  170. && GFC_DESCRIPTOR_RANK (b) != 1)
  171. {
  172. /* This block of code implements a tuned matmul, derived from
  173. Superscalar GEMM-based level 3 BLAS, Beta version 0.1
  174. Bo Kagstrom and Per Ling
  175. Department of Computing Science
  176. Umea University
  177. S-901 87 Umea, Sweden
  178. from netlib.org, translated to C, and modified for matmul.m4. */
  179. const 'rtype_name` *a, *b;
  180. 'rtype_name` *c;
  181. const index_type m = xcount, n = ycount, k = count;
  182. /* System generated locals */
  183. index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
  184. i1, i2, i3, i4, i5, i6;
  185. /* Local variables */
  186. 'rtype_name` f11, f12, f21, f22, f31, f32, f41, f42,
  187. f13, f14, f23, f24, f33, f34, f43, f44;
  188. index_type i, j, l, ii, jj, ll;
  189. index_type isec, jsec, lsec, uisec, ujsec, ulsec;
  190. 'rtype_name` *t1;
  191. a = abase;
  192. b = bbase;
  193. c = retarray->base_addr;
  194. /* Parameter adjustments */
  195. c_dim1 = rystride;
  196. c_offset = 1 + c_dim1;
  197. c -= c_offset;
  198. a_dim1 = aystride;
  199. a_offset = 1 + a_dim1;
  200. a -= a_offset;
  201. b_dim1 = bystride;
  202. b_offset = 1 + b_dim1;
  203. b -= b_offset;
  204. /* Empty c first. */
  205. for (j=1; j<=n; j++)
  206. for (i=1; i<=m; i++)
  207. c[i + j * c_dim1] = ('rtype_name`)0;
  208. /* Early exit if possible */
  209. if (m == 0 || n == 0 || k == 0)
  210. return;
  211. /* Adjust size of t1 to what is needed. */
  212. index_type t1_dim, a_sz;
  213. if (aystride == 1)
  214. a_sz = rystride;
  215. else
  216. a_sz = a_dim1;
  217. t1_dim = a_sz * 256 + b_dim1;
  218. if (t1_dim > 65536)
  219. t1_dim = 65536;
  220. t1 = malloc (t1_dim * sizeof('rtype_name`));
  221. /* Start turning the crank. */
  222. i1 = n;
  223. for (jj = 1; jj <= i1; jj += 512)
  224. {
  225. /* Computing MIN */
  226. i2 = 512;
  227. i3 = n - jj + 1;
  228. jsec = min(i2,i3);
  229. ujsec = jsec - jsec % 4;
  230. i2 = k;
  231. for (ll = 1; ll <= i2; ll += 256)
  232. {
  233. /* Computing MIN */
  234. i3 = 256;
  235. i4 = k - ll + 1;
  236. lsec = min(i3,i4);
  237. ulsec = lsec - lsec % 2;
  238. i3 = m;
  239. for (ii = 1; ii <= i3; ii += 256)
  240. {
  241. /* Computing MIN */
  242. i4 = 256;
  243. i5 = m - ii + 1;
  244. isec = min(i4,i5);
  245. uisec = isec - isec % 2;
  246. i4 = ll + ulsec - 1;
  247. for (l = ll; l <= i4; l += 2)
  248. {
  249. i5 = ii + uisec - 1;
  250. for (i = ii; i <= i5; i += 2)
  251. {
  252. t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
  253. a[i + l * a_dim1];
  254. t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
  255. a[i + (l + 1) * a_dim1];
  256. t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
  257. a[i + 1 + l * a_dim1];
  258. t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
  259. a[i + 1 + (l + 1) * a_dim1];
  260. }
  261. if (uisec < isec)
  262. {
  263. t1[l - ll + 1 + (isec << 8) - 257] =
  264. a[ii + isec - 1 + l * a_dim1];
  265. t1[l - ll + 2 + (isec << 8) - 257] =
  266. a[ii + isec - 1 + (l + 1) * a_dim1];
  267. }
  268. }
  269. if (ulsec < lsec)
  270. {
  271. i4 = ii + isec - 1;
  272. for (i = ii; i<= i4; ++i)
  273. {
  274. t1[lsec + ((i - ii + 1) << 8) - 257] =
  275. a[i + (ll + lsec - 1) * a_dim1];
  276. }
  277. }
  278. uisec = isec - isec % 4;
  279. i4 = jj + ujsec - 1;
  280. for (j = jj; j <= i4; j += 4)
  281. {
  282. i5 = ii + uisec - 1;
  283. for (i = ii; i <= i5; i += 4)
  284. {
  285. f11 = c[i + j * c_dim1];
  286. f21 = c[i + 1 + j * c_dim1];
  287. f12 = c[i + (j + 1) * c_dim1];
  288. f22 = c[i + 1 + (j + 1) * c_dim1];
  289. f13 = c[i + (j + 2) * c_dim1];
  290. f23 = c[i + 1 + (j + 2) * c_dim1];
  291. f14 = c[i + (j + 3) * c_dim1];
  292. f24 = c[i + 1 + (j + 3) * c_dim1];
  293. f31 = c[i + 2 + j * c_dim1];
  294. f41 = c[i + 3 + j * c_dim1];
  295. f32 = c[i + 2 + (j + 1) * c_dim1];
  296. f42 = c[i + 3 + (j + 1) * c_dim1];
  297. f33 = c[i + 2 + (j + 2) * c_dim1];
  298. f43 = c[i + 3 + (j + 2) * c_dim1];
  299. f34 = c[i + 2 + (j + 3) * c_dim1];
  300. f44 = c[i + 3 + (j + 3) * c_dim1];
  301. i6 = ll + lsec - 1;
  302. for (l = ll; l <= i6; ++l)
  303. {
  304. f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
  305. * b[l + j * b_dim1];
  306. f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
  307. * b[l + j * b_dim1];
  308. f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
  309. * b[l + (j + 1) * b_dim1];
  310. f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
  311. * b[l + (j + 1) * b_dim1];
  312. f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
  313. * b[l + (j + 2) * b_dim1];
  314. f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
  315. * b[l + (j + 2) * b_dim1];
  316. f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
  317. * b[l + (j + 3) * b_dim1];
  318. f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
  319. * b[l + (j + 3) * b_dim1];
  320. f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
  321. * b[l + j * b_dim1];
  322. f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
  323. * b[l + j * b_dim1];
  324. f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
  325. * b[l + (j + 1) * b_dim1];
  326. f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
  327. * b[l + (j + 1) * b_dim1];
  328. f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
  329. * b[l + (j + 2) * b_dim1];
  330. f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
  331. * b[l + (j + 2) * b_dim1];
  332. f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
  333. * b[l + (j + 3) * b_dim1];
  334. f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
  335. * b[l + (j + 3) * b_dim1];
  336. }
  337. c[i + j * c_dim1] = f11;
  338. c[i + 1 + j * c_dim1] = f21;
  339. c[i + (j + 1) * c_dim1] = f12;
  340. c[i + 1 + (j + 1) * c_dim1] = f22;
  341. c[i + (j + 2) * c_dim1] = f13;
  342. c[i + 1 + (j + 2) * c_dim1] = f23;
  343. c[i + (j + 3) * c_dim1] = f14;
  344. c[i + 1 + (j + 3) * c_dim1] = f24;
  345. c[i + 2 + j * c_dim1] = f31;
  346. c[i + 3 + j * c_dim1] = f41;
  347. c[i + 2 + (j + 1) * c_dim1] = f32;
  348. c[i + 3 + (j + 1) * c_dim1] = f42;
  349. c[i + 2 + (j + 2) * c_dim1] = f33;
  350. c[i + 3 + (j + 2) * c_dim1] = f43;
  351. c[i + 2 + (j + 3) * c_dim1] = f34;
  352. c[i + 3 + (j + 3) * c_dim1] = f44;
  353. }
  354. if (uisec < isec)
  355. {
  356. i5 = ii + isec - 1;
  357. for (i = ii + uisec; i <= i5; ++i)
  358. {
  359. f11 = c[i + j * c_dim1];
  360. f12 = c[i + (j + 1) * c_dim1];
  361. f13 = c[i + (j + 2) * c_dim1];
  362. f14 = c[i + (j + 3) * c_dim1];
  363. i6 = ll + lsec - 1;
  364. for (l = ll; l <= i6; ++l)
  365. {
  366. f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
  367. 257] * b[l + j * b_dim1];
  368. f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
  369. 257] * b[l + (j + 1) * b_dim1];
  370. f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
  371. 257] * b[l + (j + 2) * b_dim1];
  372. f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
  373. 257] * b[l + (j + 3) * b_dim1];
  374. }
  375. c[i + j * c_dim1] = f11;
  376. c[i + (j + 1) * c_dim1] = f12;
  377. c[i + (j + 2) * c_dim1] = f13;
  378. c[i + (j + 3) * c_dim1] = f14;
  379. }
  380. }
  381. }
  382. if (ujsec < jsec)
  383. {
  384. i4 = jj + jsec - 1;
  385. for (j = jj + ujsec; j <= i4; ++j)
  386. {
  387. i5 = ii + uisec - 1;
  388. for (i = ii; i <= i5; i += 4)
  389. {
  390. f11 = c[i + j * c_dim1];
  391. f21 = c[i + 1 + j * c_dim1];
  392. f31 = c[i + 2 + j * c_dim1];
  393. f41 = c[i + 3 + j * c_dim1];
  394. i6 = ll + lsec - 1;
  395. for (l = ll; l <= i6; ++l)
  396. {
  397. f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
  398. 257] * b[l + j * b_dim1];
  399. f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
  400. 257] * b[l + j * b_dim1];
  401. f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
  402. 257] * b[l + j * b_dim1];
  403. f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
  404. 257] * b[l + j * b_dim1];
  405. }
  406. c[i + j * c_dim1] = f11;
  407. c[i + 1 + j * c_dim1] = f21;
  408. c[i + 2 + j * c_dim1] = f31;
  409. c[i + 3 + j * c_dim1] = f41;
  410. }
  411. i5 = ii + isec - 1;
  412. for (i = ii + uisec; i <= i5; ++i)
  413. {
  414. f11 = c[i + j * c_dim1];
  415. i6 = ll + lsec - 1;
  416. for (l = ll; l <= i6; ++l)
  417. {
  418. f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
  419. 257] * b[l + j * b_dim1];
  420. }
  421. c[i + j * c_dim1] = f11;
  422. }
  423. }
  424. }
  425. }
  426. }
  427. }
  428. free(t1);
  429. return;
  430. }
  431. else if (rxstride == 1 && aystride == 1 && bxstride == 1)
  432. {
  433. if (GFC_DESCRIPTOR_RANK (a) != 1)
  434. {
  435. const 'rtype_name` *restrict abase_x;
  436. const 'rtype_name` *restrict bbase_y;
  437. 'rtype_name` *restrict dest_y;
  438. 'rtype_name` s;
  439. for (y = 0; y < ycount; y++)
  440. {
  441. bbase_y = &bbase[y*bystride];
  442. dest_y = &dest[y*rystride];
  443. for (x = 0; x < xcount; x++)
  444. {
  445. abase_x = &abase[x*axstride];
  446. s = ('rtype_name`) 0;
  447. for (n = 0; n < count; n++)
  448. s += abase_x[n] * bbase_y[n];
  449. dest_y[x] = s;
  450. }
  451. }
  452. }
  453. else
  454. {
  455. const 'rtype_name` *restrict bbase_y;
  456. 'rtype_name` s;
  457. for (y = 0; y < ycount; y++)
  458. {
  459. bbase_y = &bbase[y*bystride];
  460. s = ('rtype_name`) 0;
  461. for (n = 0; n < count; n++)
  462. s += abase[n*axstride] * bbase_y[n];
  463. dest[y*rystride] = s;
  464. }
  465. }
  466. }
  467. else if (GFC_DESCRIPTOR_RANK (a) == 1)
  468. {
  469. const 'rtype_name` *restrict bbase_y;
  470. 'rtype_name` s;
  471. for (y = 0; y < ycount; y++)
  472. {
  473. bbase_y = &bbase[y*bystride];
  474. s = ('rtype_name`) 0;
  475. for (n = 0; n < count; n++)
  476. s += abase[n*axstride] * bbase_y[n*bxstride];
  477. dest[y*rxstride] = s;
  478. }
  479. }
  480. else if (axstride < aystride)
  481. {
  482. for (y = 0; y < ycount; y++)
  483. for (x = 0; x < xcount; x++)
  484. dest[x*rxstride + y*rystride] = ('rtype_name`)0;
  485. for (y = 0; y < ycount; y++)
  486. for (n = 0; n < count; n++)
  487. for (x = 0; x < xcount; x++)
  488. /* dest[x,y] += a[x,n] * b[n,y] */
  489. dest[x*rxstride + y*rystride] +=
  490. abase[x*axstride + n*aystride] *
  491. bbase[n*bxstride + y*bystride];
  492. }
  493. else
  494. {
  495. const 'rtype_name` *restrict abase_x;
  496. const 'rtype_name` *restrict bbase_y;
  497. 'rtype_name` *restrict dest_y;
  498. 'rtype_name` s;
  499. for (y = 0; y < ycount; y++)
  500. {
  501. bbase_y = &bbase[y*bystride];
  502. dest_y = &dest[y*rystride];
  503. for (x = 0; x < xcount; x++)
  504. {
  505. abase_x = &abase[x*axstride];
  506. s = ('rtype_name`) 0;
  507. for (n = 0; n < count; n++)
  508. s += abase_x[n*aystride] * bbase_y[n*bxstride];
  509. dest_y[x*rxstride] = s;
  510. }
  511. }
  512. }
  513. }
  514. #undef POW3
  515. #undef min
  516. #undef max
  517. '