gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp Source File

gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp Source File#

Composable Kernel: gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp Source File
gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
18
19namespace ck {
20template <typename FloatAB,
21 typename FloatGemmAcc,
22 typename FloatCShuffle,
23 typename DsDataType,
24 typename FloatE,
25 typename FloatReduceAcc,
26 typename RsDataType,
27 typename AElementwiseOperation,
28 typename BElementwiseOperation,
29 typename CDEElementwiseOperation,
30 typename QsElementwiseOperation,
31 typename RsElementwiseOperation,
32 typename ThreadReduceOperations,
33 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
34 typename RsGlobalMemoryDataOperation,
35 typename AGridDesc_M_K,
36 typename BGridDesc_N_K,
37 typename EGridDesc_M_N,
38 typename RGridDesc_M,
39 index_t NumGemmKPrefetchStage,
40 index_t BlockSize,
41 index_t MPerBlock,
42 index_t NPerBlock,
43 index_t KPerBlock,
44 index_t AK1Value,
45 index_t BK1Value,
46 index_t MPerXdl,
47 index_t NPerXdl,
48 index_t MXdlPerWave,
49 index_t NXdlPerWave,
50 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
51 typename ABlockTransferThreadClusterArrangeOrder,
52 typename ABlockTransferSrcAccessOrder,
53 index_t ABlockTransferSrcVectorDim,
54 index_t ABlockTransferSrcScalarPerVector,
55 index_t ABlockTransferDstScalarPerVector_AK1,
56 bool AThreadTransferSrcResetCoordinateAfterRun,
57 index_t ABlockLdsExtraM,
58 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
59 typename BBlockTransferThreadClusterArrangeOrder,
60 typename BBlockTransferSrcAccessOrder,
61 index_t BBlockTransferSrcVectorDim,
62 index_t BBlockTransferSrcScalarPerVector,
63 index_t BBlockTransferDstScalarPerVector_BK1,
64 bool BThreadTransferSrcResetCoordinateAfterRun,
65 index_t BBlockLdsExtraN,
66 index_t CShuffleMXdlPerWavePerShuffle,
67 index_t CShuffleNXdlPerWavePerShuffle,
68 typename CDRThreadTransferClusterLengths_MPerBlock_NPerBlock,
69 index_t CDEReduceThreadTransferScalarPerVector_NPerBlock,
70 index_t RThreadTransferDstScalarPerVector_MPerBlock,
71 LoopScheduler LoopSched,
74{
75 static constexpr index_t NumDTensor = DsDataType::Size();
76 static constexpr index_t NumRTensor = RsDataType::Size();
77
78 static constexpr auto I0 = Number<0>{};
79 static constexpr auto I1 = Number<1>{};
80 static constexpr auto I2 = Number<2>{};
81 static constexpr auto I3 = Number<3>{};
82 static constexpr auto I4 = Number<4>{};
83 static constexpr auto I5 = Number<5>{};
84 static constexpr auto I6 = Number<6>{};
85 static constexpr auto I7 = Number<7>{};
86
87 // K1 should be Number<...>
88 static constexpr auto AK1 = Number<AK1Value>{};
89 static constexpr auto BK1 = Number<BK1Value>{};
90 static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
91 static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
92
94
97
98 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
99 {
100 // A matrix in LDS memory, dst of blockwise copy
104 }
105
106 __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
107 {
108 // B matrix in LDS memory, dst of blockwise copy
112 }
113
114 __host__ __device__ static constexpr auto
116 {
117 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
118 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
119
120 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
124 I1,
126
127 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
128 }
129
130 // ck::Tuple<const T0DataType*, const T1DataType*, ...>
131 template <typename Ts, bool isConst = true>
132 static constexpr auto MakeTsGridPointer()
133 {
134 return generate_tuple(
135 [&](auto i) {
136 using T = remove_cvref_t<tuple_element_t<i.value, Ts>>;
137 if constexpr(isConst)
138 return static_cast<const T*>(nullptr);
139 else
140 return static_cast<T*>(nullptr);
141 },
142 Number<Ts::Size()>{});
143 }
144
145 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
146 {
147 // LDS allocation for A and B: be careful of alignment
148 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
149 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
150
151 // lds max alignment
152 constexpr auto max_lds_align = math::lcm(AK1, BK1);
153
154 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
155 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
156
157 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
158 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
159
160 // LDS allocation for C shuffle in LDS
161 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
163
164 constexpr auto c_block_size =
165 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
166
167 return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
168 sizeof(FloatAB),
169 c_block_size * sizeof(FloatCShuffle));
170 }
171
172 // A desc for source in blockwise copy
173 __host__ __device__ static constexpr auto
174 MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
175 {
176 const auto M = a_grid_desc_m_k.GetLength(I0);
177 const auto K = a_grid_desc_m_k.GetLength(I1);
178
179 const auto AK0 = K / AK1;
180
181 return transform_tensor_descriptor(a_grid_desc_m_k,
186 }
187
188 // B desc for source in blockwise copy
189 __host__ __device__ static constexpr auto
190 MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
191 {
192 const auto N = b_grid_desc_n_k.GetLength(I0);
193 const auto K = b_grid_desc_n_k.GetLength(I1);
194
195 const auto BK0 = K / BK1;
196
197 return transform_tensor_descriptor(b_grid_desc_n_k,
202 }
203
205
206 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
207 template <typename Block2ETileMap>
208 __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
209 const BGridDesc_N_K& b_grid_desc_n_k,
210 const EGridDesc_M_N& e_grid_desc_m_n,
211 const RGridDesc_M& r_grid_desc_m,
212 const Block2ETileMap& block_2_etile_map)
213 {
214 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
215 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
216 "Invalid tuning param!");
217
218 static_assert(AGridDesc_M_K::GetNumOfDimension() == 2);
219 static_assert(BGridDesc_N_K::GetNumOfDimension() == 2);
220 static_assert(EGridDesc_M_N::GetNumOfDimension() == 2);
221
222 const auto M = a_grid_desc_m_k.GetLength(I0);
223 const auto N = b_grid_desc_n_k.GetLength(I0);
224 const auto K = a_grid_desc_m_k.GetLength(I1);
225
226 if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
227 return false;
228
229 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
230 return false;
231
232 if(M != r_grid_desc_m.GetLength(I0))
233 return false;
234
235 // check gridwise gemm pipeline
236 const auto num_k_loop = K / KPerBlock;
237
238 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
239 {
240 return false;
241 }
242
243 if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
244 {
245 return false;
246 }
247
248 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
249 return true;
250 }
251
252 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
253 {
254 const index_t num_loop = K / KPerBlock;
255
256 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
257 }
258
259 __host__ __device__ static constexpr auto
260 MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
261 {
262 const auto M = e_grid_desc_m_n.GetLength(I0);
263 const auto N = e_grid_desc_m_n.GetLength(I1);
264
265 const auto MBlock = M / MPerBlock;
266 const auto NBlock = N / NPerBlock;
267
268 const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
269 e_grid_desc_m_n,
274
275 return e_grid_desc_mblock_mperblock_nblock_nperblock;
276 }
277
278 __host__ __device__ static constexpr auto
279 MakeRGridDescriptor_MBlock_MPerBlock(const RGridDesc_M& r_grid_desc_m)
280 {
281 const auto M = r_grid_desc_m.GetLength(I0);
282 const auto MBlock = M / MPerBlock;
283
284 const auto r_grid_desc_mblock_mperblock = transform_tensor_descriptor(
285 r_grid_desc_m,
289
290 return r_grid_desc_mblock_mperblock;
291 }
292
293 // return block_id to E matrix tile idx (m0, n0) mapping
294 __host__ __device__ static constexpr auto
295 MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
296 {
298 e_grid_desc_m_n);
299 }
300
302 remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
304 remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
307 EGridDesc_M_N{}))>;
308
309 // Support 2 dimension in the future. Not only M
312
314 remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
315
318
319 template <bool HasMainKBlockLoop,
320 typename AGridDesc_AK0_M_AK1,
321 typename BGridDesc_BK0_N_BK1,
322 typename Block2ETileMap>
323 __device__ static void
324 Run(const FloatAB* __restrict__ p_a_grid,
325 const FloatAB* __restrict__ p_b_grid,
326 DsGridPointer p_ds_grid,
327 FloatE* __restrict__ p_e_grid,
328 RsGridPointer p_rs_grid,
329 void* __restrict__ p_shared,
330 const AElementwiseOperation& a_element_op,
331 const BElementwiseOperation& b_element_op,
332 const CDEElementwiseOperation& cde_element_op,
333 const QsElementwiseOperation& qs_element_op,
334 const RsElementwiseOperation& rs_element_op,
335 const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
336 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
338 NumDTensor>&
339 ds_grid_desc_mblock_mperblock_nblock_nperblock, // FIXME: Ds desc may be of different
341 e_grid_desc_mblock_mperblock_nblock_nperblock,
343 NumRTensor>&
344 rs_grid_desc_mblock_mperblock, // FIXME: Rs desc may be of different
345 const Block2ETileMap& block_2_etile_map)
346 {
347 // FIXME - Share code with other gemm kernel
348 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
349 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
350
351 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
352 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
353
354 const auto ds_grid_buf = generate_tuple(
355 [&](auto i) {
357 p_ds_grid[i],
358 ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
359 },
361
363 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
364
365 auto rs_grid_buf = generate_tuple(
366 [&](auto i) {
368 p_rs_grid(i), rs_grid_desc_mblock_mperblock[i].GetElementSpaceSize());
369 },
371
372 // divide block work by [M, N]
373 const auto block_work_idx =
374 block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
375
376 if(!block_2_etile_map.ValidCTileIndex(
377 block_work_idx,
378 make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
379 e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
380 {
381 return;
382 }
383
384 // HACK: this force m/n_block_data_idx_on_grid into SGPR
385 const index_t m_block_data_idx_on_grid =
386 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
387
388 const index_t n_block_data_idx_on_grid =
389 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
390
391 // lds max alignment
392 constexpr auto max_lds_align = math::lcm(AK1, BK1);
393
394 // A matrix in LDS memory, dst of blockwise copy
395 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
396
397 // B matrix in LDS memory, dst of blockwise copy
398 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
399
400 // A matrix blockwise copy
401 auto a_blockwise_copy =
403 AElementwiseOperation,
407 ABlockTransferThreadClusterLengths_AK0_M_AK1,
408 ABlockTransferThreadClusterArrangeOrder,
409 FloatAB,
410 FloatAB,
411 decltype(a_grid_desc_ak0_m_ak1),
412 decltype(a_block_desc_ak0_m_ak1),
413 ABlockTransferSrcAccessOrder,
415 ABlockTransferSrcVectorDim,
416 2,
417 ABlockTransferSrcScalarPerVector,
418 ABlockTransferDstScalarPerVector_AK1,
419 1,
420 1,
421 AThreadTransferSrcResetCoordinateAfterRun,
422 true,
423 NumGemmKPrefetchStage>(
424 a_grid_desc_ak0_m_ak1,
425 make_multi_index(0, m_block_data_idx_on_grid, 0),
426 a_element_op,
427 a_block_desc_ak0_m_ak1,
428 make_multi_index(0, 0, 0),
430
431 // B matrix blockwise copy
432 auto b_blockwise_copy =
434 BElementwiseOperation,
438 BBlockTransferThreadClusterLengths_BK0_N_BK1,
439 BBlockTransferThreadClusterArrangeOrder,
440 FloatAB,
441 FloatAB,
442 decltype(b_grid_desc_bk0_n_bk1),
443 decltype(b_block_desc_bk0_n_bk1),
444 BBlockTransferSrcAccessOrder,
446 BBlockTransferSrcVectorDim,
447 2,
448 BBlockTransferSrcScalarPerVector,
449 BBlockTransferDstScalarPerVector_BK1,
450 1,
451 1,
452 BThreadTransferSrcResetCoordinateAfterRun,
453 true,
454 NumGemmKPrefetchStage>(
455 b_grid_desc_bk0_n_bk1,
456 make_multi_index(0, n_block_data_idx_on_grid, 0),
457 b_element_op,
458 b_block_desc_bk0_n_bk1,
459 make_multi_index(0, 0, 0),
461
462 // GEMM definition
463 // c_mtx += transpose(a_mtx) * b_mtx
464 // a_mtx[K0PerBlock, MPerBlock] is in LDS
465 // b_mtx[K0PerBlock, NPerBlock] is in LDS
466 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
467 // register
468 // sanity check
469 constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
470 constexpr bool is_single_rate_mfma =
472 lcm_AK1_BK1 <= 4) ||
473 (is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
475 lcm_AK1_BK1 < 32))
476 ? true
477 : false;
478 constexpr auto is_scale_mfma = false;
479 constexpr index_t KPack = math::max(
480 lcm_AK1_BK1,
482 selected_mfma.k_per_blk);
483
485 BlockSize,
486 FloatAB,
487 FloatAB,
488 FloatGemmAcc,
489 decltype(a_block_desc_ak0_m_ak1),
490 decltype(b_block_desc_bk0_n_bk1),
491 MPerXdl,
492 NPerXdl,
493 MXdlPerWave,
494 NXdlPerWave,
495 KPack,
496 LoopSched>();
497
498 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
499
500 // LDS allocation for A and B: be careful of alignment
501 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
502 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
503
505 static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
506
508 static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
509 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
510
511 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
512 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
513
514 // gridwise GEMM pipeline
515 const auto gridwise_gemm_pipeline =
517
518 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
519 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
520 KPerBlock);
521
522 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
523 a_block_desc_ak0_m_ak1,
524 a_blockwise_copy,
525 a_grid_buf,
526 a_block_buf,
527 a_block_slice_copy_step,
528 b_grid_desc_bk0_n_bk1,
529 b_block_desc_bk0_n_bk1,
530 b_blockwise_copy,
531 b_grid_buf,
532 b_block_buf,
533 b_block_slice_copy_step,
534 blockwise_gemm,
535 c_thread_buf,
536 num_k_block_main_loop);
537
538 // shuffle C + Ds + reduction + write out
539 {
540 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
541 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
542 "wrong!");
543
544 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
545 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
546
547 // TODO: hacky, fix it!
548 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
549 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
550
551 // TODO: hacky, fix it!
552 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
553 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
554 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
555
556 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
557 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
558 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
559 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
560 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
561 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
562 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
563 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
564
565 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
567
568 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
569 static_cast<FloatCShuffle*>(p_shared),
570 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
571
572 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
573 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
577 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
578 M1, // M1 = MWave
579 M2, // M2 * M3 * M4 = MPerXdl
580 M3,
581 M4)),
584 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
585 N1, // N1 = NWave
586 N2))), // N2 = NPerXdl
590
591 // calculate origin of thread output tensor on global memory
592 // blockwise GEMM c matrix starting index
593 const auto c_thread_mtx_on_block =
594 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
595
596 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
597 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
598
599 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
601 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
604
605 const auto m_thread_data_on_block_idx =
606 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
607 make_multi_index(m_thread_data_on_block));
608
609 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
614
615 const auto n_thread_data_on_block_idx =
616 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
617 make_multi_index(n_thread_data_on_block));
618
619 // shuffle: threadwise copy C from VGPR to LDS
620 auto c_thread_copy_vgpr_to_lds =
622 FloatCShuffle,
623 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
624 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
626 Sequence<CShuffleMXdlPerWavePerShuffle,
627 CShuffleNXdlPerWavePerShuffle,
628 I1,
629 I1,
630 M2,
631 I1,
632 M4,
633 I1>,
635 7,
636 1,
638 1,
639 true>{
640 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
642 0,
643 m_thread_data_on_block_idx[I1],
644 n_thread_data_on_block_idx[I1],
645 m_thread_data_on_block_idx[I2],
646 m_thread_data_on_block_idx[I3],
647 m_thread_data_on_block_idx[I4],
648 n_thread_data_on_block_idx[I2]),
650
651 // space filling curve for threadwise C in VGPR
652 constexpr auto sfc_c_vgpr =
655 Sequence<CShuffleMXdlPerWavePerShuffle,
656 CShuffleNXdlPerWavePerShuffle,
657 1,
658 1,
659 M2,
660 1,
661 M4,
662 1>>{};
663
664 // space filling curve for shuffled blockwise C in global mem
665 constexpr auto sfc_der_global =
668 Sequence<1,
669 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
670 1,
671 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
672
673 // TODO: this should be implemented as a blockwise reduction
674 // LDS c_reduce_block_desc_mperblock_nperblock
675 constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
676 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
680 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)),
683 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))),
686
687 static_assert(CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I0) *
688 CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I1) ==
689 BlockSize,
690 "wrong!");
691
692 static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) %
693 CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I0) ==
694 0 &&
695 (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) %
696 CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I1) ==
697 0,
698 "wrong!");
699
700 constexpr index_t mreduce_per_thread =
701 (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) /
702 CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I0);
703
704 constexpr index_t nreduce_per_thread =
705 (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) /
706 CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I1);
707
708 constexpr auto c_reduce_thread_lengths_mperblock_nperblock =
710
711 // VGPR cde_reduce_thread_desc_mperblock_nperblock
712 constexpr auto cde_reduce_thread_desc_mperblock_nperblock =
715
716 constexpr auto r_thread_desc_mperblock =
718
719 constexpr auto r_thread_desc_mblock_mperblock =
721
723 cde_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
724
725 // reduce: threadwise copy from LDS to VGPR
726 constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor(
727 CDRThreadTransferClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{});
728
729 const auto c_reduce_thread_cluster_idx =
730 c_reduce_thread_cluster_desc.CalculateBottomIndex(
732
733 const auto c_reduce_thread_data_idx_begin =
734 c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock;
735
736 // To apply D0, D1, ... and reduction.
737 // Copy c shuffle from LDS back to VGPR
738 auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
739 FloatCShuffle,
740 FloatReduceAcc,
741 decltype(c_reduce_block_desc_mperblock_nperblock),
742 decltype(cde_reduce_thread_desc_mperblock_nperblock),
743 decltype(c_reduce_thread_lengths_mperblock_nperblock),
745 1,
746 CDEReduceThreadTransferScalarPerVector_NPerBlock,
747 1,
748 true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
749
750 // Copy result of reduction back from VGPR to global
751 auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple(
752 [&](auto I) {
753 auto p_r_grid = p_rs_grid[I];
754 auto r_element_op = rs_element_op[I];
755 auto r_grid_desc_mblock_mperblock = rs_grid_desc_mblock_mperblock[I];
756
758 FloatReduceAcc,
759 remove_pointer_t<decltype(p_r_grid)>,
760 decltype(r_thread_desc_mblock_mperblock),
761 decltype(r_grid_desc_mblock_mperblock),
762 decltype(r_element_op),
765 1,
766 RThreadTransferDstScalarPerVector_MPerBlock,
767 RsGlobalMemoryDataOperation::At(I),
768 1,
769 false>{r_grid_desc_mblock_mperblock,
770 make_multi_index(block_work_idx[I0], // mblock
771 c_reduce_thread_data_idx_begin[I0]), // mperblock
772 r_element_op};
773 },
775
776 // D0, D1, ..., Dn
777 constexpr auto cde_reduce_thread_desc_I1_mperblock_I1_nperblock =
780
781 // FIXME: Decrease usage of VGPR
782 // Apply pointwise lambda function from multi-source (Global and LDS) into VGPR
783 auto ds_thread_buf = generate_tuple(
784 [&](auto) {
786 cde_reduce_thread_desc_I1_mperblock_I1_nperblock.GetElementSpaceSize());
787 },
789
790 // Copy D0, D1, ..., Dn from global to VGPR
791 auto ds_thread_copy_global_to_vgpr = generate_tuple(
792 [&](auto I) {
793 using DDataType = remove_cvref_t<tuple_element_t<I.value, DsDataType>>;
795 DDataType,
796 FloatReduceAcc,
797 decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]),
798 decltype(cde_reduce_thread_desc_I1_mperblock_I1_nperblock),
801 3,
802 CDEReduceThreadTransferScalarPerVector_NPerBlock,
803 1,
804 true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I],
806 I0,
807 m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
808 I0,
809 n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
810 },
812
813 auto e_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
814 FloatReduceAcc,
815 FloatE,
816 decltype(cde_reduce_thread_desc_I1_mperblock_I1_nperblock),
817 decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
820 Sequence<0, 1, 2, 3>, // DimAccessOrder
821 3, // DstVectorDim
822 CDEReduceThreadTransferScalarPerVector_NPerBlock,
824 1,
825 true>{
826 e_grid_desc_mblock_mperblock_nblock_nperblock,
828 m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
829 I0,
830 n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]),
832
833 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
834
835 static_assert(num_access == sfc_der_global.GetNumOfAccess(), "wrong!");
836
837 static_for<0, num_access, 1>{}([&](auto access_id) {
838 // make sure it's safe to read from LDS
840
841 // each thread shuffle data from VGPR to LDS
842 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
843 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
844 c_thread_buf,
845 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
846 c_shuffle_block_buf);
847
848 // make sure it's safe to write to LDS
850
851 // Get shuffle data from LDS to VGPR
852 c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
853 c_shuffle_block_buf,
854 cde_reduce_thread_desc_mperblock_nperblock,
855 make_tuple(I0, I0),
856 e_thread_buf);
857
858 // Global read D0, D1, ...
859 static_for<0, NumDTensor, 1>{}([&](auto Id) {
860 auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(Id);
861 d_thread_copy_global_to_vgpr.Run(
862 ds_grid_desc_mblock_mperblock_nblock_nperblock[Id],
863 ds_grid_buf[Id],
864 cde_reduce_thread_desc_I1_mperblock_I1_nperblock,
865 make_tuple(I0, I0, I0, I0),
866 ds_thread_buf(Id));
867
868 if constexpr(access_id < num_access - 1)
869 {
870 // move on D0, D1, ...
871 constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id);
872 d_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
873 ds_grid_desc_mblock_mperblock_nblock_nperblock[Id], de_global_step);
874 }
875 });
876
877 // cde_element_op(e, c, d0, d1, ...);
878 static_for<0, cde_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
879 [&](auto i) {
880 const auto c_ds_src_data_refs = concat_tuple_of_reference(
881 tie(e_thread_buf[i]),
883 [&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; },
885 auto e_dst_data_refs = tie(e_thread_buf(i));
886 unpack2(cde_element_op, e_dst_data_refs, c_ds_src_data_refs);
887 });
888
889 // Global write E
890 e_thread_copy_vgpr_to_global.Run(cde_reduce_thread_desc_I1_mperblock_I1_nperblock,
891 make_tuple(I0, I0, I0, I0),
892 e_thread_buf,
893 e_grid_desc_mblock_mperblock_nblock_nperblock,
894 e_grid_buf);
895
896 if constexpr(access_id < num_access - 1)
897 {
898 // move on E
899 constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id);
900 e_thread_copy_vgpr_to_global.MoveDstSliceWindow(
901 e_grid_desc_mblock_mperblock_nblock_nperblock, de_global_step);
902 }
903
904 // reduction
905 static_for<0, NumRTensor, 1>{}([&](auto Ir) {
907 r_thread_desc_mperblock.GetElementSpaceSize());
908
909 auto& reduce_thread_copy_vgpr_to_global =
910 reduce_tuple_thread_copy_vgpr_to_global(Ir);
911
912 using ThreadReduceOperation =
913 remove_cvref_t<decltype(ThreadReduceOperations{}[Ir])>;
914
915 using ThreadwiseReduce =
916 ThreadwiseReduction<FloatReduceAcc,
917 decltype(cde_reduce_thread_desc_mperblock_nperblock),
918 decltype(r_thread_desc_mperblock),
919 ThreadReduceOperation,
920 false>;
921
922 // threadwise reduction
923 const auto reduce_identityVal =
924 ThreadReduceOperation::template GetIdentityValue<FloatReduceAcc>();
926 [&](auto I) { r_thread_buf(I) = reduce_identityVal; });
929 constexpr auto offset =
930 Number<cde_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
931 make_tuple(im, in))>{};
932
933 qs_element_op[Ir](e_thread_buf(offset), e_thread_buf(offset));
934 });
935 });
936 ThreadwiseReduce::Reduce(e_thread_buf, r_thread_buf);
937
938 // gridwise reduction
939 reduce_thread_copy_vgpr_to_global.Run(r_thread_desc_mblock_mperblock,
940 make_tuple(I0, I0),
941 r_thread_buf,
942 rs_grid_desc_mblock_mperblock[Ir],
943 rs_grid_buf(Ir));
944
945 if constexpr(access_id < num_access - 1)
946 {
947 // move on R0, R1, ...
948 constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id);
949 reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
950 rs_grid_desc_mblock_mperblock[Ir],
951 make_tuple(de_global_step[I0], de_global_step[I1]));
952 }
953 });
954 }); // copy c, d, e + reduction
955
956 } // shuffle C + Ds + reduction + write out
957 } // Run
958};
959
960} // namespace ck
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition blockwise_gemm_xdlops.hpp:620
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
typename remove_pointer< T >::type remove_pointer_t
Definition type.hpp:300
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
Definition block_to_ctile_map.hpp:261
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:74
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const ADataType *__restrict__ p_b_grid, DsGridPointer p_ds_grid, EDataType *__restrict__ p_e_grid, RsGridPointer p_rs_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const QsElementwiseOperation &qs_element_op, const RsElementwiseOperation &rs_element_op, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const StaticallyIndexedArray< EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, NumDTensor > &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const StaticallyIndexedArray< RGridDescriptor_MBlock_MPerBlock, NumRTensor > &rs_grid_desc_mblock_mperblock, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:324
__host__ static __device__ constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N &e_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:260
remove_cvref_t< decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:305
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const EGridDesc_M_N &e_grid_desc_m_n, const RGridDesc_M &r_grid_desc_m, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp:208
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition reduction_functions_threadwise.hpp:23
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340