gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp Source File

gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp Source File#

Composable Kernel: gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp Source File
gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
19
20namespace ck {
21
22// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
23template <typename GridwiseGemm,
24 typename FloatAB,
25 typename FloatC,
26 typename FloatC0,
27 typename AElementwiseOperation,
28 typename BElementwiseOperation,
29 typename AccElementwiseOperation,
30 typename CElementwiseOperation,
31 typename AGridDesc_AK0_M_AK1,
32 typename BGridDesc_BK0_N_BK1,
33 typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
34 typename C0GridDescriptor_NBlock_NPerBlock,
35 typename Block2CTileMap,
36 bool HasMainKBlockLoop>
37__global__ void
38#if CK_USE_LAUNCH_BOUNDS
40#endif
42 const FloatAB* __restrict__ p_a_grid,
43 const FloatAB* __restrict__ p_b_grid,
44 FloatC* __restrict__ p_c_grid, // MxN
45 const FloatC0* __restrict__ p_c0_bias_grid, // 1xN
46 const FloatC0* __restrict__ p_c0_add_grid, // MxN
47 const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN
48 const FloatC0* __restrict__ p_c0_beta_grid, // 1xN
49 const AElementwiseOperation a_element_op,
50 const BElementwiseOperation b_element_op,
51 const AccElementwiseOperation acc_element_op,
52 const CElementwiseOperation c_element_op,
53 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
54 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
55 const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
56 c_grid_desc_mblock_mperblock_nblock_nperblock,
57 const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock,
58 const Block2CTileMap block_2_ctile_map)
59{
60#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
61 defined(__gfx12__)
62 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
63 {
64 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
65
66 // TODO ANT: separate into MMA + Epilogue
67 GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
68 p_b_grid,
69 p_c_grid,
70 p_c0_bias_grid,
71 p_c0_add_grid,
72 p_c0_gamma_grid,
73 p_c0_beta_grid,
74 p_shared,
75 a_element_op,
76 b_element_op,
77 acc_element_op,
78 c_element_op,
79 a_grid_desc_ak0_m_ak1,
80 b_grid_desc_bk0_n_bk1,
81 c_grid_desc_mblock_mperblock_nblock_nperblock,
82 c0_grid_desc_nblock_nperblock,
83 block_2_ctile_map);
84 }
85 // TODO ANT: Run layernorm epilogue here
86#else
87 ignore = p_a_grid;
88 ignore = p_b_grid;
89 ignore = p_c_grid;
90 ignore = p_c0_bias_grid;
91 ignore = p_c0_add_grid;
92 ignore = p_c0_gamma_grid;
93 ignore = p_c0_beta_grid;
94 ignore = a_element_op;
95 ignore = b_element_op;
96 ignore = acc_element_op;
97 ignore = c_element_op;
98 ignore = a_grid_desc_ak0_m_ak1;
99 ignore = b_grid_desc_bk0_n_bk1;
100 ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
101 ignore = c0_grid_desc_nblock_nperblock;
102 ignore = block_2_ctile_map;
103#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
104}
105
106// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers
107// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example,
108// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128
109template <typename FloatAB,
110 typename FloatGemmAcc,
111 typename FloatCShuffle,
112 typename FloatC,
113 typename FloatC0,
114 typename FloatReduceAcc, // Data type after shuffle
115 typename AElementwiseOperation,
116 typename BElementwiseOperation,
117 typename AccElementwiseOperation,
118 typename CElementwiseOperation,
119 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
120 typename AGridDesc_AK0_M_AK1,
121 typename BGridDesc_BK0_N_BK1,
122 typename CGridDesc_M_N,
123 typename C0GridDesc_N,
124 index_t NumGemmKPrefetchStage,
125 index_t BlockSize,
126 index_t MPerBlock,
127 index_t NPerBlock,
128 index_t KPerBlock,
129 index_t AK1Value,
130 index_t BK1Value,
131 index_t MPerXdl,
132 index_t NPerXdl,
133 index_t MXdlPerWave,
134 index_t NXdlPerWave,
135 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
136 typename ABlockTransferThreadClusterArrangeOrder,
137 typename ABlockTransferSrcAccessOrder,
138 index_t ABlockTransferSrcVectorDim,
139 index_t ABlockTransferSrcScalarPerVector,
140 index_t ABlockTransferDstScalarPerVector_AK1,
141 bool AThreadTransferSrcResetCoordinateAfterRun,
142 index_t ABlockLdsExtraM,
143 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
144 typename BBlockTransferThreadClusterArrangeOrder,
145 typename BBlockTransferSrcAccessOrder,
146 index_t BBlockTransferSrcVectorDim,
147 index_t BBlockTransferSrcScalarPerVector,
148 index_t BBlockTransferDstScalarPerVector_BK1,
149 bool BThreadTransferSrcResetCoordinateAfterRun,
150 index_t BBlockLdsExtraN,
151 index_t CShuffleMXdlPerWavePerShuffle,
152 index_t CShuffleNXdlPerWavePerShuffle,
153 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
154 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
155 typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
156 index_t CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
157 LoopScheduler LoopSched,
160{
161 static constexpr auto I0 = Number<0>{};
162 static constexpr auto I1 = Number<1>{};
163 static constexpr auto I2 = Number<2>{};
164 static constexpr auto I3 = Number<3>{};
165 static constexpr auto I4 = Number<4>{};
166 static constexpr auto I5 = Number<5>{};
167 static constexpr auto I6 = Number<6>{};
168 static constexpr auto I7 = Number<7>{};
169
170 // K1 should be Number<...>
171 static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
172 static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
173 static constexpr auto AK1 = Number<AK1Value>{};
174 static constexpr auto BK1 = Number<BK1Value>{};
175
177
180
181 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
182 {
183 // A matrix in LDS memory, dst of blockwise copy
187 }
188
189 __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
190 {
191 // B matrix in LDS memory, dst of blockwise copy
195 }
196
197 __host__ __device__ static constexpr auto
199 {
200 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
201 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
202
203 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
207 I1,
209
210 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
211 }
212
213 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
214 {
215 // LDS allocation for A and B: be careful of alignment
216 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
217 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
218
219 // lds max alignment
220 constexpr auto max_lds_align = math::lcm(AK1, BK1);
221
222 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
223 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
224
225 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
226 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
227
228 // LDS allocation for C shuffle in LDS
229 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
231
232 // Align 16 bytes (maximum LDS read/write width)
233 constexpr auto c_block_size_aligned =
235 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() *
236 sizeof(FloatCShuffle),
237 16) /
238 sizeof(FloatCShuffle);
239
240 // LDS allocation for reduction workspace
241 constexpr index_t c_lds_workspace_size = BlockSize;
242
243 return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
244 sizeof(FloatAB),
245 c_block_size_aligned * sizeof(FloatCShuffle) +
246 c_lds_workspace_size * sizeof(FloatReduceAcc));
247 }
248
249 template <
250 InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
251 __device__ static bool constexpr IsValidCompilationParameter()
252 {
253 return ck::tensor_operation::device::IsValidGemmCompilationParameter<
254 BlockSize,
255 MPerBlock,
256 NPerBlock,
257 MPerXdl,
258 NPerXdl,
259 MXdlPerWave,
260 NXdlPerWave,
261 FloatC,
262 CGlobalMemoryDataOperation>();
263 }
264
265 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
266 template <typename Block2CTileMap>
267 __host__ __device__ static constexpr bool
268 CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
269 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
270 const CGridDesc_M_N& c_grid_desc_m_n,
271 const Block2CTileMap& block_2_ctile_map)
272 {
273 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
274 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
275 "Invalid tuning param!");
276
277 const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
278 const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
279 const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
280
281 if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
282 return false;
283
284 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
285 return false;
286
287 // in order to reduce N dim without elaborate sync across CUs in single kernel, one
288 // workgroup must span the entire N extent
289 if(math::integer_divide_ceil(N, NPerBlock) > 1)
290 {
291 return false;
292 }
293
294 // static check: all waves in the workgroups combined must cover whole N extent in order
295 // to have efficient N-dim reduction
296 static_assert(CShuffleNXdlPerWavePerShuffle == NXdlPerWave,
297 "condition not met for efficient layernorm");
298
299 // check gridwise gemm pipeline
300 const auto num_k_loop = K / KPerBlock;
301
302 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
303 {
304 return false;
305 }
306
307 if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
308 {
309 return false;
310 }
311
312 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
313 return true;
314 }
315
316 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
317 {
318 const index_t num_loop = K / KPerBlock;
319
320 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
321 }
322
323 __host__ __device__ static constexpr auto
324 MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
325 {
326 const auto M = c_grid_desc_m_n.GetLength(I0);
327 const auto N = c_grid_desc_m_n.GetLength(I1);
328
329 const auto MBlock = M / MPerBlock;
330 const auto NBlock = N / NPerBlock;
331
332 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
333 c_grid_desc_m_n,
338
339 return c_grid_desc_mblock_mperblock_nblock_nperblock;
340 }
341
342 // for bias, beta, gamma
343 __host__ __device__ static constexpr auto
344 MakeC0GridDescriptor_NBlock_NPerBlock(const C0GridDesc_N& c0_grid_desc_n)
345 {
346 const auto N = c0_grid_desc_n.GetLength(I0);
347 const auto NBlock = N / NPerBlock;
348
349 const auto c0_grid_desc_nblock_nperblock = transform_tensor_descriptor(
350 c0_grid_desc_n,
354
355 return c0_grid_desc_nblock_nperblock;
356 }
357
358 // return block_id to C matrix tile idx (m0, n0) mapping
359 __host__ __device__ static constexpr auto
360 MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
361 {
363 c_grid_desc_m_n);
364 }
365
368 CGridDesc_M_N{}))>;
369
371 remove_cvref_t<decltype(MakeC0GridDescriptor_NBlock_NPerBlock(C0GridDesc_N{}))>;
372
374 remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
375
376 template <bool HasMainKBlockLoop, typename Block2CTileMap>
377 __device__ static void
378 Run(const FloatAB* __restrict__ p_a_grid,
379 const FloatAB* __restrict__ p_b_grid,
380 FloatC* __restrict__ p_c_grid,
381 const FloatC0* __restrict__ p_c0_bias_grid, // 1xN
382 const FloatC0* __restrict__ p_c0_add_grid, // MxN
383 const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN
384 const FloatC0* __restrict__ p_c0_beta_grid, // 1xN
385 void* __restrict__ p_shared,
386 const AElementwiseOperation& a_element_op,
387 const BElementwiseOperation& b_element_op,
388 const AccElementwiseOperation& acc_element_op,
389 const CElementwiseOperation& c_element_op,
390 const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
391 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
393 c_grid_desc_mblock_mperblock_nblock_nperblock,
394 const C0GridDescriptor_NBlock_NPerBlock& c0_grid_desc_nblock_nperblock,
395 const Block2CTileMap& block_2_ctile_map)
396 {
397 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
398 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
399 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
400 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
402 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
403 auto c0_bias_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
404 p_c0_bias_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize());
405 // Note: c0_add is of same layout as c so we don't declare new c0_add_desc here
407 p_c0_add_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
408 auto c0_gamma_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
409 p_c0_gamma_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize());
410 auto c0_beta_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
411 p_c0_beta_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize());
412
413 // divide block work by [M, N]
414 const auto block_work_idx =
415 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
416
417 if(!block_2_ctile_map.ValidCTileIndex(
418 block_work_idx,
419 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
420 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
421 {
422 return;
423 }
424
425 // HACK: this force m/n_block_data_idx_on_grid into SGPR
426 const index_t m_block_data_idx_on_grid =
427 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
428
429 const index_t n_block_data_idx_on_grid =
430 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
431
432 // lds max alignment
433 constexpr auto max_lds_align = math::lcm(AK1, BK1);
434
435 // A matrix in LDS memory, dst of blockwise copy
436 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
437
438 // B matrix in LDS memory, dst of blockwise copy
439 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
440
441 // A matrix blockwise copy
442 auto a_blockwise_copy =
444 AElementwiseOperation,
448 ABlockTransferThreadClusterLengths_AK0_M_AK1,
449 ABlockTransferThreadClusterArrangeOrder,
450 FloatAB,
451 FloatAB,
452 decltype(a_grid_desc_ak0_m_ak1),
453 decltype(a_block_desc_ak0_m_ak1),
454 ABlockTransferSrcAccessOrder,
456 ABlockTransferSrcVectorDim,
457 2,
458 ABlockTransferSrcScalarPerVector,
459 ABlockTransferDstScalarPerVector_AK1,
460 1,
461 1,
462 AThreadTransferSrcResetCoordinateAfterRun,
463 true,
464 NumGemmKPrefetchStage>(
465 a_grid_desc_ak0_m_ak1,
466 make_multi_index(0, m_block_data_idx_on_grid, 0),
467 a_element_op,
468 a_block_desc_ak0_m_ak1,
469 make_multi_index(0, 0, 0),
471
472 // B matrix blockwise copy
473 auto b_blockwise_copy =
475 BElementwiseOperation,
479 BBlockTransferThreadClusterLengths_BK0_N_BK1,
480 BBlockTransferThreadClusterArrangeOrder,
481 FloatAB,
482 FloatAB,
483 decltype(b_grid_desc_bk0_n_bk1),
484 decltype(b_block_desc_bk0_n_bk1),
485 BBlockTransferSrcAccessOrder,
487 BBlockTransferSrcVectorDim,
488 2,
489 BBlockTransferSrcScalarPerVector,
490 BBlockTransferDstScalarPerVector_BK1,
491 1,
492 1,
493 BThreadTransferSrcResetCoordinateAfterRun,
494 true,
495 NumGemmKPrefetchStage>(
496 b_grid_desc_bk0_n_bk1,
497 make_multi_index(0, n_block_data_idx_on_grid, 0),
498 b_element_op,
499 b_block_desc_bk0_n_bk1,
500 make_multi_index(0, 0, 0),
502
503 // GEMM definition
504 // c_mtx += transpose(a_mtx) * b_mtx
505 // a_mtx[K0PerBlock, MPerBlock] is in LDS
506 // b_mtx[K0PerBlock, NPerBlock] is in LDS
507 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
508 // register
509 // sanity check
510 constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
511 constexpr bool is_single_rate_mfma =
513 lcm_AK1_BK1 <= 4) ||
514 (is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
516 lcm_AK1_BK1 < 32))
517 ? true
518 : false;
519 constexpr auto is_scale_mfma = false;
520 constexpr index_t KPack = math::max(
521 lcm_AK1_BK1,
523 selected_mfma.k_per_blk);
524
526 BlockSize,
527 FloatAB,
528 FloatAB,
529 FloatGemmAcc,
530 decltype(a_block_desc_ak0_m_ak1),
531 decltype(b_block_desc_bk0_n_bk1),
532 MPerXdl,
533 NPerXdl,
534 MXdlPerWave,
535 NXdlPerWave,
536 KPack,
537 LoopSched>();
538
539 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
540
541 // LDS allocation for A and B: be careful of alignment
542 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
543 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
544
546 static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
547
549 static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
550 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
551
552 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
553 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
554
555 // gridwise GEMM pipeline
556 const auto gridwise_gemm_pipeline =
558
559 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
560 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
561 KPerBlock);
562
563 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
564 a_block_desc_ak0_m_ak1,
565 a_blockwise_copy,
566 a_grid_buf,
567 a_block_buf,
568 a_block_slice_copy_step,
569 b_grid_desc_bk0_n_bk1,
570 b_block_desc_bk0_n_bk1,
571 b_blockwise_copy,
572 b_grid_buf,
573 b_block_buf,
574 b_block_slice_copy_step,
575 blockwise_gemm,
576 c_thread_buf,
577 num_k_block_main_loop);
578
579 // shuffle C and write out
580 {
581 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
582 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
583 "wrong!");
584
585 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
586 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
587
588 // TODO: hacky, fix it!
589 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
590 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
591
592 // TODO: hacky, fix it!
593 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
594 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
595 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
596
597 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
598 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
599 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
600 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
601 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
602 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
603 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
604 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
605
606 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
608
609 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
610 static_cast<FloatCShuffle*>(p_shared),
611 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
612
613 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
614 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
618 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
619 M1, // M1 = MWave
620 M2, // M2 * M3 * M4 = MPerXdl
621 M3,
622 M4)),
625 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
626 N1, // N1 = NWave
627 N2))), // N2 = NPerXdl
631
632 // calculate origin of thread output tensor on global memory
633 // blockwise GEMM c matrix starting index
634 const auto c_thread_mtx_on_block =
635 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
636
637 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
638 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
639
640 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
642 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
645
646 const auto m_thread_data_on_block_idx =
647 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
648 make_multi_index(m_thread_data_on_block));
649
650 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
655
656 const auto n_thread_data_on_block_idx =
657 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
658 make_multi_index(n_thread_data_on_block));
659
660 // shuffle: threadwise copy C from VGPR to LDS
661 auto c_thread_copy_vgpr_to_lds =
663 FloatCShuffle,
664 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
665 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
667 Sequence<CShuffleMXdlPerWavePerShuffle,
668 CShuffleNXdlPerWavePerShuffle,
669 I1,
670 I1,
671 M2,
672 I1,
673 M4,
674 I1>,
676 7,
677 1,
679 1,
680 true>{
681 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
683 0,
684 m_thread_data_on_block_idx[I1],
685 n_thread_data_on_block_idx[I1],
686 m_thread_data_on_block_idx[I2],
687 m_thread_data_on_block_idx[I3],
688 m_thread_data_on_block_idx[I4],
689 n_thread_data_on_block_idx[I2]),
691
692 // shuffle: blockwise copy C from LDS to global
693 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
694 ThisThreadBlock, // ThreadGroup
695 CElementwiseOperation, // ElementwiseOperation,
696 CGlobalMemoryDataOperation, // DstInMemOp,
697 Sequence<1,
698 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
699 1,
700 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
701 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
702 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
703 FloatCShuffle, // typename SrcData,
704 FloatC, // typename DstData,
705 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
706 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
707 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
708 3, // index_t VectorDim,
709 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
710 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
711 false> // bool ThreadTransferDstResetCoordinateAfterRun>
712 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
713 make_multi_index(0, 0, 0, 0),
714 c_grid_desc_mblock_mperblock_nblock_nperblock,
715 make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
716 c_element_op};
717
718 const auto NBlock = c0_grid_desc_nblock_nperblock.GetLength(I0);
719
720 // for broadcasting bias, beta, gamma
721 const auto c0_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
722 c0_grid_desc_nblock_nperblock,
726 make_pass_through_transform(NPerBlock)),
729
730 // LDS c_reduce_block_desc_mperblock_nperblock
731 constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
732 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
736 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)),
739 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))),
742
743 static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) *
744 CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) ==
745 BlockSize,
746 "wrong!");
747
748 static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) %
749 CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) ==
750 0 &&
751 (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) %
752 CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) ==
753 0,
754 "wrong!");
755
756 constexpr index_t mreduce_per_thread =
757 (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) /
758 CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0);
759
760 constexpr index_t nreduce_per_thread =
761 (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) /
762 CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1);
763
764 constexpr auto c_reduce_thread_lengths_mperblock_nperblock =
766
767 // pytorch default
768 // https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
769 static constexpr FloatReduceAcc epsilon = 1e-5;
770
771 // VGPR c_reduce_thread_desc_mperblock_nperblock
772 constexpr auto c_reduce_thread_desc_mperblock_nperblock =
775
776 constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
779
780 // VGPR d_reduce_thread_desc_mperblock
781 constexpr auto d_reduce_thread_desc_mperblock =
783
784 // TODO: this should be implemented as a blockwise reduction
786 c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
787
789 c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
790
791 // Align 16 bytes (maximum LDS read/write width)
792 constexpr auto c_block_size_aligned =
794 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() *
795 sizeof(FloatCShuffle),
796 16) /
797 sizeof(FloatCShuffle);
798
799 auto d_reduce_work_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
800 reinterpret_cast<FloatReduceAcc*>(static_cast<FloatCShuffle*>(p_shared) +
801 c_block_size_aligned),
802 BlockSize);
803
804 // Sum thread workspace
806 d_reduce_thread_desc_mperblock.GetElementSpaceSize());
807
808 // Squared sum thread workspace
810 d_reduce_thread_desc_mperblock.GetElementSpaceSize());
811
812 // reduce: threadwise copy from LDS to VGPR
813 constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor(
814 CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{});
815
816 const auto c_reduce_thread_cluster_idx =
817 c_reduce_thread_cluster_desc.CalculateBottomIndex(
819
820 const auto c_reduce_thread_data_idx_begin =
821 c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock;
822
823 auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
824 FloatCShuffle,
825 FloatReduceAcc,
826 decltype(c_reduce_block_desc_mperblock_nperblock),
827 decltype(c_reduce_thread_desc_mperblock_nperblock),
828 decltype(c_reduce_thread_lengths_mperblock_nperblock),
830 1,
831 CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
832 1,
833 true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
834
835 auto c_reduce_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
836 FloatReduceAcc,
837 FloatCShuffle,
838 decltype(c_reduce_thread_desc_mperblock_nperblock),
839 decltype(c_reduce_block_desc_mperblock_nperblock),
841 decltype(c_reduce_thread_lengths_mperblock_nperblock),
843 1,
844 CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
846 1,
847 true>{c_reduce_block_desc_mperblock_nperblock,
848 c_reduce_thread_data_idx_begin,
850
851 auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
852 FloatC0,
853 FloatC0,
854 decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock),
855 decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
858 3,
859 CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
860 1,
861 true>(c0_grid_desc_mblock_mperblock_nblock_nperblock,
862 make_multi_index(block_work_idx[I0],
863 c_reduce_thread_data_idx_begin[I0],
864 block_work_idx[I1],
865 c_reduce_thread_data_idx_begin[I1]));
866
867 // Note: c0_add is of same layout as c so we don't declare new c0_add_desc here
868 auto c0_add_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
869 FloatC0,
870 FloatC0,
871 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
872 decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
875 3,
876 CReduceThreadCopySrcDstScalarPerVector_NPerBlock,
877 1,
878 true>(c_grid_desc_mblock_mperblock_nblock_nperblock,
879 make_multi_index(block_work_idx[I0],
880 c_reduce_thread_data_idx_begin[I0],
881 block_work_idx[I1],
882 c_reduce_thread_data_idx_begin[I1]));
883
884 // space filling curve for threadwise C in VGPR
885 constexpr auto sfc_c_vgpr =
888 Sequence<CShuffleMXdlPerWavePerShuffle,
889 CShuffleNXdlPerWavePerShuffle,
890 1,
891 1,
892 M2,
893 1,
894 M4,
895 1>>{};
896
897 // space filling curve for shuffled blockwise C in global mem
898 constexpr auto sfc_c_global =
901 Sequence<1,
902 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
903 1,
904 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
905
906 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
907
908 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
909
910 static_for<0, num_access, 1>{}([&](auto access_id) {
911 // make sure it's safe to write to LDS
913
914 // each thread write its data from VGPR to LDS
915 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
916 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
917 c_thread_buf,
918 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
919 c_shuffle_block_buf);
920
922
923 // load from LDS and global, add bias
924 c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
925 c_shuffle_block_buf,
926 c_reduce_thread_desc_mperblock_nperblock,
927 make_tuple(I0, I0),
928 c_reduce_thread_buf);
929
930 c0_thread_copy_global_to_vgpr.Run(
931 c0_grid_desc_mblock_mperblock_nblock_nperblock,
932 c0_bias_grid_buf,
933 c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
934 make_tuple(I0, I0, I0, I0),
935 c0_thread_buf);
936
937 static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
938 [&](auto i) {
939 FloatReduceAcc out;
940 acc_element_op(out,
941 c_reduce_thread_buf(i) +
942 static_cast<FloatReduceAcc>(c0_thread_buf(i)));
943 c_reduce_thread_buf(i) = out; // acc_element_op(acc + bias)
944 });
945
946 c0_add_thread_copy_global_to_vgpr.Run(
947 c_grid_desc_mblock_mperblock_nblock_nperblock,
948 c0_add_grid_buf,
949 c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
950 make_tuple(I0, I0, I0, I0),
951 c0_thread_buf);
952
953 static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
954 [&](auto i) {
955 c_reduce_thread_buf(i) +=
956 static_cast<FloatReduceAcc>(c0_thread_buf(i)); // add
957 });
958
959 // layernorm
960 {
961 using ThreadwiseReduceD0 =
962 ThreadwiseReduction<FloatReduceAcc,
963 decltype(c_reduce_thread_desc_mperblock_nperblock),
964 decltype(d_reduce_thread_desc_mperblock),
966 false>;
967 using ThreadwiseReduceD1 =
968 ThreadwiseReduction<FloatReduceAcc,
969 decltype(c_reduce_thread_desc_mperblock_nperblock),
970 decltype(d_reduce_thread_desc_mperblock),
972 false>;
973
974 const auto d0_zeroVal =
975 ThreadwiseReduceD0::Op::template GetIdentityValue<FloatReduceAcc>();
976 const auto d1_zeroVal =
977 ThreadwiseReduceD1::Op::template GetIdentityValue<FloatReduceAcc>();
979 [&](auto i) { d0_thread_buf(i) = d0_zeroVal; });
981 [&](auto i) { d1_thread_buf(i) = d1_zeroVal; });
982
983 // reduce sum in VGPR
984 ThreadwiseReduceD0::Reduce(c_reduce_thread_buf, d0_thread_buf);
985
986 // reduce squared sum in VGPR
987 ThreadwiseReduceD1::Reduce(c_reduce_thread_buf, d1_thread_buf);
988
989 // reduce within workgroup
990 using BlockwiseReduce = PartitionedBlockwiseReduction<
991 FloatReduceAcc,
992 BlockSize,
993 CReduceThreadClusterLengths_MPerBlock_NPerBlock, // ThreadClusterLengths_M_K
994 Sequence<1, 0>, // ThreadClusterArrangeOrder
996 false>;
997
1000 BlockwiseReduce::Reduce(d_reduce_work_buf,
1001 d0_thread_buf(i)); // blockwise reduced sum
1003 BlockwiseReduce::Reduce(d_reduce_work_buf,
1004 d1_thread_buf(i)); // blockwise reduced squared sum
1005 });
1006
1007 // normalize
1008 const index_t NRaw =
1009 c_grid_desc_mblock_mperblock_nblock_nperblock.GetTransforms()[I0]
1010 .GetUpperLengths()[I1]; // TODO: proper handle
1011
1014 constexpr auto dst_offset =
1015 Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
1016 make_tuple(im, in))>{};
1017
1018 constexpr auto src_offset =
1019 Number<d_reduce_thread_desc_mperblock.CalculateOffset(
1020 make_tuple(im))>{};
1021
1022 FloatReduceAcc avg_sum = d0_thread_buf(src_offset) / NRaw;
1023 FloatReduceAcc avg_squared_sum = d1_thread_buf(src_offset) / NRaw;
1024
1025 FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum;
1026 FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum;
1027 FloatReduceAcc divisor_sqrt;
1028 tensor_operation::element_wise::UnarySqrt{}(divisor_sqrt, divisor);
1029
1030 c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt;
1031 });
1032 });
1033
1034 // scaling
1035 c0_thread_copy_global_to_vgpr.Run(
1036 c0_grid_desc_mblock_mperblock_nblock_nperblock,
1037 c0_gamma_grid_buf,
1038 c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
1039 make_tuple(I0, I0, I0, I0),
1040 c0_thread_buf);
1041
1042 static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
1043 [&](auto i) {
1044 c_reduce_thread_buf(i) *=
1045 static_cast<FloatReduceAcc>(c0_thread_buf(i)); // * gamma
1046 });
1047
1048 c0_thread_copy_global_to_vgpr.Run(
1049 c0_grid_desc_mblock_mperblock_nblock_nperblock,
1050 c0_beta_grid_buf,
1051 c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
1052 make_tuple(I0, I0, I0, I0),
1053 c0_thread_buf);
1054
1055 static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
1056 [&](auto i) {
1057 c_reduce_thread_buf(i) +=
1058 static_cast<FloatReduceAcc>(c0_thread_buf(i)); // + beta
1059 });
1060
1062
1063 c_reduce_thread_copy_vgpr_to_lds.Run(c_reduce_thread_desc_mperblock_nperblock,
1064 make_tuple(I0, I0),
1065 c_reduce_thread_buf,
1066 c_reduce_block_desc_mperblock_nperblock,
1067 c_shuffle_block_buf);
1068
1069 } // end layernorm
1070
1072
1073 // each block copy its data from LDS to global
1074 c_shuffle_block_copy_lds_to_global.Run(
1075 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1076 c_shuffle_block_buf,
1077 c_grid_desc_mblock_mperblock_nblock_nperblock,
1078 c_grid_buf);
1079
1080 if constexpr(access_id < num_access - 1)
1081 {
1082 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1083
1084 // move on C
1085 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1086 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1087
1088 // move on C0
1089 c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
1090 c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1091
1092 // move on C0_add
1093 c0_add_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
1094 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1095 }
1096 });
1097 }
1098 }
1099};
1100
1101} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition blockwise_gemm_xdlops.hpp:620
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_insert_transform(const UpperIndex &up_idx)
Definition multi_index_transform_helper.hpp:157
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__global__ void kernel_gemm_layernorm_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC0 *__restrict__ p_c0_bias_grid, const FloatC0 *__restrict__ p_c0_add_grid, const FloatC0 *__restrict__ p_c0_gamma_grid, const FloatC0 *__restrict__ p_c0_beta_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const AccElementwiseOperation acc_element_op, const CElementwiseOperation c_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp:41
Definition block_to_ctile_map.hpp:261
Definition gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp:160
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp:268
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp:366
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const ADataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, const C0DataType *__restrict__ p_c0_bias_grid, const C0DataType *__restrict__ p_c0_add_grid, const C0DataType *__restrict__ p_c0_gamma_grid, const C0DataType *__restrict__ p_c0_beta_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const AccElementwiseOperation &acc_element_op, const CElementwiseOperation &c_element_op, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const C0GridDescriptor_NBlock_NPerBlock &c0_grid_desc_nblock_nperblock, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp:378
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition reduction_functions_blockwise.hpp:28
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition reduction_functions_threadwise.hpp:23
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:276
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer.hpp:389
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition reduction_operator.hpp:37
Definition reduction_operator.hpp:87
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:797