| /* { dg-do run { target { ! ia32 } } } */ |
| /* { dg-require-effective-target amx_tile } */ |
| /* { dg-require-effective-target amx_bf16 } */ |
| /* { dg-options "-O2 -mamx-tile -mamx-bf16" } */ |
| #include <immintrin.h> |
| |
| #define AMX_BF16 |
| #define DO_TEST test_amx_bf16_dpbf16ps |
| void test_amx_bf16_dpbf16ps (); |
| #include "amx-check.h" |
| |
| /* Transformation functions between bf16/float */ |
| static uint16_t make_bf16 (float f) |
| { |
| union |
| { |
| float f; |
| uint32_t u; |
| } fu; |
| fu.f = f; |
| fu.u = (fu.u >> 16) & 0xffff; |
| return (uint16_t) fu.u; |
| } |
| |
| static float make_f32 (uint16_t bf) |
| { |
| union |
| { |
| float f; |
| uint32_t u; |
| } fu; |
| fu.u = (uint32_t) bf << 16; |
| return fu.f; |
| } |
| |
| /* Init tile buffer with bf16 pairs */ |
| void init_bf16_max_tile_buffer (uint8_t *buf) |
| { |
| int i, j; |
| uint16_t *ptr = (uint16_t *)buf; |
| |
| for(i = 0; i < 16; i++) |
| for(j = 0; j < 32; j++) |
| { |
| float f = 16.1f * i + 3.4f * j; |
| ptr[i * 32 + j] = make_bf16(f); |
| } |
| } |
| |
| void calc_matrix_dpbf16ps (__tile *dst, __tile *src1, __tile *src2) |
| { |
| uint16_t *src1_buf = (uint16_t *)src1->buf; |
| uint16_t *src2_buf = (uint16_t *)src2->buf; |
| float *dst_buf = (float *)dst->buf; |
| |
| int M = src1->rows; |
| int N = src1->colsb / 4; |
| int K = src2->colsb / 4; |
| int i, j, k, t; |
| |
| for (i = 0; i < M; i++) |
| for (j = 0; j < N; j++) |
| for (k = 0; k < K; k++) |
| for (t = 0; t < 2; t+=2) |
| { |
| dst_buf[i * N + k] += |
| (make_f32(src1_buf[i * 2 * N + 2 * j + t]) * |
| make_f32(src2_buf[j * 2 * K + 2 * k + t])) + |
| (make_f32(src1_buf[i * 2 * N + 2 * j + t + 1]) * |
| make_f32(src2_buf[j * 2 * K + 2 * k + t + 1])); |
| } |
| |
| } |
| |
| void test_amx_bf16_dpbf16ps () |
| { |
| __tilecfg_u cfg; |
| __tile dst, dst_ref, src1, src2; |
| uint8_t tmp_dst_buf[1024]; |
| |
| init_bf16_max_tile_buffer (tmp_dst_buf); |
| |
| init_tile_config (&cfg); |
| init_tile_reg_and_src_with_buffer (1, dst, tmp_dst_buf); |
| init_tile_reg_and_src_with_buffer (2, src1, tmp_dst_buf); |
| init_tile_reg_and_src_with_buffer (3, src2, tmp_dst_buf); |
| |
| calc_matrix_dpbf16ps (&dst, &src1, &src2); |
| |
| _tile_dpbf16ps (1, 2, 3); |
| _tile_stored (1, dst_ref.buf, _STRIDE); |
| |
| if (!check_float_tile_register (&dst_ref, &dst)) |
| abort(); |
| } |