smfmac_xdlops_gemm.hpp Source File

smfmac_xdlops_gemm.hpp Source File#

Composable Kernel: smfmac_xdlops_gemm.hpp Source File
smfmac_xdlops_gemm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck/utility/math.hpp"
9
10namespace ck {
11
19
20template <SmfmacInstr instr>
22
23template <>
25{
26 static constexpr index_t group_size = 4;
27 static constexpr index_t num_groups_per_blk = 1;
28 static constexpr index_t num_regs_per_blk = 4;
29 static constexpr index_t num_threads_per_blk = 16;
30 static constexpr index_t wave_size = 64;
31 static constexpr index_t num_input_blks = 4;
32 static constexpr index_t num_output_blks = 1;
33 static constexpr index_t m_per_blk = 16;
34 static constexpr index_t n_per_blk = 16;
35 static constexpr index_t k_per_blk = 8;
36 static constexpr bool is_k_reduction = true;
37
38 template <index_t MPerXdlops,
39 index_t NPerXdlops,
40 index_t idx_part,
41 class FloatA,
42 class FloatB,
43 class FloatC>
44 __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
45 {
47 a, b, idx, reg_c);
48 }
49};
50
51template <>
53{
54 static constexpr index_t group_size = 4;
55 static constexpr index_t num_groups_per_blk = 4;
56 static constexpr index_t num_regs_per_blk = 16;
57 static constexpr index_t num_threads_per_blk = 32;
58 static constexpr index_t wave_size = 64;
59 static constexpr index_t num_input_blks = 2;
60 static constexpr index_t num_output_blks = 1;
61 static constexpr index_t m_per_blk = 32;
62 static constexpr index_t n_per_blk = 32;
63 static constexpr index_t k_per_blk = 16;
64 static constexpr bool is_k_reduction = true;
65
66 template <index_t MPerXdlops,
67 index_t NPerXdlops,
68 index_t idx_part,
69 class FloatA,
70 class FloatB,
71 class FloatC>
72 __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
73 {
75 a, b, idx, reg_c);
76 }
77};
78
79template <>
81{
82 static constexpr index_t group_size = 4;
83 static constexpr index_t num_groups_per_blk = 1;
84 static constexpr index_t num_regs_per_blk = 4;
85 static constexpr index_t num_threads_per_blk = 16;
86 static constexpr index_t wave_size = 64;
87 static constexpr index_t num_input_blks = 4;
88 static constexpr index_t num_output_blks = 1;
89 static constexpr index_t m_per_blk = 16;
90 static constexpr index_t n_per_blk = 16;
91 static constexpr index_t k_per_blk = 8;
92 static constexpr bool is_k_reduction = true;
93
94 template <index_t MPerXdlops,
95 index_t NPerXdlops,
96 index_t idx_part,
97 class FloatA,
98 class FloatB,
99 class FloatC>
100 __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
101 {
103 a, b, idx, reg_c);
104 }
105};
106
107template <>
109{
110 static constexpr index_t group_size = 4;
111 static constexpr index_t num_groups_per_blk = 4;
112 static constexpr index_t num_regs_per_blk = 16;
113 static constexpr index_t num_threads_per_blk = 32;
114 static constexpr index_t wave_size = 64;
115 static constexpr index_t num_input_blks = 2;
116 static constexpr index_t num_output_blks = 1;
117 static constexpr index_t m_per_blk = 32;
118 static constexpr index_t n_per_blk = 32;
119 static constexpr index_t k_per_blk = 16;
120 static constexpr bool is_k_reduction = true;
121
122 template <index_t MPerXdlops,
123 index_t NPerXdlops,
124 index_t idx_part,
125 class FloatA,
126 class FloatB,
127 class FloatC>
128 __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
129 {
131 a, b, idx, reg_c);
132 }
133};
134
135template <typename base_type,
136 index_t MPerXdlops,
137 index_t NPerXdlops,
138 typename additional_type = base_type>
140{
141 template <typename base_type_,
142 index_t MPerXdlops_,
143 index_t NPerXdlops_,
144 typename additional_type_ = base_type_>
145 static constexpr auto GetSmfmac();
146
147 template <>
148 static constexpr auto GetSmfmac<half_t, 16, 16>()
149 {
151 }
152
153 template <>
154 static constexpr auto GetSmfmac<half_t, 32, 32>()
155 {
157 }
158
159 template <>
160 static constexpr auto GetSmfmac<bhalf_t, 16, 16>()
161 {
163 }
164
165 template <>
166 static constexpr auto GetSmfmac<bhalf_t, 32, 32>()
167 {
169 }
170
173
174 __host__ __device__ constexpr SmfmacSelector()
175 {
176 static_assert(selected_smfmac.group_size * selected_smfmac.num_groups_per_blk ==
177 selected_smfmac.num_regs_per_blk,
178 "wrong! num_regs_per_blk");
179
180 static_assert(selected_smfmac.num_threads_per_blk == selected_smfmac.n_per_blk,
181 "n_per_blk != num_threads_per_blk");
182
183 static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.num_input_blks ==
184 selected_smfmac.m_per_blk,
185 "m_per_blk != num_input_blks * num_regs_per_blk");
186
187 static_assert(selected_smfmac.num_output_blks == selected_smfmac.num_input_blks ||
188 selected_smfmac.num_output_blks == 1,
189 "incorrect num_output_blks");
190
191 static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.wave_size ==
192 selected_smfmac.m_per_blk * selected_smfmac.n_per_blk,
193 "num_regs_per_blk incorrect");
194
195 static_assert(selected_smfmac.is_k_reduction ||
196 (selected_smfmac.num_input_blks == selected_smfmac.num_output_blks),
197 "is_k_reduction wrong!");
198 }
199
200 static constexpr index_t GetKPerXdlops()
201 {
202 return (selected_smfmac.is_k_reduction ? selected_smfmac.num_input_blks : 1) *
203 selected_smfmac.k_per_blk;
204 }
205
206 static constexpr index_t GetK1PerXdlops() { return selected_smfmac.k_per_blk; }
207};
208
209template <typename base_type,
210 index_t MPerXdlops,
211 index_t NPerXdlops,
212 index_t KPack,
213 typename additional_type = base_type>
215{
216 static constexpr auto I0 = Number<0>{};
217 static constexpr auto I1 = Number<1>{};
218 static constexpr auto I2 = Number<2>{};
219 static constexpr auto I3 = Number<3>{};
220 static constexpr auto I4 = Number<4>{};
221 static constexpr auto I5 = Number<5>{};
222
225
226 __device__ static constexpr index_t GetNumBlks() { return smfmac_instr.num_output_blks; }
227
228 __device__ static constexpr index_t GetNumXdlops()
229 {
230 return MPerXdlops * NPerXdlops /
231 (smfmac_instr.m_per_blk * smfmac_instr.n_per_blk * smfmac_instr.num_output_blks);
232 }
233
234 __host__ __device__ constexpr SparseXdlopsGemm()
235 {
236 static_assert(NPerXdlops == 16 || NPerXdlops == 32,
237 "Only support GemmNPerXdlops == 16 or 32 for smfmac xdlops");
238
239 static_assert(MPerXdlops == 16 || MPerXdlops == 32,
240 "Only support GemmMPerXdlops == 16 or 32 for smfmac xdlops");
241
242 static_assert(KPack % smfmac_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk");
243 }
244
245 // XDL output supporting C = A * B
246 // M2_N2 -> M2_M3_M4_N2
247 template <typename CDesc_M0_N0_M1_N1_M2_N2>
248 __host__ __device__ static constexpr auto
249 MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
250 {
251 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
252 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
253 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
254 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
255
257 c_desc_m0_n0_m1_n1_m2_n2,
263 Number<smfmac_instr.num_input_blks>{},
264 Number<smfmac_instr.group_size>{})),
267 Sequence<1>{},
268 Sequence<2>{},
269 Sequence<3>{},
270 Sequence<4>{},
271 Sequence<5>{}),
273 Sequence<1>{},
274 Sequence<2>{},
275 Sequence<3>{},
277 Sequence<7>{}));
278 }
279
280 template <typename CDesc_G_M0_N0_M1_N1_M2_N2>
281 __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
282 const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
283 {
284 const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
285 const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
286 const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
287 const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
288 const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
289
291 c_desc_g_m0_n0_m1_n1_m2_n2,
298 smfmac_instr.num_input_blks,
299 smfmac_instr.group_size)),
300 make_pass_through_transform(smfmac_instr.num_threads_per_blk)),
302 Sequence<1>{},
303 Sequence<2>{},
304 Sequence<3>{},
305 Sequence<4>{},
306 Sequence<5>{},
307 Sequence<6>{}),
309 Sequence<1>{},
310 Sequence<2>{},
311 Sequence<3>{},
312 Sequence<4>{},
314 Sequence<8>{}));
315 }
316
317 __device__ static constexpr index_t GetRegSizePerXdlops()
318 {
319 return MPerXdlops * NPerXdlops / smfmac_instr.wave_size;
320 }
321
322 __device__ static constexpr index_t GetWaveSize() { return smfmac_instr.wave_size; }
323
324 template <class FloatA, class FloatB, class Idx, class FloatC>
325 __device__ void
326 Run(const FloatA& p_a_wave, const FloatB& p_b_wave, const Idx& idx, FloatC& p_c_thread) const
327 {
329 "base base_type must be half or bfloat16!");
330
331 static_for<0, KPack / smfmac_instr.k_per_blk, 1>{}([&](auto k) {
332 smfmac_instr.template run<MPerXdlops, NPerXdlops, k % 4>(
333 p_a_wave[k], p_b_wave[k], idx[k / 4], p_c_thread);
334 });
335 }
336
337 __device__ static auto GetLaneId() { return get_thread_local_1d_id() % smfmac_instr.wave_size; }
338
339 __device__ static auto GetBlkIdx()
340 {
341 const auto laneId = GetLaneId();
342
343 constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
345 make_tuple(1, smfmac_instr.num_input_blks, smfmac_instr.num_threads_per_blk))),
348
349 const auto blk_idx =
350 threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
351
352 const auto blk_id = blk_idx[I1];
353 const auto blk_td = blk_idx[I2];
354
355 return make_tuple(blk_id, blk_td);
356 }
357
358 __host__ __device__ static auto CalculateAThreadOriginDataIndex()
359 {
360 const auto laneId = GetLaneId();
361 const auto blk_idx = GetBlkIdx();
362
363 const auto blk_id = blk_idx[I0];
364 const auto blk_td = blk_idx[I1];
365
366 if constexpr(smfmac_instr.is_k_reduction)
367 {
368 return make_tuple(blk_id, blk_td);
369 }
370 else
371 {
372 return make_tuple(0, laneId);
373 }
374 }
375
376 __host__ __device__ static auto CalculateBThreadOriginDataIndex()
377 {
378 const auto laneId = GetLaneId();
379 const auto blk_idx = GetBlkIdx();
380
381 const auto blk_id = blk_idx[I0];
382 const auto blk_td = blk_idx[I1];
383
384 if constexpr(smfmac_instr.is_k_reduction)
385 {
386 return make_tuple(blk_id, blk_td);
387 }
388 else
389 {
390 return make_tuple(0, laneId);
391 }
392 }
393
394 __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
395 {
396 const auto blk_idx = GetBlkIdx();
397
398 const auto blk_id = blk_idx[I0];
399 const auto blk_td = blk_idx[I1];
400
401 index_t n_offset = blk_i * smfmac_instr.n_per_blk + blk_td;
402 index_t m_offset = xdlops_i * smfmac_instr.m_per_blk + blk_id * smfmac_instr.group_size;
403
404 return CIndex{m_offset, n_offset};
405 }
406
407 __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */)
408 {
409 const auto blk_idx = GetBlkIdx();
410
411 const auto blk_id = blk_idx[I0];
412 const auto blk_td = blk_idx[I1];
413
414 return CIndex4D{I0, blk_id, I0, blk_td};
415 }
416
419
420 static constexpr auto smfmac_instr = smfmac.selected_smfmac;
421
422 static constexpr auto KPerXdlops = smfmac.GetKPerXdlops();
423 static constexpr auto K1PerXdlops = smfmac.GetK1PerXdlops();
424 static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
425
426 __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
427 {
428 return make_tuple(
430 }
431};
432
433} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
SmfmacInstr
Definition smfmac_xdlops_gemm.hpp:13
@ smfmac_f32_16x16x32bf16
Definition smfmac_xdlops_gemm.hpp:16
@ smfmac_f32_16x16x32f16
Definition smfmac_xdlops_gemm.hpp:14
@ smfmac_f32_32x32x16f16
Definition smfmac_xdlops_gemm.hpp:15
@ smfmac_f32_32x32x16bf16
Definition smfmac_xdlops_gemm.hpp:17
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition utility/sequence.hpp:43
Definition smfmac_xdlops_gemm.hpp:140
__host__ __device__ constexpr SmfmacSelector()
Definition smfmac_xdlops_gemm.hpp:174
static constexpr index_t GetKPerXdlops()
Definition smfmac_xdlops_gemm.hpp:200
static constexpr auto GetSmfmac()
static constexpr auto selected_smfmac
Definition smfmac_xdlops_gemm.hpp:171
static constexpr index_t GetK1PerXdlops()
Definition smfmac_xdlops_gemm.hpp:206
static __device__ auto GetLaneId()
Definition smfmac_xdlops_gemm.hpp:337
static __device__ constexpr index_t GetNumBlks()
Definition smfmac_xdlops_gemm.hpp:226
static constexpr auto K0PerXdlops
Definition smfmac_xdlops_gemm.hpp:424
static constexpr auto I2
Definition smfmac_xdlops_gemm.hpp:218
static __device__ CIndex4D GetBeginOfThreadBlk4D(index_t, index_t)
Definition smfmac_xdlops_gemm.hpp:407
static __device__ constexpr index_t GetWaveSize()
Definition smfmac_xdlops_gemm.hpp:322
MultiIndex< 4 > CIndex4D
Definition smfmac_xdlops_gemm.hpp:224
static constexpr auto I0
Definition smfmac_xdlops_gemm.hpp:216
static constexpr auto I5
Definition smfmac_xdlops_gemm.hpp:221
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, const Idx &idx, FloatC &p_c_thread) const
Definition smfmac_xdlops_gemm.hpp:326
static constexpr auto K1PerXdlops
Definition smfmac_xdlops_gemm.hpp:423
static constexpr auto I1
Definition smfmac_xdlops_gemm.hpp:217
__host__ static __device__ constexpr auto GetCM0M1M2NThreadBlkLengths()
Definition smfmac_xdlops_gemm.hpp:426
static __device__ auto GetBlkIdx()
Definition smfmac_xdlops_gemm.hpp:339
MultiIndex< 2 > CIndex
Definition smfmac_xdlops_gemm.hpp:223
__host__ __device__ constexpr SparseXdlopsGemm()
Definition smfmac_xdlops_gemm.hpp:234
static constexpr auto I4
Definition smfmac_xdlops_gemm.hpp:220
static __device__ CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
Definition smfmac_xdlops_gemm.hpp:394
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition smfmac_xdlops_gemm.hpp:358
static constexpr auto smfmac_instr
Definition smfmac_xdlops_gemm.hpp:420
static constexpr auto smfmac
Definition smfmac_xdlops_gemm.hpp:417
static constexpr auto KPerXdlops
Definition smfmac_xdlops_gemm.hpp:422
static __device__ constexpr index_t GetNumXdlops()
Definition smfmac_xdlops_gemm.hpp:228
__host__ static __device__ constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition smfmac_xdlops_gemm.hpp:249
static __device__ constexpr index_t GetRegSizePerXdlops()
Definition smfmac_xdlops_gemm.hpp:317
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition smfmac_xdlops_gemm.hpp:376
static constexpr auto I3
Definition smfmac_xdlops_gemm.hpp:219
__host__ static __device__ constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_G_M0_N0_M1_N1_M2_N2 &c_desc_g_m0_n0_m1_n1_m2_n2)
Definition smfmac_xdlops_gemm.hpp:281
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition amd_smfmac.hpp:34
Definition amd_smfmac.hpp:10
Definition amd_smfmac.hpp:78
Definition amd_smfmac.hpp:56
static constexpr index_t m_per_blk
Definition smfmac_xdlops_gemm.hpp:89
static constexpr index_t group_size
Definition smfmac_xdlops_gemm.hpp:82
static constexpr index_t num_threads_per_blk
Definition smfmac_xdlops_gemm.hpp:85
static constexpr index_t num_input_blks
Definition smfmac_xdlops_gemm.hpp:87
static constexpr bool is_k_reduction
Definition smfmac_xdlops_gemm.hpp:92
__device__ void run(const FloatA &a, const FloatB &b, const index_t &idx, FloatC &reg_c) const
Definition smfmac_xdlops_gemm.hpp:100
static constexpr index_t num_groups_per_blk
Definition smfmac_xdlops_gemm.hpp:83
static constexpr index_t num_regs_per_blk
Definition smfmac_xdlops_gemm.hpp:84
static constexpr index_t k_per_blk
Definition smfmac_xdlops_gemm.hpp:91
static constexpr index_t wave_size
Definition smfmac_xdlops_gemm.hpp:86
static constexpr index_t num_output_blks
Definition smfmac_xdlops_gemm.hpp:88
static constexpr index_t n_per_blk
Definition smfmac_xdlops_gemm.hpp:90
static constexpr index_t m_per_blk
Definition smfmac_xdlops_gemm.hpp:33
static constexpr index_t num_threads_per_blk
Definition smfmac_xdlops_gemm.hpp:29
static constexpr index_t group_size
Definition smfmac_xdlops_gemm.hpp:26
static constexpr index_t num_output_blks
Definition smfmac_xdlops_gemm.hpp:32
static constexpr bool is_k_reduction
Definition smfmac_xdlops_gemm.hpp:36
static constexpr index_t wave_size
Definition smfmac_xdlops_gemm.hpp:30
static constexpr index_t num_groups_per_blk
Definition smfmac_xdlops_gemm.hpp:27
static constexpr index_t k_per_blk
Definition smfmac_xdlops_gemm.hpp:35
__device__ void run(const FloatA &a, const FloatB &b, const index_t &idx, FloatC &reg_c) const
Definition smfmac_xdlops_gemm.hpp:44
static constexpr index_t n_per_blk
Definition smfmac_xdlops_gemm.hpp:34
static constexpr index_t num_input_blks
Definition smfmac_xdlops_gemm.hpp:31
static constexpr index_t num_regs_per_blk
Definition smfmac_xdlops_gemm.hpp:28
static constexpr index_t n_per_blk
Definition smfmac_xdlops_gemm.hpp:118
static constexpr index_t group_size
Definition smfmac_xdlops_gemm.hpp:110
static constexpr index_t num_groups_per_blk
Definition smfmac_xdlops_gemm.hpp:111
static constexpr bool is_k_reduction
Definition smfmac_xdlops_gemm.hpp:120
static constexpr index_t k_per_blk
Definition smfmac_xdlops_gemm.hpp:119
static constexpr index_t num_regs_per_blk
Definition smfmac_xdlops_gemm.hpp:112
static constexpr index_t num_output_blks
Definition smfmac_xdlops_gemm.hpp:116
static constexpr index_t num_threads_per_blk
Definition smfmac_xdlops_gemm.hpp:113
static constexpr index_t num_input_blks
Definition smfmac_xdlops_gemm.hpp:115
__device__ void run(const FloatA &a, const FloatB &b, const index_t &idx, FloatC &reg_c) const
Definition smfmac_xdlops_gemm.hpp:128
static constexpr index_t wave_size
Definition smfmac_xdlops_gemm.hpp:114
static constexpr index_t m_per_blk
Definition smfmac_xdlops_gemm.hpp:117
static constexpr bool is_k_reduction
Definition smfmac_xdlops_gemm.hpp:64
static constexpr index_t m_per_blk
Definition smfmac_xdlops_gemm.hpp:61
static constexpr index_t group_size
Definition smfmac_xdlops_gemm.hpp:54
static constexpr index_t n_per_blk
Definition smfmac_xdlops_gemm.hpp:62
static constexpr index_t wave_size
Definition smfmac_xdlops_gemm.hpp:58
static constexpr index_t num_threads_per_blk
Definition smfmac_xdlops_gemm.hpp:57
static constexpr index_t num_output_blks
Definition smfmac_xdlops_gemm.hpp:60
static constexpr index_t num_input_blks
Definition smfmac_xdlops_gemm.hpp:59
static constexpr index_t num_groups_per_blk
Definition smfmac_xdlops_gemm.hpp:55
static constexpr index_t num_regs_per_blk
Definition smfmac_xdlops_gemm.hpp:56
__device__ void run(const FloatA &a, const FloatB &b, const index_t &idx, FloatC &reg_c) const
Definition smfmac_xdlops_gemm.hpp:72
static constexpr index_t k_per_blk
Definition smfmac_xdlops_gemm.hpp:63
Definition smfmac_xdlops_gemm.hpp:21
Definition functional2.hpp:33