blockwise_gemm_xdlops.hpp Source File

blockwise_gemm_xdlops.hpp Source File#

Composable Kernel: blockwise_gemm_xdlops.hpp Source File
blockwise_gemm_xdlops.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
14template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
15__host__ __device__ static constexpr auto
16MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
17{
18 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
19 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
20
22 TileDesc_K0_MN_K1{},
28}
29
30template <index_t BlockSize,
31 typename FloatA,
32 typename FloatB,
33 typename FloatAcc,
34 typename AK0MK1BlockDesc,
35 typename BK0NK1BlockDesc,
36 index_t MPerXDL,
37 index_t NPerXDL,
38 index_t MRepeat,
39 index_t NRepeat,
40 index_t KPack,
41 typename ComputeTypeA = FloatA,
42 typename ComputeTypeB = FloatB>
44{
45 static constexpr auto I0 = Number<0>{};
46 static constexpr auto I1 = Number<1>{};
47 static constexpr auto I2 = Number<2>{};
48 static constexpr auto I3 = Number<3>{};
49
51
56
57 static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
58 static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
59 static constexpr index_t KPerBlock =
60 BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
61
62 static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
63 static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
64 static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
65 static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
66
67 static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
68 static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
69 static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
70
71 static constexpr auto xdlops_gemm =
72 XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB, false, false>{};
73
74 static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
75
76 StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
77 FloatAcc,
78 MRepeat * NRepeat,
79 xdlops_gemm.GetRegSizePerXdlops(),
80 true>
82
83 __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
84
85 __device__ static auto GetWaveIdx()
86 {
87 const index_t thread_id = ThisThreadBlock::GetThreadId();
88
89 constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
93
94 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
95 }
96
97 __device__ static auto CalculateAThreadOriginDataIndex()
98 {
99 const auto wave_idx = GetWaveIdx();
100
101 const auto waveId_m = wave_idx[I0];
102
103 const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
104
105 return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
106 }
107
108 __device__ static auto CalculateBThreadOriginDataIndex()
109 {
110 const auto wave_idx = GetWaveIdx();
111
112 const auto waveId_n = wave_idx[I1];
113
114 const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
115
116 return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
117 }
118
119 template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
120 __device__ static auto
122 {
123 const auto wave_idx = GetWaveIdx();
124
125 const auto waveId_m = wave_idx[I0];
126 const auto waveId_n = wave_idx[I1];
127
128 const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
129
130 constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
134
135 constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
139
140 const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
141 make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
142 const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
143 make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
144
145 return make_tuple(c_thread_m, c_thread_n);
146 }
147
148 template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
149 __device__ static auto
151 {
152 const auto wave_idx = GetWaveIdx();
153
154 const auto waveId_m = wave_idx[I0];
155 const auto waveId_n = wave_idx[I1];
156
157 const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
158
159 return make_tuple(Number<m0>{},
160 Number<n0>{},
161 waveId_m,
162 waveId_n,
163 blk_idx[I0],
164 blk_idx[I1],
165 blk_idx[I2],
166 blk_idx[I3]);
167 }
168
170 {
171 static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
172 BK0NK1BlockDesc::IsKnownAtCompileTime(),
173 "wrong! Desc should be known at compile-time");
174
176 "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
177
178 static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
179 "wrong!");
181 {
183 "ComputeTypeA and ComputeTypeB must be same when one of them is tf32");
184 }
185 }
186
187 __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
188 {
189 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
190
191 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
192 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
193 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
194 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
195
197 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
198 }
199
200 __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
201 {
202 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
203
204 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
205 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
206 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
207 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
208
210 make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
211 }
212
213 __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
214 {
215 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
221 Number<NPerXDL>{}));
222
223 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
224 }
225
226 __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
227 {
228 constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
235 Number<NPerXDL>{}));
236
237 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
238 c_block_desc_g_m0_n0_m1_n1_m2_n2);
239 }
240
241 template <typename CGridDesc_M_N>
242 __host__ __device__ static constexpr auto
243 MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
244 {
245 const auto M = c_grid_desc_m_n.GetLength(I0);
246 const auto N = c_grid_desc_m_n.GetLength(I1);
247
248 const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
249 c_grid_desc_m_n,
250 make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
251 make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
254
255 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
256 }
257
258 template <typename CGridDesc_G_M_N>
259 __host__ __device__ static constexpr auto
260 MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
261 {
262 const auto G = c_grid_desc_g_m_n.GetLength(I0);
263 const auto M = c_grid_desc_g_m_n.GetLength(I1);
264 const auto N = c_grid_desc_g_m_n.GetLength(I2);
265
266 const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
267 c_grid_desc_g_m_n,
269 make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
270 make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
273
274 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
275 c_grid_desc_g_m0_n0_m1_n1_m2_n2);
276 }
277
289
301
304
305 template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
306 __device__ void Run(const ABlockBuffer& a_block_buf,
307 const BBlockBuffer& b_block_buf,
308 CThreadBuffer& c_thread_buf) const
309 {
311 a_thread_desc_.GetElementSpaceSize());
313 b_thread_desc_.GetElementSpaceSize());
314
315 static_for<0, MRepeat, 1>{}([&](auto m0) {
316 // read A
318 make_tuple(m0, I0, I0, I0),
319 a_block_buf,
321 make_tuple(I0, I0, I0, I0),
322 a_thread_buf);
323
324 static_for<0, NRepeat, 1>{}([&](auto n0) {
325 // read B
327 make_tuple(n0, I0, I0, I0),
328 b_block_buf,
330 make_tuple(I0, I0, I0, I0),
331 b_thread_buf);
332
336
337 static_for<0, KPack, 1>{}([&](auto i) {
338 a_thread_vec.template AsType<ElementDataTypeA>()(i) = a_thread_buf
339 [Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
340 b_thread_vec.template AsType<ElementDataTypeB>()(i) = b_thread_buf
341 [Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
342 });
343
344 using mfma_input_type_a =
345 typename vector_type<ElementDataTypeA, xdlops_gemm.K1PerXdlops>::type;
346 using mfma_input_type_b =
347 typename vector_type<ElementDataTypeB, xdlops_gemm.K1PerXdlops>::type;
348
349 constexpr index_t c_offset =
350 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
351
352 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type_a>(),
353 b_thread_vec.template AsType<mfma_input_type_b>(),
354 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
355 });
356 });
357 });
358 }
359
360 protected:
361 // A[M0, M1, M2, KPerThread]
362 static constexpr auto a_thread_desc_ =
364
365 // B[N0, N1, N2, KPerThread]
366 static constexpr auto b_thread_desc_ =
368
369 // C[M, N, NumRegXdlops]
371 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
372
375 decltype(a_block_desc_m0_m1_m2_k),
376 decltype(a_thread_desc_),
379 3,
380 A_K1,
381 A_K1>;
382
385 decltype(b_block_desc_n0_n1_n2_k),
386 decltype(b_thread_desc_),
389 3,
390 B_K1,
391 B_K1>;
392
395};
396
397// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro
398// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in
399// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
400// default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0
401template <index_t BlockSize,
402 typename FloatA,
403 typename FloatB,
404 typename FloatAcc,
405 typename AK0MK1BlockDesc,
406 typename BK0NK1BlockDesc,
407 index_t MPerXDL,
408 index_t NPerXDL,
409 index_t MRepeat,
410 index_t NRepeat,
411 index_t KPack,
412 typename ComputeTypeA = FloatA,
413 typename ComputeTypeB = FloatB,
417 FloatA,
418 FloatB,
419 FloatAcc,
420 AK0MK1BlockDesc,
421 BK0NK1BlockDesc,
422 MPerXDL,
423 NPerXDL,
424 MRepeat,
425 NRepeat,
426 KPack,
427 ComputeTypeA,
428 ComputeTypeB>
429{
431 FloatA,
432 FloatB,
433 FloatAcc,
434 AK0MK1BlockDesc,
435 BK0NK1BlockDesc,
436 MPerXDL,
437 NPerXDL,
438 MRepeat,
439 NRepeat,
440 KPack,
441 ComputeTypeA,
442 ComputeTypeB>;
443
444#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
446 using Base::A_K1;
448 using Base::B_K1;
453 using Base::I0;
454 using Base::I1;
455 using Base::KPerThread;
456 using Base::xdlops_gemm;
457
458 using ElementDataTypeA =
460 using ElementDataTypeB =
462
463 static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
464
465 // 2-wave optimized blockwise gemm
466 template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
467 __device__ void Run(const ABlockBuffer& a_block_buf,
468 const BBlockBuffer& b_block_buf,
469 CThreadBuffer& c_thread_buf) const
470 {
472 a_thread_desc_.GetElementSpaceSize());
474 b_thread_desc_.GetElementSpaceSize());
475
477 static_for<0, MRepeat, 1>{}([&](auto m0) {
478 // read A
480 make_tuple(m0, I0, I0, k),
481 a_block_buf,
483 make_tuple(m0, I0, I0, I0),
484 a_thread_buf);
485 });
486 static_for<0, NRepeat, 1>{}([&](auto n0) {
487 // read B
489 make_tuple(n0, I0, I0, k),
490 b_block_buf,
492 make_tuple(n0, I0, I0, I0),
493 b_thread_buf);
494 });
495 __builtin_amdgcn_sched_barrier(0);
496 // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except
497 // the first, as we can shorten non-MAC cluster a bit and there's no observable negative
498 // impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids
499 // some out-of-sync waves hijacking MAC resource from other workgroups and reducing the
500 // chance of latency hiding by waiting for the rest of the workgroup at the eventual
501 // sync point.
502 if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
503 {
504#ifdef __gfx12__
505 asm volatile("\
506 s_barrier_signal -1 \n \
507 s_barrier_wait -1 \
508 " ::);
509#else
510 asm volatile("s_barrier" ::);
511#endif
512 __builtin_amdgcn_sched_barrier(0);
513 }
514 static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
515 static_for<0, MRepeat, 1>{}([&](auto m0) {
516 static_for<0, NRepeat, 1>{}([&](auto n0) {
517 vector_type<ElementDataTypeA, KPack> a_thread_vec;
518 vector_type<ElementDataTypeB, KPack> b_thread_vec;
519
520 static_for<0, KPack, 1>{}([&](auto i) {
521 a_thread_vec.template AsType<ElementDataTypeA>()(i) =
522 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
523 make_tuple(m0, 0, 0, k_ + i))>{}];
524 b_thread_vec.template AsType<ElementDataTypeB>()(i) =
525 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
526 make_tuple(n0, 0, 0, k_ + i))>{}];
527 });
528
529 using mfma_input_type_a =
530 typename vector_type<ElementDataTypeA, xdlops_gemm.K1PerXdlops>::type;
531 using mfma_input_type_b =
532 typename vector_type<ElementDataTypeB, xdlops_gemm.K1PerXdlops>::type;
533
534 constexpr index_t c_offset =
535 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
536
537 // The block_sync_lds() here performs double duty:
538 // A) safeguard against data hazard because barrier from blockwise_gemm is
539 // moved here B) reduce VMEM FIFO congestion by applying small delays to
540 // different wavefronts It is performed near the end of MAC cluster to
541 // minimize lgkmcnt penalty
542 if constexpr(k.value == KPerThread - KPerInnerLoop &&
543 k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 &&
544 n0.value == NRepeat - 1)
545 {
546 __builtin_amdgcn_sched_barrier(0);
548 __builtin_amdgcn_sched_barrier(0);
549 }
550
551 // TODO: insert setprio in more precise manner since we
552 // could have more than >1 MFMA instructions in single call
553 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type_a>(),
554 b_thread_vec.template AsType<mfma_input_type_b>(),
555 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
556 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
557 {
558 __builtin_amdgcn_sched_barrier(0);
559 __builtin_amdgcn_s_setprio(1);
560 __builtin_amdgcn_sched_barrier(0);
561 }
562 });
563 });
564 });
565 __builtin_amdgcn_sched_barrier(0);
566 __builtin_amdgcn_s_setprio(0);
567 __builtin_amdgcn_sched_barrier(0);
568 });
569 }
570
571 protected:
572 // A[M0, M1, M2, KPerInnerLoop]
575
576 // B[N0, N1, N2, KPerInnerLoop]
579
580 using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
582 decltype(a_block_desc_m0_m1_m2_k),
583 decltype(a_thread_desc_),
584 Sequence<1, 1, 1, KPerInnerLoop>,
585 Sequence<0, 1, 2, 3>,
586 3,
587 A_K1,
588 A_K1>;
589
590 using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
592 decltype(b_block_desc_n0_n1_n2_k),
593 decltype(b_thread_desc_),
594 Sequence<1, 1, 1, KPerInnerLoop>,
595 Sequence<0, 1, 2, 3>,
596 3,
597 B_K1,
598 B_K1>;
599
602
603#endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
604};
605
606template <index_t BlockSize,
607 typename FloatA,
608 typename FloatB,
609 typename FloatAcc,
610 typename AK0MK1BlockDesc,
611 typename BK0NK1BlockDesc,
612 index_t MPerXDL,
613 index_t NPerXDL,
614 index_t MRepeat,
615 index_t NRepeat,
616 index_t KPack,
617 LoopScheduler LoopSched,
618 typename ComputeTypeA = FloatA,
619 typename ComputeTypeB = FloatB>
621{
622 if constexpr(LoopSched == LoopScheduler::Default)
623 {
625 FloatA,
626 FloatB,
627 FloatAcc,
628 AK0MK1BlockDesc,
629 BK0NK1BlockDesc,
630 MPerXDL,
631 NPerXDL,
632 MRepeat,
633 NRepeat,
634 KPack,
635 ComputeTypeA,
636 ComputeTypeB>{};
637 }
638 else if constexpr(LoopSched == LoopScheduler::Interwave)
639 {
641 BlockSize,
642 FloatA,
643 FloatB,
644 FloatAcc,
645 AK0MK1BlockDesc,
646 BK0NK1BlockDesc,
647 MPerXDL,
648 NPerXDL,
649 MRepeat,
650 NRepeat,
651 KPack,
652 ComputeTypeA,
653 ComputeTypeB,
655 }
656};
657
667
668template <
669 index_t BlockSize,
670 typename FloatAB,
671 typename FloatAcc,
672 typename ATileDesc,
673 typename BTileDesc,
674 typename AMmaTileDesc,
675 typename BMmaTileDesc,
676 index_t MPerBlock,
677 index_t NPerBlock,
678 index_t KPerBlock,
679 index_t MPerXDL,
680 index_t NPerXDL,
681 index_t MRepeat,
682 index_t NRepeat,
683 index_t KPack,
684 bool TransposeC = false,
685 index_t AMmaKStride =
686 KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
687 index_t BMmaKStride =
690{
691 static constexpr auto I0 = Number<0>{};
692 static constexpr auto I1 = Number<1>{};
693 static constexpr auto I2 = Number<2>{};
694 static constexpr auto I3 = Number<3>{};
695
697
698 static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
699 static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
700 static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
701
702 static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
703 static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
704 static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
705 static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
706
709
710 static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
711
712 static_assert(KPerThread % KPack == 0,
713 "Wrong KPack setting; try increasing KPerThread or decreasing KPack");
714
716 FloatAcc,
717 MRepeat * NRepeat,
718 xdlops_gemm.GetRegSizePerXdlops(),
719 true>
721
722 __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
723
724 __device__ static auto GetWaveIdx()
725 {
726 const index_t thread_id = ThisThreadBlock::GetThreadId();
727
728 constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
732
733 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
734 }
735
736 __device__ static auto CalculateAThreadOriginDataIndex()
737 {
738 const auto wave_idx = GetWaveIdx();
739
740 const auto waveId_m = wave_idx[I0];
741
742 const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
743
744 return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
745 }
746
747 __device__ static auto CalculateBThreadOriginDataIndex()
748 {
749 const auto wave_idx = GetWaveIdx();
750
751 const auto waveId_n = wave_idx[I1];
752
753 const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
754
755 return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
756 }
757
758 template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
759 __device__ static auto
761 {
762 const auto wave_idx = GetWaveIdx();
763
764 const auto waveId_m = wave_idx[I0];
765 const auto waveId_n = wave_idx[I1];
766
767 const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
768
769 constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
773
774 constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
778
779 const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
780 make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
781 const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
782 make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
783
784 return make_tuple(c_thread_m, c_thread_n);
785 }
786
787 template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
788 __device__ static auto
790 {
791 const auto wave_idx = GetWaveIdx();
792
793 const auto waveId_m = wave_idx[I0];
794 const auto waveId_n = wave_idx[I1];
795
796 const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
797
798 return make_tuple(
799 m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
800 }
801
803
806 : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
807 {
808#if defined(__HIP_DEVICE_COMPILE__)
809 static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
810 "wrong! Desc should be known at compile-time");
811
813 "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
814
815 static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
816 "wrong!");
817#endif
818 }
819
820 // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
821 __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
822 {
823 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
824
825 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
826 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
827 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
828 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
829
831 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
832 }
833
834 // XDL output supporting C_xdl = A_xdl * B_xdl
835 __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
836 {
837 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
838
839 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
840 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
841 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
842 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
843
845 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
846 }
847
848 __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
849 {
850 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
851
852 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
853 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
854 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
855 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
856
858 make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
859 }
860
861 // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
862 __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
863 {
864 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
870 Number<NPerXDL>{}));
871
872 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
873 }
874
875 // XDL output supporting C_xdl = A_xdl * B_xdl
876 __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
877 {
878 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
884 Number<NPerXDL>{}));
885
886 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
887 }
888
889 __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
890 {
891 constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
898 Number<NPerXDL>{}));
899
900 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
901 c_block_desc_g_m0_n0_m1_n1_m2_n2);
902 }
903
904 template <typename CGridDesc_M_N>
905 __host__ __device__ static constexpr auto
906 MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
907 {
908 const auto M = c_grid_desc_m_n.GetLength(I0);
909 const auto N = c_grid_desc_m_n.GetLength(I1);
910
911 const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
912 c_grid_desc_m_n,
913 make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
914 make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
917
918 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
919 }
920
921 template <typename CGridDesc_G_M_N>
922 __host__ __device__ static constexpr auto
923 MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
924 {
925 const auto G = c_grid_desc_g_m_n.GetLength(I0);
926 const auto M = c_grid_desc_g_m_n.GetLength(I1);
927 const auto N = c_grid_desc_g_m_n.GetLength(I2);
928
929 const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
930 c_grid_desc_g_m_n,
932 make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
933 make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
936
937 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
938 c_grid_desc_g_m0_n0_m1_n1_m2_n2);
939 }
940
941 static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
942 static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
943
944 template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
945 __device__ void Run(const ABlockBuffer& a_block_buf,
946 const BBlockBuffer& b_block_buf,
947 CThreadBuffer& c_thread_buf) const
948 {
950 a_thread_desc_.GetElementSpaceSize());
952 b_thread_desc_.GetElementSpaceSize());
953
954 static_for<0, KPerThread / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
955 static_for<0, MRepeat, 1>{}([&](auto m0) {
956 // read A
959 a_block_buf,
961 make_tuple(I0, I0, I0, I0),
962 a_thread_buf);
963
964 static_for<0, NRepeat, 1>{}([&](auto n0) {
965 // read B
968 b_block_buf,
970 make_tuple(I0, I0, I0, I0),
971 b_thread_buf);
972 vector_type<FloatAB, KPack> a_thread_vec;
973 vector_type<FloatAB, KPack> b_thread_vec;
974
975 static_for<0, KPack, 1>{}([&](auto i) {
976 a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
977 [Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
978 b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
979 [Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
980 });
981
982 using mfma_input_type =
983 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
984
985 constexpr index_t c_offset =
986 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
987
988 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
989 b_thread_vec.template AsType<mfma_input_type>(),
990 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
991 });
992 });
993 });
994 }
995
996 protected:
997 // A[M0, M1, M2, KPack]
998 static constexpr auto a_thread_desc_ =
1000
1001 // B[N0, N1, N2, KPack]
1002 static constexpr auto b_thread_desc_ =
1004
1005 // C[M, N, NumRegXdlops]
1007 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
1008
1010 FloatAB,
1011 decltype(a_block_desc_m0_m1_m2_k),
1012 decltype(a_thread_desc_),
1015 3,
1016 A_K1,
1017 A_K1>;
1018
1020 FloatAB,
1021 decltype(b_block_desc_n0_n1_n2_k),
1022 decltype(b_thread_desc_),
1025 3,
1026 B_K1,
1027 B_K1>;
1028
1031};
1032
1033} // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition ck.hpp:209
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition blockwise_gemm_xdlops.hpp:620
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition blockwise_gemm_smfmac_xdlops.hpp:44
static constexpr index_t KPerBlock
Definition blockwise_gemm_smfmac_xdlops.hpp:58
__host__ static __device__ constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
Definition blockwise_gemm_xdlops.hpp:278
static constexpr index_t A_K1
Definition blockwise_gemm_smfmac_xdlops.hpp:63
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
Definition blockwise_gemm_xdlops.hpp:169
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_xdlops.hpp:83
static constexpr auto I2
Definition blockwise_gemm_smfmac_xdlops.hpp:47
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_xdlops.hpp:260
conditional_t< is_same_v< ComputeTypeB, ck::tf32_t >, float, ComputeTypeB > ElementDataTypeB
Definition blockwise_gemm_xdlops.hpp:54
static __device__ auto CalculateBThreadOriginDataIndex()
Definition blockwise_gemm_xdlops.hpp:108
static constexpr index_t WaveSize
Definition blockwise_gemm_smfmac_xdlops.hpp:54
__host__ static __device__ constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
Definition blockwise_gemm_xdlops.hpp:290
static constexpr index_t KPerThread
Definition blockwise_gemm_smfmac_xdlops.hpp:69
static constexpr index_t B_K1
Definition blockwise_gemm_smfmac_xdlops.hpp:64
static __device__ auto CalculateAThreadOriginDataIndex()
Definition blockwise_gemm_xdlops.hpp:97
static constexpr index_t MPerBlock
Definition blockwise_gemm_smfmac_xdlops.hpp:56
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition blockwise_gemm_smfmac_xdlops.hpp:76
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops.hpp:226
static constexpr auto b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_smfmac_xdlops.hpp:295
static constexpr index_t NPerBlock
Definition blockwise_gemm_smfmac_xdlops.hpp:57
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_xdlops.hpp:150
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops.hpp:200
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops.hpp:213
static constexpr auto I0
Definition blockwise_gemm_smfmac_xdlops.hpp:45
ThreadwiseTensorSliceTransfer_v4< FloatB, ComputeTypeB, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPerThread >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_smfmac_xdlops.hpp:440
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition blockwise_gemm_smfmac_xdlops.hpp:50
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_xdlops.hpp:243
static constexpr auto a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_smfmac_xdlops.hpp:294
static constexpr index_t NWaves
Definition blockwise_gemm_smfmac_xdlops.hpp:53
ThreadwiseTensorSliceTransfer_v4< FloatA, ComputeTypeA, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPerThread >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_smfmac_xdlops.hpp:430
static constexpr auto xdlops_gemm
Definition blockwise_gemm_smfmac_xdlops.hpp:66
conditional_t< is_same_v< ComputeTypeA, ck::tf32_t >, float, ComputeTypeA > ElementDataTypeA
Definition blockwise_gemm_xdlops.hpp:52
static constexpr index_t B_K0
Definition blockwise_gemm_smfmac_xdlops.hpp:62
static constexpr index_t A_K0
Definition blockwise_gemm_smfmac_xdlops.hpp:61
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition blockwise_gemm_xdlops.hpp:306
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_xdlops.hpp:121
static constexpr auto I3
Definition blockwise_gemm_smfmac_xdlops.hpp:48
static constexpr auto I1
Definition blockwise_gemm_smfmac_xdlops.hpp:46
static __device__ auto GetWaveIdx()
Definition blockwise_gemm_xdlops.hpp:85
static constexpr index_t MWaves
Definition blockwise_gemm_smfmac_xdlops.hpp:52
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops.hpp:187
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_xdlops.hpp:923
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_xdlops.hpp:722
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops.hpp:876
ThreadwiseTensorSliceTransfer_v4< FloatAB, FloatAB, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_xdlops.hpp:1009
static constexpr index_t A_K0
Definition blockwise_gemm_xdlops.hpp:702
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops.hpp:835
static constexpr auto xdlops_gemm
Definition blockwise_gemm_xdlops.hpp:707
static constexpr index_t A_K1
Definition blockwise_gemm_xdlops.hpp:704
static constexpr auto b_thread_desc_
Definition blockwise_gemm_xdlops.hpp:1002
static __device__ auto GetWaveIdx()
Definition blockwise_gemm_xdlops.hpp:724
static constexpr auto I1
Definition blockwise_gemm_xdlops.hpp:692
static constexpr index_t NWaves
Definition blockwise_gemm_xdlops.hpp:699
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_xdlops.hpp:942
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition blockwise_gemm_xdlops.hpp:945
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Definition blockwise_gemm_xdlops.hpp:804
static constexpr index_t B_K0
Definition blockwise_gemm_xdlops.hpp:703
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_xdlops.hpp:760
static constexpr auto I2
Definition blockwise_gemm_xdlops.hpp:693
decltype(CalculateAThreadOriginDataIndex()) Tuple4
Definition blockwise_gemm_xdlops.hpp:802
static constexpr auto a_thread_desc_
Definition blockwise_gemm_xdlops.hpp:998
static constexpr auto c_thread_desc_
Definition blockwise_gemm_xdlops.hpp:1006
static constexpr auto I3
Definition blockwise_gemm_xdlops.hpp:694
static constexpr index_t WaveSize
Definition blockwise_gemm_xdlops.hpp:700
static __device__ auto CalculateAThreadOriginDataIndex()
Definition blockwise_gemm_xdlops.hpp:736
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_xdlops.hpp:862
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_xdlops.hpp:821
static __device__ auto CalculateBThreadOriginDataIndex()
Definition blockwise_gemm_xdlops.hpp:747
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops.hpp:889
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_xdlops.hpp:941
ThreadwiseTensorSliceTransfer_v4< FloatAB, FloatAB, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_xdlops.hpp:1019
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition blockwise_gemm_xdlops.hpp:696
static constexpr index_t MWaves
Definition blockwise_gemm_xdlops.hpp:698
static constexpr index_t B_K1
Definition blockwise_gemm_xdlops.hpp:705
static constexpr index_t KPerThread
Definition blockwise_gemm_xdlops.hpp:710
AThreadCopy a_thread_copy_
Definition blockwise_gemm_xdlops.hpp:1029
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition blockwise_gemm_xdlops.hpp:720
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_xdlops.hpp:789
static constexpr auto I0
Definition blockwise_gemm_xdlops.hpp:691
BThreadCopy b_thread_copy_
Definition blockwise_gemm_xdlops.hpp:1030
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_xdlops.hpp:848
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_xdlops.hpp:906
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< BlockSize, FloatA, FloatB, FloatAcc, AK0MK1BlockDesc, BK0NK1BlockDesc, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack, ComputeTypeA, ComputeTypeB > Base
Definition blockwise_gemm_xdlops.hpp:430
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:75
static __device__ constexpr index_t GetNumOfThread()
Definition thread_group.hpp:15
static __device__ index_t GetThreadId()
Definition thread_group.hpp:19
Definition threadwise_tensor_slice_transfer.hpp:1260
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition threadwise_tensor_slice_transfer.hpp:1293
Definition xdlops_gemm.hpp:1821
static constexpr auto K0PerXdlops
Definition xdlops_gemm.hpp:2201
Definition functional2.hpp:33
Definition dtype_vector.hpp:10