gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp Source File

gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp Source File#

Composable Kernel: gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp Source File
gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck/utility/env.hpp"
17
18namespace ck {
19
20// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
21// kernel function Blockers:
22// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
23// two lds chunks.
24// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
25// buffer when we declare __shared__ inside blkgemmpipe
26template <typename GridwiseGemm,
27 bool HasMainKBlockLoop,
28 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
29 index_t MinimumOccupancy = 1,
31__global__ void
32#if CK_USE_LAUNCH_BOUNDS
33__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
34#endif
35 // __attribute__((amdgpu_waves_per_eu(1, 1)))
36 kernel_gemm_xdl_cshuffle_v3_b_preshuffle(typename GridwiseGemm::Argument karg)
37{
38#if defined(__gfx9__) || defined(__gfx12__)
39 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
40 {
41 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
42
43 // Full K needed for matrix B
44 const index_t Kt = karg.K;
45
46 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
47
48 const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
49 const index_t k_id = blockIdx.z * num_k_per_block;
50
51 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
52 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
53 karg.p_b_grid,
54 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
55 p_shared,
56 karg,
57 k_id,
58 Kt);
59 }
60#else
61 ignore = karg;
62#endif // end of if (defined(__gfx9__))
63}
64
65template <typename GridwiseGemm,
66 bool HasMainKBlockLoop,
67 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
68 index_t MinimumOccupancy = 1,
70__global__ void
71#if CK_USE_LAUNCH_BOUNDS
72__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
73#endif
74 // __attribute__((amdgpu_waves_per_eu(1, 1)))
75 kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds(typename GridwiseGemm::Argument karg)
76{
77#if defined(__gfx9__) || defined(__gfx12__)
78 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
79 {
80 // Pass two lds pointer is the key to tell compiler that ds_read/write
81 // operate on different lds chunk at same time without order dependecy
82 __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
83 __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
84
85 // Full K needed for matrix B
86 const index_t Kt = karg.K;
87
88 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
89
90 const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
91 const index_t k_id = blockIdx.z * num_k_per_block;
92
93 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
94 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
95 karg.p_b_grid,
96 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
97 p_shared_0,
98 p_shared_1,
99 karg,
100 k_id,
101 Kt);
102 }
103#else
104 ignore = karg;
105#endif // end of if (defined(__gfx9__))
106}
107
108template <typename ALayout,
109 typename BLayout,
110 typename CLayout,
111 typename ADataType,
112 typename BDataType,
113 typename AccDataType,
114 typename CShuffleDataType,
115 typename CDataType,
116 typename AElementwiseOperation,
117 typename BElementwiseOperation,
118 typename CElementwiseOperation,
120 index_t BlockSize,
121 index_t MPerBlock,
122 index_t NPerBlock,
123 index_t KPerBlock,
124 index_t AK1Value,
125 index_t BK1Value,
126 index_t MPerXdl,
127 index_t NPerXdl,
128 index_t MXdlPerWave,
129 index_t NXdlPerWave,
130 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
131 typename ABlockTransferThreadClusterArrangeOrder,
132 typename ABlockTransferSrcAccessOrder,
133 index_t ABlockTransferSrcVectorDim,
134 index_t ABlockTransferSrcScalarPerVector,
135 index_t ABlockTransferDstScalarPerVector_AK1,
136 bool AThreadTransferSrcResetCoordinateAfterRun,
137 index_t ABlockLdsExtraM,
138 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
139 typename BBlockTransferThreadClusterArrangeOrder,
140 typename BBlockTransferSrcAccessOrder,
141 index_t BBlockTransferSrcVectorDim,
142 index_t BBlockTransferSrcScalarPerVector,
143 index_t BBlockTransferDstScalarPerVector_BK1,
144 bool BThreadTransferSrcResetCoordinateAfterRun,
145 index_t BBlockLdsExtraN,
146 index_t CShuffleMXdlPerWavePerShuffle,
147 index_t CShuffleNXdlPerWavePerShuffle,
148 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
149 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
152 typename ComputeTypeA = CDataType,
153 typename ComputeTypeB = ComputeTypeA,
154 bool PermuteA = false,
155 bool PermuteB = false>
157{
158 static constexpr auto I0 = Number<0>{};
159 static constexpr auto I1 = Number<1>{};
160 static constexpr auto I2 = Number<2>{};
161 static constexpr auto I3 = Number<3>{};
162 static constexpr auto I4 = Number<4>{};
163 static constexpr auto I5 = Number<5>{};
164 static constexpr auto I6 = Number<6>{};
165 static constexpr auto I7 = Number<7>{};
166
167 // K1 should be Number<...>
168 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
169 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
170 static constexpr auto AK1Number = Number<AK1Value>{};
171 static constexpr auto BK1Number = Number<BK1Value>{};
172
173 // Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
174 // See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3
175 // TODO: explore optimization opportunity by using new mfma instructions on gfx950
176 static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
177 static constexpr bool is_single_rate_mfma = true;
178 static constexpr auto is_scale_mfma = false;
179 static constexpr auto mfma = MfmaSelector<ComputeTypeA,
180 MPerXdl,
181 NPerXdl,
182 ComputeTypeA,
184 is_scale_mfma>{};
185 static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
186 static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops();
187
188 static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
189 static constexpr index_t NLane = NPerXdl;
190 static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
191
193
194 static constexpr index_t APackedSize = []() {
196 return 2;
197 else
198 return 1;
199 }();
200
201 static constexpr index_t BPackedSize = []() {
203 return 2;
204 else
205 return 1;
206 }();
207
208 __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
209 {
210 return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
211 }
212
213 __host__ static auto CalculateMPadded(index_t M)
214 {
215 return math::integer_least_multiple(M, MPerBlock);
216 }
217
218 __host__ static auto CalculateNPadded(index_t N)
219 {
220 return math::integer_least_multiple(N, NPerBlock);
221 }
222
223 __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
224 {
226 }
227
228 __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
229 {
231 }
232
233 __host__ static auto CalculateKPadded(index_t K)
234 {
235 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
236 }
237
238 __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
239 {
240 auto K_t = K_Batch * KPerBlock;
241 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
242 }
243
244 __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
245 {
246 auto K_t = K_Batch * KPerBlock;
247 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
248 }
249
250 __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
251 {
252 auto K_t = K_Batch * KPerBlock;
253 return (K + K_t - 1) / K_t * KPerBlock;
254 }
255
256 __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
257 {
258 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
259 auto K_t = K_Batch * KReadVec;
260 return (K + K_t - 1) / K_t * KReadVec;
261 }
262
263 __host__ static auto CalculateMBlock(index_t M)
264 {
265 return math::integer_divide_ceil(M, MPerBlock);
266 }
267
268 __host__ static auto CalculateNBlock(index_t N)
269 {
270 return math::integer_divide_ceil(N, NPerBlock);
271 }
272
273 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
274 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
275 {
276 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
277 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
278
280 TileDesc_K0_MN_K1{},
286 }
287
288 __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
289 index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
290 {
291 const auto a_grid_desc_mraw_kraw = [&]() {
293 {
294 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
295 }
297 {
298 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
299 }
300 }();
301
302 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
303
304 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
305 GemmSpec == GemmSpecialization::MNKPadding)
306 {
307 // pad both M and K
308 const auto a_grid_desc_m_k =
309 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
311 make_right_pad_transform(K, KPad - K)),
314
315 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
316 a_grid_desc_m_k,
321
322 return a_grid_desc_ak0_m_ak1;
323 }
324 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
325 GemmSpec == GemmSpecialization::MNPadding)
326 {
327 // pad M, but not K
328 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
329 a_grid_desc_mraw_kraw,
331 make_right_pad_transform(M, MPad - M)),
334
335 return a_grid_desc_ak0_m_ak1;
336 }
337 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
338 GemmSpec == GemmSpecialization::NKPadding)
339 {
340 // pad K, but not M
341 const auto a_grid_desc_m_k = transform_tensor_descriptor(
342 a_grid_desc_mraw_kraw,
346
347 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
348 a_grid_desc_m_k,
353
354 return a_grid_desc_ak0_m_ak1;
355 }
356 else
357 {
358 // not pad M or K
359 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
360 a_grid_desc_mraw_kraw,
365
366 return a_grid_desc_ak0_m_ak1;
367 }
368 }
369
370 __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
371 {
372 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
373 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
374 constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack>{};
376 make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
377 make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
378 }
379
380 __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
381 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
382 {
383 const auto b_grid_desc_nraw_kraw = [&]() {
385 {
386 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
387 }
389 {
390 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
391 }
392 }();
393
394 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
395
397 GemmSpec != GemmSpecialization::Default),
398 "pk_i4_t does not support padding");
399
400 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
401 GemmSpec == GemmSpecialization::MNKPadding)
402 {
403 // pad both N and K
404 const auto b_grid_desc_n_k =
405 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
407 make_right_pad_transform(K, KPad - K)),
410
411 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
412 b_grid_desc_n_k,
417
418 return b_grid_desc_bk0_n_bk1;
419 }
420 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
421 GemmSpec == GemmSpecialization::MNPadding)
422 {
423 // pad N, but not K
424 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
425 b_grid_desc_nraw_kraw,
427 make_right_pad_transform(N, NPad - N)),
430
431 return b_grid_desc_bk0_n_bk1;
432 }
433 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
434 GemmSpec == GemmSpecialization::MKPadding)
435 {
436 // pad K, but not N
437 const auto b_grid_desc_n_k = transform_tensor_descriptor(
438 b_grid_desc_nraw_kraw,
442
443 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
444 b_grid_desc_n_k,
449
450 return b_grid_desc_bk0_n_bk1;
451 }
452 else
453 {
454 if constexpr(!PermuteB)
455 {
456 // not pad N or K
457 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
458 b_grid_desc_nraw_kraw,
463
464 return b_grid_desc_bk0_n_bk1;
465 }
466 else
467 {
468 // Pre-shuffled Weight
469 // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
470 constexpr index_t BK01 = KPerBlock / BK1Value;
471 const index_t BK0_ = StrideB / BK1Value;
472 const index_t BK00 = BK0_ / BK01;
473
474 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
475 make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
476
477 const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor(
478 b_grid_desc_bk00_n_bk01_bk1_permute,
484
485 return b_grid_desc_bk0_n_bk1_permute;
486 }
487 }
488 }
489
490 template <typename ABlockDesc_AK0_M_AK1>
491 __host__ __device__ static constexpr auto
492 MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
493 {
494 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
495
496 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
497 }
498
499 template <typename BBlockDesc_BK0_N_BK1>
500 __host__ __device__ static constexpr auto
501 MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
502 {
503 // constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
504
505 // return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
506
507 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
508 }
509
510 __host__ __device__ static auto
512 {
513 const auto c_grid_desc_mraw_nraw = [&]() {
515 {
516 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
517 }
519 {
520 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
521 }
522 }();
523
524 // pad M and N
525 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
527 make_right_pad_transform(N, NPad - N)),
530#if 0
531 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
532
533 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
534 GemmSpec == GemmSpecialization::MNKPadding)
535 {
536 // pad M and N
537 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
539 make_right_pad_transform(N, NPad - N)),
542 }
543 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
544 GemmSpec == GemmSpecialization::MKPadding)
545 {
546 // pad M, but not N
548 c_grid_desc_mraw_nraw,
552 }
553 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
554 GemmSpec == GemmSpecialization::NKPadding)
555 {
556 // pad N, but not M
558 c_grid_desc_mraw_nraw,
562 }
563 else
564 {
565 // not pad M or N
566 return c_grid_desc_mraw_nraw;
567 }
568#endif
569 }
570
571 struct Problem
572 {
573 __host__ Problem(index_t M_,
574 index_t N_,
575 index_t K_,
576 index_t StrideA_,
577 index_t StrideB_,
578 index_t StrideC_,
579 index_t KBatch_)
580 : M{M_},
581 N{N_},
582 K{K_},
583 StrideA{StrideA_},
584 StrideB{StrideB_},
585 StrideC{StrideC_},
586 KBatch{KBatch_},
589 KRead{CalculateKRead(K_, KBatch_)},
590 KPadded{CalculateKPadded(K_, KBatch_)},
591 AK0{CalculateAK0Padded(K_, KBatch_)},
592 BK0{CalculateBK0Padded(K_, KBatch_)},
595 {
596 }
597
598 __host__ void Print() const
599 {
600 std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
601 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
602 << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
603 << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
604 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
605 << "NBlock: " << NBlock << "}" << std::endl;
606 }
607
623 };
624
625 // Argument
627 {
628 __host__ Argument(const ADataType* p_a_grid_,
629 const BDataType* p_b_grid_,
630 CDataType* p_c_grid_,
631 index_t M_,
632 index_t N_,
633 index_t K_,
634 index_t StrideA_,
635 index_t StrideB_,
636 index_t StrideC_,
637 index_t k_batch_,
638 bool is_reduce_ = false)
639 : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
640 p_a_grid{p_a_grid_},
641 p_b_grid{p_b_grid_},
642 p_c_grid{p_c_grid_},
643 is_reduce(is_reduce_)
644 {
645 }
646
647 __host__ __device__ inline bool IsReduceAdd() const
648 {
649 return (Problem::KBatch > 1) && is_reduce;
650 }
651
652 __host__ __device__ inline bool IsAtomicAdd() const
653 {
654 return (Problem::KBatch > 1) && (!is_reduce);
655 }
656
657 const ADataType* p_a_grid;
658 const BDataType* p_b_grid;
659 CDataType* p_c_grid;
661 };
662
664 {
665
666 __device__ SplitKBatchOffset(Argument& karg)
667 {
669 {
670 a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
671 }
673 {
674 a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
675 }
676
677 if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
678 {
679 karg.K = karg.KRead;
680 }
681 else
682 {
683 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
684 }
685
686 if(karg.IsReduceAdd())
687 {
688 c_reduce_offset = blockIdx.z * karg.M * karg.N;
689 }
690 else
691 {
692 c_reduce_offset = 0;
693 }
694 }
695
698 };
699
700 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
701 {
702 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
703 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
704 // A matrix in LDS memory, dst of blockwise copy
705 if constexpr(ABlockLdsExtraM)
706 {
710 }
711 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
712 // in some cases.
714 {
715 constexpr auto a_lds_block_desc =
718
719 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
720 a_lds_block_desc,
726
727 return a_lds_block_desc_permuted;
728 }
729 else // ColumnMajor A
730 {
731 // kfold and mpair dimension is not always required.
732 // more dimension in merge_transform increase the difficulty of generating immarg offset
733 // for compiler.
734 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
735 constexpr auto M1 = MPerBlock / M0;
736
737 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
738 constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
739 constexpr auto KThreadRead = WaveSize / MPerXdl;
740 constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
741
742 constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
743 ? 1
744 : 128 / (AK1Number * M0 * sizeof(ADataType));
745 constexpr auto KThreadReadPerm =
746 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
747 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
748 : KThreadRead;
749
750 // 1<=mpair<=n0
751 constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
752 ? 1
753 : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
754 ? M0
755 : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
756
757 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
761 Number<kfold * M0 / mpair>{},
763 AK1Number));
764
765 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
766 a_lds_block_desc,
771 make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
778
779 constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
780 a_lds_block_desc_permuted,
789 Sequence<1>{},
790 Sequence<2>{},
791 Sequence<3>{},
792 Sequence<4>{},
793 Sequence<5>{}),
795 Sequence<2>{},
798 Sequence<6>{},
799 Sequence<7>{}));
800
801 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
802 a_lds_block_desc_unmerged,
805 Number<KThreadWrite / kfold / KThreadReadPerm>{},
813
814 return a_lds_block_desc_ak0_m_ak1;
815 }
816 }
817
818 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
819 {
820 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
823 I1,
825 Number<BK1Value>{})); //??? BK1Value same as KPack?
826 }
827
829 {
830 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
831
832 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
836 I1,
838
839 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
840 }
841
844 BlkGemmPipelineVer,
845 BlkGemmPipeSched,
846 BlockSize,
847 ADataType,
848 BDataType,
849 ComputeTypeA,
850 AccDataType,
857 ABlockTransferSrcScalarPerVector,
858 BBlockTransferSrcScalarPerVector,
859 MPerBlock,
860 NPerBlock,
861 KPerBlock,
862 MPerXdl,
863 NPerXdl,
864 MXdlPerWave,
865 NXdlPerWave,
866 KPack>())>;
867
868 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
869 {
870 // LDS allocation for A and B: be careful of alignment
871 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
872
873 // lds max alignment
874 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
875
876 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
877 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
878
879 // LDS allocation for C shuffle in LDS
880 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
882
883 constexpr auto c_block_size =
884 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
885
886 return math::max(a_block_space_size_aligned * sizeof(ADataType) / APackedSize,
887 c_block_size * sizeof(CShuffleDataType));
888 }
889
891
892 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
893 __host__ static constexpr bool CheckValidity(const Argument& karg)
894 {
895 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
896 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
897 "Invalid tuning param!");
898
899 if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0)
900 {
901 return false;
902 }
903
909 {
910 if(!(karg.M % MPerBlock == 0))
911 {
912 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
913 {
914 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
915 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
916 << std::endl;
917 }
918 return false;
919 }
920 }
921
927 {
928 if(!(karg.N % NPerBlock == 0))
929 {
930 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
931 {
932 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
933 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
934 << std::endl;
935 }
936 return false;
937 }
938 }
939
944 {
945
946 auto K_t = karg.KBatch * KPerBlock;
947 if(!(karg.K % K_t == 0))
948 {
949 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
950 {
951 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
952 << karg.K << " " << __FILE__ << ":" << __LINE__
953 << ", in function: " << __func__ << std::endl;
954 }
955 return false;
956 }
957 }
958 else
959 {
960 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
961 auto K_t = karg.KBatch * KReadVec;
962 auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
963 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
964 {
965 return false;
966 }
967 }
968
970 {
971 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
972 {
973 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
974 {
975 std::cout << "Arg K (" << karg.K
976 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
977 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
978 << __LINE__ << ", in function: " << __func__ << std::endl;
979 }
980 return false;
981 }
982 }
983 else
984 {
985 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
986 {
987 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
988 {
989 std::cout << "Arg M (" << karg.M
990 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
991 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
992 << __LINE__ << ", in function: " << __func__ << std::endl;
993 }
994 return false;
995 }
996 }
997
999 {
1000 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1001 {
1002 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1003 {
1004 std::cout << "Arg N (" << karg.N
1005 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1006 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1007 << __LINE__ << ", in function: " << __func__ << std::endl;
1008 }
1009 return false;
1010 }
1011 }
1012 else
1013 {
1014 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1015 {
1016 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1017 {
1018 std::cout << "Arg K (" << karg.K
1019 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1020 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1021 << __LINE__ << ", in function: " << __func__ << std::endl;
1022 }
1023 return false;
1024 }
1025 }
1026
1028 {
1029 if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1030 {
1031 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1032 {
1033 std::cout << "Arg N (" << karg.N
1034 << ") value is not a multiple of "
1035 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1036 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1037 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1038 << std::endl;
1039 }
1040 return false;
1041 }
1042 }
1043 else
1044 {
1045 if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1046 {
1047 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1048 {
1049 std::cout << "Arg M (" << karg.M
1050 << ") value is not a multiple of "
1051 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1052 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1053 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1054 << std::endl;
1055 }
1056 return false;
1057 }
1058 }
1059
1064 {
1065 if(!karg.IsReduceAdd())
1066 {
1067 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1068 {
1069 std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
1070 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1071 }
1072 if(karg.KBatch > 1)
1073 {
1074 return false;
1075 }
1076 }
1077 }
1078
1079 // check gridwise gemm pipeline
1080 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1081
1082 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1083 {
1084 return false;
1085 }
1086
1087 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1088 return true;
1089 }
1090
1091 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1092 {
1093 const index_t num_loop = K / KPerBlock;
1094
1095 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1096 }
1097
1098 __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1099 {
1100 const index_t num_loop = K / KPerBlock;
1101
1102 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1103 }
1104
1105 template <typename CGridDesc>
1106 __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1107 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1108 {
1109 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1110 c_grid_desc_m_n,
1115
1116 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1117 }
1118
1119 // return block_id to C matrix tile idx (m0, n0) mapping
1120 // if arch = gfx942
1122 // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
1123
1124 template <typename AGridDesc_AK0_M_K1,
1125 typename BGridDesc_BPreshuffled,
1126 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1127 bool HasMainKBlockLoop,
1128 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1129 TailNumber TailNum = TailNumber::Odd>
1130 __device__ static void Run(const ADataType* p_a_grid,
1131 const BDataType* p_b_grid,
1132 CDataType* p_c_grid,
1133 void* p_shared,
1134 const Problem& problem,
1135 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1136 const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled,
1137 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1138 c_grid_desc_mblock_mperblock_nblock_nperblock,
1139 const index_t k_id)
1140 {
1141 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1142 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1143 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1144 p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1146 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1147
1148 const AElementwiseOperation a_element_op{};
1149 // const BElementwiseOperation b_element_op{};
1150 const CElementwiseOperation c_element_op{};
1151
1152 // divide block work by [M, N]
1153 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1154
1155 const auto block_work_idx =
1156 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1157
1158 if(!block_2_ctile_map.ValidCTileIndex(
1159 block_work_idx,
1160 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1161 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1162 {
1163 return;
1164 }
1165
1166 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1167 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1168
1169 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1170 const index_t m_block_data_idx_on_grid =
1171 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1172
1173 // N0, K0, Blocksize*KPack
1174 const index_t n_block_data_idx_on_grid =
1175 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1176
1177 // A matrix in LDS memory, dst of blockwise copy
1178 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1179
1180 // B matrix in LDS memory, dst of blockwise copy
1181 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1182
1183 // A matrix blockwise copy
1184 auto a_blockwise_copy =
1186 AElementwiseOperation,
1190 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1191 ABlockTransferThreadClusterArrangeOrder,
1192 ADataType,
1193 ADataType,
1194 decltype(a_grid_desc_ak0_m_ak1),
1195 decltype(a_block_desc_ak0_m_ak1),
1196 ABlockTransferSrcAccessOrder,
1198 ABlockTransferSrcVectorDim,
1199 2,
1200 ABlockTransferSrcScalarPerVector,
1201 ABlockTransferDstScalarPerVector_AK1,
1202 1,
1203 1,
1204 AThreadTransferSrcResetCoordinateAfterRun,
1205 true,
1206 BlockwiseGemmPipe::GlobalBufferNum>(
1207 a_grid_desc_ak0_m_ak1,
1208 make_multi_index(0, m_block_data_idx_on_grid, 0),
1209 a_element_op,
1210 a_block_desc_ak0_m_ak1,
1211 make_multi_index(0, 0, 0),
1213
1214 // B matrix threadwise copy, using threadwiseTensorSliceTransfer_v2
1216 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1217
1218 auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1219 BDataType,
1220 BDataType,
1221 decltype(b_grid_desc_bpreshuffled),
1222 decltype(b_block_desc_bk0_n_bk1),
1225 3,
1226 BBlockTransferSrcScalarPerVector,
1227 BThreadTransferSrcResetCoordinateAfterRun,
1228 true>(b_grid_desc_bpreshuffled,
1229 make_multi_index(n_block_data_idx_on_grid,
1231 k_id,
1232 KPack * (get_thread_local_1d_id() % WarpSize)));
1233
1234 // LDS allocation for A and B: be careful of alignment
1235
1236 // Cast after lds
1238 static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1239
1240 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1241 constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1242
1243 // Blockwise GEMM pipeline
1244 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1245 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1246 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1247
1248 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1249 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1250 KPerBlock);
1251
1252 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1253 a_block_desc_ak0_m_ak1,
1254 a_blockwise_copy,
1255 a_grid_buf,
1256 a_block_buf,
1257 a_block_slice_copy_step,
1258 b_grid_desc_bpreshuffled,
1259 b_blockwise_copy,
1260 b_grid_buf,
1261 b_block_buf,
1262 b_block_slice_copy_step,
1263 c_thread_buf,
1264 num_k_block_main_loop);
1265
1266 // shuffle C and write out
1267 {
1268 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1269 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1270 "wrong!");
1271
1272 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1273
1274 // TODO: hacky, fix it!
1275 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1276 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1277
1278 // TODO: hacky, fix it!
1279 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1280 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1281 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1282
1283 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1284 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1285 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1286 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1287 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1288 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1289 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1290 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1291
1292 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1294
1295 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1296 static_cast<CShuffleDataType*>(p_shared),
1297 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1298
1299 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1300 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1301 make_tuple(
1304 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1305 M1, // M1 = MWave
1306 M2, // M2 * M3 * M4 = MPerXdl
1307 M3,
1308 M4)),
1311 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1312 N1, // N1 = NWave
1313 N2))), // N2 = NPerXdl
1315 make_tuple(
1317
1318 // calculate origin of thread output tensor on global memory
1319 // blockwise GEMM c matrix starting index
1320 const auto c_thread_mtx_on_block =
1321 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1322
1323 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1324 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1325
1326 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1328 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1331
1332 const auto m_thread_data_on_block_idx =
1333 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1334 make_multi_index(m_thread_data_on_block));
1335
1336 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1341
1342 const auto n_thread_data_on_block_idx =
1343 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1344 make_multi_index(n_thread_data_on_block));
1345
1346 // shuffle: threadwise copy C from VGPR to LDS
1347 auto c_thread_copy_vgpr_to_lds =
1349 CShuffleDataType,
1350 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1351 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1353 Sequence<CShuffleMXdlPerWavePerShuffle,
1354 CShuffleNXdlPerWavePerShuffle,
1355 I1,
1356 I1,
1357 M2,
1358 I1,
1359 M4,
1360 I1>,
1362 7,
1363 1,
1365 1,
1366 true>{
1367 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1369 0,
1370 m_thread_data_on_block_idx[I1],
1371 n_thread_data_on_block_idx[I1],
1372 m_thread_data_on_block_idx[I2],
1373 m_thread_data_on_block_idx[I3],
1374 m_thread_data_on_block_idx[I4],
1375 n_thread_data_on_block_idx[I2]),
1377
1378 // shuffle: blockwise copy C from LDS to global
1379 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1380 ThisThreadBlock, // ThreadGroup
1381 CElementwiseOperation, // ElementwiseOperation,
1382 CGlobalMemoryDataOperation, // DstInMemOp,
1383 Sequence<1,
1384 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1385 1,
1386 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1387 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1388 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1389 CShuffleDataType, // typename SrcData,
1390 CDataType, // typename DstData,
1391 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1392 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1393 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1394 3, // index_t VectorDim,
1395 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1396 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1397 false> // bool ThreadTransferDstResetCoordinateAfterRun>
1398 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1399 make_multi_index(0, 0, 0, 0),
1400 c_grid_desc_mblock_mperblock_nblock_nperblock,
1401 make_multi_index(block_m_id, 0, block_n_id, 0),
1402 c_element_op};
1403
1404 // space filling curve for threadwise C in VGPR
1405 constexpr auto sfc_c_vgpr =
1408 Sequence<CShuffleMXdlPerWavePerShuffle,
1409 CShuffleNXdlPerWavePerShuffle,
1410 1,
1411 1,
1412 M2,
1413 1,
1414 M4,
1415 1>>{};
1416
1417 // space filling curve for shuffled blockwise C in global mem
1418 constexpr auto sfc_c_global =
1421 Sequence<1,
1422 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1423 1,
1424 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1425
1426 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1427
1428 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1429
1430 static_for<0, num_access, 1>{}([&](auto access_id) {
1431 // make sure it's safe to write to LDS
1433
1434 // each thread write its data from VGPR to LDS
1435 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1436 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1437 c_thread_buf,
1438 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1439 c_shuffle_block_buf);
1440
1441 // make sure it's safe to read from LDS
1443
1444 // each block copy its data from LDS to global
1445 c_shuffle_block_copy_lds_to_global.Run(
1446 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1447 c_shuffle_block_buf,
1448 c_grid_desc_mblock_mperblock_nblock_nperblock,
1449 c_grid_buf);
1450
1451 if constexpr(access_id < num_access - 1)
1452 {
1453 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1454
1455 // move on C
1456 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1457 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1458 }
1459 });
1460 }
1461 }
1462
1463 template <bool HasMainKBlockLoop,
1464 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1465 TailNumber TailNum = TailNumber::Odd>
1466 __device__ static void Run(const ADataType* p_a_grid,
1467 const BDataType* p_b_grid,
1468 CDataType* p_c_grid,
1469 void* p_shared,
1470 const Problem& problem,
1471 const index_t k_id,
1472 const index_t Kt)
1473 {
1474 index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1475 index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
1476 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1477 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1478 const auto b_grid_desc_bpreshuffled =
1479 MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1480 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
1481 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1482 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1484 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1485
1486 Run<decltype(a_grid_desc_ak0_m_ak1),
1487 decltype(b_grid_desc_bpreshuffled),
1488 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1489 HasMainKBlockLoop,
1490 CGlobalMemoryDataOperation,
1491 TailNum>(p_a_grid,
1492 p_b_grid,
1493 p_c_grid,
1494 p_shared,
1495 problem,
1496 a_grid_desc_ak0_m_ak1,
1497 b_grid_desc_bpreshuffled,
1498 c_grid_desc_mblock_mperblock_nblock_nperblock,
1499 k_id);
1500 }
1501
1502 template <typename AGridDesc_AK0_M_K1,
1503 typename BGridDesc_BPreshuffled,
1504 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1505 bool HasMainKBlockLoop,
1506 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1507 TailNumber TailNum = TailNumber::Odd>
1508 __device__ static void Run_2Lds(const ADataType* p_a_grid,
1509 const BDataType* p_b_grid,
1510 CDataType* p_c_grid,
1511 void* p_shared_0,
1512 void* p_shared_1,
1513 const Problem& problem,
1514 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1515 const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled,
1516 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1517 c_grid_desc_mblock_mperblock_nblock_nperblock,
1518 const index_t k_id)
1519 {
1520 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1521 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1522 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1523 p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1525 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1526
1527 const AElementwiseOperation a_element_op{};
1528 // const BElementwiseOperation b_element_op{};
1529 const CElementwiseOperation c_element_op{};
1530
1531 // divide block work by [M, N]
1532 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1533
1534 const auto block_work_idx =
1535 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1536
1537 if(!block_2_ctile_map.ValidCTileIndex(
1538 block_work_idx,
1539 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1540 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1541 {
1542 return;
1543 }
1544
1545 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1546 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1547
1548 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1549 const index_t m_block_data_idx_on_grid =
1550 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1551
1552 // N0, K0, Blocksize*KPack
1553 const index_t n_block_data_idx_on_grid =
1554 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1555
1556 // A matrix in LDS memory, dst of blockwise copy
1557 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1558
1559 // B matrix in LDS memory, dst of blockwise copy
1560 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1561
1562 // A matrix blockwise copy
1563 auto a_blockwise_copy =
1565 AElementwiseOperation,
1569 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1570 ABlockTransferThreadClusterArrangeOrder,
1571 ADataType,
1572 ADataType,
1573 decltype(a_grid_desc_ak0_m_ak1),
1574 decltype(a_block_desc_ak0_m_ak1),
1575 ABlockTransferSrcAccessOrder,
1577 ABlockTransferSrcVectorDim,
1578 2,
1579 ABlockTransferSrcScalarPerVector,
1580 ABlockTransferDstScalarPerVector_AK1,
1581 1,
1582 1,
1583 AThreadTransferSrcResetCoordinateAfterRun,
1584 true,
1585 2>(
1586 a_grid_desc_ak0_m_ak1,
1587 make_multi_index(0, m_block_data_idx_on_grid, 0),
1588 a_element_op,
1589 a_block_desc_ak0_m_ak1,
1590 make_multi_index(0, 0, 0),
1592
1593 // B matrix blockwise copy
1594 // Thread-wise copy
1595 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1597 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1599 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1600 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
1601
1602 auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1603 BDataType,
1604 BDataType,
1605 decltype(b_grid_desc_bpreshuffled),
1606 decltype(b_block_desc_bk0_n_bk1),
1609 3,
1610 BBlockTransferSrcScalarPerVector,
1611 BThreadTransferSrcResetCoordinateAfterRun,
1612 true>(b_grid_desc_bpreshuffled,
1613 make_multi_index(n_block_data_idx_on_grid,
1615 k_id,
1616 KPack * (get_thread_local_1d_id() % WarpSize)));
1617
1618 // LDS allocation for A and B: be careful of alignment
1619 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1620 static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1621
1622 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1623 static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1624
1625 auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
1626
1627 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1628 constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1629
1630 // Blockwise GEMM pipeline
1631 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1632 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1633 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1634
1635 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1636 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1637 KPerBlock);
1638
1639 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1640 a_block_desc_ak0_m_ak1,
1641 a_blockwise_copy,
1642 a_grid_buf,
1643 a_block_bufs,
1644 a_block_slice_copy_step,
1645 b_grid_desc_bpreshuffled,
1646 b_blockwise_copy,
1647 b_grid_buf,
1648 b_block_bufs,
1649 b_block_slice_copy_step,
1650 c_thread_buf,
1651 num_k_block_main_loop);
1652
1653 // shuffle C and write out
1654 {
1655 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1656 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1657 "wrong!");
1658
1659 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1660
1661 // TODO: hacky, fix it!
1662 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1663 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1664
1665 // TODO: hacky, fix it!
1666 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1667 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1668 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1669
1670 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1671 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1672 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1673 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1674 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1675 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1676 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1677 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1678
1679 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1681
1682 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1683 static_cast<CShuffleDataType*>(p_shared_0),
1684 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1685
1686 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1687 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1688 make_tuple(
1691 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1692 M1, // M1 = MWave
1693 M2, // M2 * M3 * M4 = MPerXdl
1694 M3,
1695 M4)),
1698 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1699 N1, // N1 = NWave
1700 N2))), // N2 = NPerXdl
1702 make_tuple(
1704
1705 // calculate origin of thread output tensor on global memory
1706 // blockwise GEMM c matrix starting index
1707 const auto c_thread_mtx_on_block =
1708 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1709
1710 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1711 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1712
1713 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1715 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1718
1719 const auto m_thread_data_on_block_idx =
1720 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1721 make_multi_index(m_thread_data_on_block));
1722
1723 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1728
1729 const auto n_thread_data_on_block_idx =
1730 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1731 make_multi_index(n_thread_data_on_block));
1732
1733 // shuffle: threadwise copy C from VGPR to LDS
1734 auto c_thread_copy_vgpr_to_lds =
1736 CShuffleDataType,
1737 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1738 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1740 Sequence<CShuffleMXdlPerWavePerShuffle,
1741 CShuffleNXdlPerWavePerShuffle,
1742 I1,
1743 I1,
1744 M2,
1745 I1,
1746 M4,
1747 I1>,
1749 7,
1750 1,
1752 1,
1753 true>{
1754 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1756 0,
1757 m_thread_data_on_block_idx[I1],
1758 n_thread_data_on_block_idx[I1],
1759 m_thread_data_on_block_idx[I2],
1760 m_thread_data_on_block_idx[I3],
1761 m_thread_data_on_block_idx[I4],
1762 n_thread_data_on_block_idx[I2]),
1764
1765 // shuffle: blockwise copy C from LDS to global
1766 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1767 ThisThreadBlock, // ThreadGroup
1768 CElementwiseOperation, // ElementwiseOperation,
1769 CGlobalMemoryDataOperation, // DstInMemOp,
1770 Sequence<1,
1771 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1772 1,
1773 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1774 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1775 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1776 CShuffleDataType, // typename SrcData,
1777 CDataType, // typename DstData,
1778 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1779 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1780 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1781 3, // index_t VectorDim,
1782 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1783 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1784 false> // bool ThreadTransferDstResetCoordinateAfterRun>
1785 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1786 make_multi_index(0, 0, 0, 0),
1787 c_grid_desc_mblock_mperblock_nblock_nperblock,
1788 make_multi_index(block_m_id, 0, block_n_id, 0),
1789 c_element_op};
1790
1791 // space filling curve for threadwise C in VGPR
1792 constexpr auto sfc_c_vgpr =
1795 Sequence<CShuffleMXdlPerWavePerShuffle,
1796 CShuffleNXdlPerWavePerShuffle,
1797 1,
1798 1,
1799 M2,
1800 1,
1801 M4,
1802 1>>{};
1803
1804 // space filling curve for shuffled blockwise C in global mem
1805 constexpr auto sfc_c_global =
1808 Sequence<1,
1809 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1810 1,
1811 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1812
1813 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1814
1815 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1816
1817 static_for<0, num_access, 1>{}([&](auto access_id) {
1818 // make sure it's safe to write to LDS
1820
1821 // each thread write its data from VGPR to LDS
1822 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1823 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1824 c_thread_buf,
1825 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1826 c_shuffle_block_buf);
1827
1828 // make sure it's safe to read from LDS
1830
1831 // each block copy its data from LDS to global
1832 c_shuffle_block_copy_lds_to_global.Run(
1833 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1834 c_shuffle_block_buf,
1835 c_grid_desc_mblock_mperblock_nblock_nperblock,
1836 c_grid_buf);
1837
1838 if constexpr(access_id < num_access - 1)
1839 {
1840 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1841
1842 // move on C
1843 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1844 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1845 }
1846 });
1847 }
1848 }
1849
1850 template <bool HasMainKBlockLoop,
1851 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1852 TailNumber TailNum = TailNumber::Odd>
1853 __device__ static void Run_2Lds(const ADataType* p_a_grid,
1854 const BDataType* p_b_grid,
1855 CDataType* p_c_grid,
1856 void* p_shared_0,
1857 void* p_shared_1,
1858 const Problem& problem,
1859 const index_t k_id,
1860 const index_t Kt)
1861 {
1862 index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1863 index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
1864 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1865 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1866 const auto b_grid_desc_bpreshuffled =
1867 MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1868 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
1869 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1870
1871 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1873 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1874
1875 Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
1876 decltype(b_grid_desc_bpreshuffled),
1877 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1878 HasMainKBlockLoop,
1879 CGlobalMemoryDataOperation,
1880 TailNum>(p_a_grid,
1881 p_b_grid,
1882 p_c_grid,
1883 p_shared_0,
1884 p_shared_1,
1885 problem,
1886 a_grid_desc_ak0_m_ak1,
1887 b_grid_desc_bpreshuffled,
1888 c_grid_desc_mblock_mperblock_nblock_nperblock,
1889 k_id);
1890 }
1891};
1892
1893} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
ushort bhalf_t
Definition data_type.hpp:30
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition get_id.hpp:45
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
_Float16 half_t
Definition data_type.hpp:31
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
__global__ void kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:75
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
constexpr auto BlockGemmBPreshufflePipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp:41
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__global__ void kernel_gemm_xdl_cshuffle_v3_b_preshuffle(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:36
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
unsigned int uint32_t
Definition stdint.h:126
signed int int32_t
Definition stdint.h:123
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:627
__host__ __device__ bool IsReduceAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:647
bool is_reduce
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:660
__host__ __device__ bool IsAtomicAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:652
const ADataType * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:657
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:659
const BDataType * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:658
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false)
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:628
index_t N
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:609
index_t KBatch
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:614
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:618
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:616
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:598
index_t M
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:608
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:615
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:617
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:619
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_)
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:573
index_t K
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:610
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:621
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:620
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:613
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:611
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:622
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:612
index_t a_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:696
index_t c_reduce_offset
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:697
__device__ SplitKBatchOffset(Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:666
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:157
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BPreshuffled &b_grid_desc_bpreshuffled, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const index_t k_id)
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:1508
remove_cvref_t< decltype(BlockGemmBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:842
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition type.hpp:177
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129