blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp Source File

blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MScaleBlock,
32 index_t NScaleBlock,
33 index_t KScaleBlock,
34 index_t MPerXDL,
35 index_t NPerXDL,
36 index_t MRepeat,
37 index_t NRepeat,
38 index_t KPacks>
42
43template <index_t BlockSize,
44 typename ADataType,
45 typename BDataType,
46 typename ComputeDataType,
47 typename AccDataType,
48 typename ATileDesc,
49 typename BTileDesc,
50 typename AMmaTileDesc,
51 typename BMmaTileDesc,
52 index_t ABlockTransferSrcScalarPerVector,
53 index_t BBlockTransferSrcScalarPerVector,
54 index_t MPerBlock,
55 index_t NPerBlock,
56 index_t KPerBlock,
57 index_t MScaleBlock,
58 index_t NScaleBlock,
59 index_t KScaleBlock,
60 index_t MPerXDL,
61 index_t NPerXDL,
62 index_t MRepeat,
63 index_t NRepeat,
64 index_t KPack
65 // ,bool TransposeC //disable transposec right now...
66 >
68 BlockSize,
69 ADataType,
70 BDataType,
71 ComputeDataType,
72 AccDataType,
73 ATileDesc,
74 BTileDesc,
75 AMmaTileDesc,
76 BMmaTileDesc,
77 ABlockTransferSrcScalarPerVector,
78 BBlockTransferSrcScalarPerVector,
79 MPerBlock,
80 NPerBlock,
81 KPerBlock,
82 MScaleBlock,
83 NScaleBlock,
84 KScaleBlock,
85 MPerXDL,
86 NPerXDL,
87 MRepeat,
88 NRepeat,
89 KPack>
91 ADataType,
92 BDataType,
93 ComputeDataType,
94 AccDataType,
95 ATileDesc,
96 BTileDesc,
97 AMmaTileDesc,
98 BMmaTileDesc,
99 ABlockTransferSrcScalarPerVector,
100 BBlockTransferSrcScalarPerVector,
101 MPerBlock,
102 NPerBlock,
103 KPerBlock,
104 MPerXDL,
105 NPerXDL,
106 MRepeat,
107 NRepeat,
108 KPack,
109 true>
110
111{
113 ADataType,
114 BDataType,
115 ComputeDataType,
116 AccDataType,
117 ATileDesc,
118 BTileDesc,
119 AMmaTileDesc,
120 BMmaTileDesc,
121 ABlockTransferSrcScalarPerVector,
122 BBlockTransferSrcScalarPerVector,
123 MPerBlock,
124 NPerBlock,
125 KPerBlock,
126 MPerXDL,
127 NPerXDL,
128 MRepeat,
129 NRepeat,
130 KPack,
131 true>;
132 using Base::A_K1;
133 using Base::B_K1;
134 using Base::I0;
135 using Base::I1;
136 using Base::I2;
137 using Base::KGroup;
138 using Base::KRepeat;
139 using Base::xdlops_gemm;
140 using typename Base::HotLoopInstList;
141
154 using Base::MWaves;
155 using Base::WaveSize;
156
157 static constexpr index_t PrefetchStages = 2;
158 static constexpr index_t LocalPrefetchStages = 2;
159 static constexpr index_t PrefillStages = 1;
160 static constexpr index_t GlobalBufferNum = 1;
161 static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
162
163 template <typename TileDesc_M0_M1_M2_K>
164 __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
165 {
166 constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
167 constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
168 constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
169 constexpr index_t K2 = KPack / KGroup;
170 constexpr index_t K1 = WaveSize / NPerXDL;
171 constexpr index_t K0 = KRepeat * KGroup;
172
174 TileDesc_M0_M1_M2_K{},
182 }
183
184 static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
186
187 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
188 {
189 return num_loop > PrefetchStages;
190 }
191
192 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
193 {
194 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
195 }
196
197 __device__ static constexpr auto HotLoopScheduler()
198 {
199 // A/B split schedule
200 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
201 constexpr auto num_ds_read_inst_a =
202 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
205
206 constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
207
208 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
209 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
210
211 static_assert(num_buffer_load_inst_a == num_ds_write_inst_a);
212
213 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
214 constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
215
216 constexpr auto ds_read_a_issue_cycle =
217 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
218 constexpr auto ds_read_a_mfma_rate =
219 math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle);
220
221 // constexpr auto num_dsread_a_mfma =
222 // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
223
224 constexpr auto num_total_stages = MRepeat;
225
226 // Group num_mfma_perstage num_ds_read_a_perstage
227 // since we want to reuse a local register buffer
228 constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
229 constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
230
231 constexpr auto num_ds_read_a_mfma_perstage =
232 math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
233
234 constexpr auto buffer_load_perstage_more =
235 math::integer_divide_ceil((num_buffer_load_inst_a + num_buffer_load_inst_b),
236 (num_total_stages - (LocalPrefetchStages - 1)));
237 constexpr auto buffer_load_perstage_less =
238 math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b),
239 (num_total_stages - (LocalPrefetchStages - 1)));
240
241 constexpr auto buffer_load_stages_more =
242 (num_buffer_load_inst_a + num_buffer_load_inst_b) -
243 math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b),
244 (num_total_stages - (LocalPrefetchStages - 1))) *
245 ((num_total_stages - (LocalPrefetchStages - 1)));
246
247 constexpr auto buffer_load_b_stages =
248 buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b
249 ? num_buffer_load_inst_b / buffer_load_perstage_more
250 : (buffer_load_stages_more +
251 (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) /
252 buffer_load_perstage_less);
253
254 constexpr auto buffer_load_a_stages =
255 num_total_stages - (LocalPrefetchStages - 1) - buffer_load_b_stages;
256
257 constexpr auto buffer_load_issue_point_b = 0;
258 constexpr auto buffer_load_issue_point_interval_more =
259 num_mfma_perstage / buffer_load_perstage_more
260 ? num_mfma_perstage / buffer_load_perstage_more
261 : 1;
262 constexpr auto buffer_load_issue_point_interval_less =
263 num_mfma_perstage / buffer_load_perstage_less
264 ? num_mfma_perstage / buffer_load_perstage_less
265 : 1;
266 constexpr auto ds_write_issue_point = 0;
267 constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
268
269 // B global read
271 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
272 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
273
274 if constexpr(((i < buffer_load_stages_more) &&
275 (imfma % buffer_load_issue_point_interval_more ==
276 buffer_load_issue_point_b)) ||
277 ((i >= buffer_load_stages_more) &&
278 (imfma % buffer_load_issue_point_interval_less ==
279 buffer_load_issue_point_b)))
280 {
281 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
282 }
283
284 if constexpr((imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage - 1)) &&
285 (imfma < (num_mfma_perstage - 1)))
286 {
287 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
288 }
289 __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma
290 // __builtin_amdgcn_sched_group_barrier(0x1000, 4, 0); // v_fmac
291 });
292 // Scale load, 1B
293 if constexpr(i.value == 0)
294 {
295 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
296 }
297 // Scale load, 1A
298 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
299 // __builtin_amdgcn_sched_barrier(0);
300 });
301
302 // A global read + A local write
304 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
305 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
306 if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
307 (imfma % buffer_load_issue_point_interval_more ==
308 ds_write_issue_point)) ||
309 (((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
310 (imfma % buffer_load_issue_point_interval_less ==
311 ds_write_issue_point)))
312 {
313 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
314 }
315 if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
316 (imfma % buffer_load_issue_point_interval_more ==
317 buffer_load_issue_point_a)) ||
318 (((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
319 (imfma % buffer_load_issue_point_interval_less ==
320 buffer_load_issue_point_a)))
321 {
322 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
323 }
324 if constexpr((imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage - 1)) &&
325 (imfma < (num_mfma_perstage - 1)))
326 {
327 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
328 }
329 __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma
330 // __builtin_amdgcn_sched_group_barrier(0x1000, 4, 0); // v_fmac
331 });
332 // Scale load, 1A
333 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
334 // __builtin_amdgcn_sched_barrier(0);
335 });
336
337 // lds synchronization, prefetch next loop local A
338 static_for<0, (LocalPrefetchStages - 1), 1>{}([&](auto i) {
339 ignore = i;
340 static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
341 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
342
343 if constexpr((imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage - 1)) &&
344 (imfma < (num_mfma_perstage - 1)))
345 {
346 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
347 }
348 __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma
349 // __builtin_amdgcn_sched_group_barrier(0x1000, 4, 0); // v_fmac
350 });
351 // Scale load, 1A
352 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
353 // __builtin_amdgcn_sched_barrier(0);
354 });
355 }
356
357 template <bool HasMainLoop,
358 int NumKBlockPerScale,
359 TailNumber TailNum,
360 typename AGridDesc,
361 typename ABlockDesc,
362 typename ABlockTransfer,
363 typename AGridBuffer,
364 typename ABlockBuffer,
365 typename ABlockTransferStep,
366 typename BGridDesc,
367 typename BBlockDesc,
368 typename BBlockTransfer,
369 typename BGridBuffer,
370 typename BBlockBuffer,
371 typename BBlockTransferStep,
372 typename CScaleThreadDesc,
373 typename CThreadBuffer,
374 typename AScaleGridBuffer,
375 typename AScaleGridDesc,
376 typename AScaleThreadDesc,
377 typename AScaleThreadTransfer,
378 typename AScaleThreadTransferStep,
379 typename BScaleGridBuffer,
380 typename BScaleGridDesc,
381 typename BScaleThreadDesc,
382 typename BScaleThreadTransfer,
383 typename BScaleThreadTransferStep>
384 __device__ void Run(
385 // ABlockCopy
386 const AGridDesc& a_grid_desc,
387 const ABlockDesc& a_block_desc,
388 ABlockTransfer& a_blockwise_copy,
389 const AGridBuffer& a_grid_buf,
390 ABlockBuffer& a_block_buf,
391 const ABlockTransferStep& a_block_copy_step,
392 // BBlockCopy
393 const BGridDesc& b_grid_desc,
394 const BBlockDesc& b_block_desc,
395 BBlockTransfer& b_blockwise_copy,
396 const BGridBuffer& b_grid_buf,
397 BBlockBuffer& b_block_buf,
398 const BBlockTransferStep& b_block_copy_step,
399 // CThread
400 const CScaleThreadDesc& c_scale_thread_desc,
401 CThreadBuffer& c_thread_buf,
402 // AScaleThreadCopy
403 const AScaleGridDesc& a_scale_grid_desc,
404 const AScaleThreadDesc& a_scale_thread_desc,
405 AScaleThreadTransfer& a_scale_thread_copy,
406 const AScaleGridBuffer& a_scale_grid_buf,
407 const AScaleThreadTransferStep& a_scale_thread_copy_step,
408 // BScaleThreadCopy
409 const BScaleGridDesc& b_scale_grid_desc,
410 const BScaleThreadDesc& b_scale_thread_desc,
411 BScaleThreadTransfer& b_scale_thread_copy,
412 const BScaleGridBuffer& b_scale_grid_buf,
413 const BScaleThreadTransferStep& b_scale_thread_copy_step,
414 // num_loop
415 index_t num_loop) const
416 {
417 ignore = b_block_desc;
418 ignore = b_block_buf;
419 __builtin_amdgcn_sched_barrier(0);
420 static_assert(CScaleThreadDesc{}.GetLength(Number<0>{}) == 1,
421 "Pipeline v3 only support scaleblocksliceK=1");
422 static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1,
423 "Pipeline v3 only support scaleblocksliceN=1");
424 // assume kperblock = scaleblockk
426 a_thread_desc_.GetElementSpaceSize());
428 b_thread_desc_.GetElementSpaceSize());
429
430 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
431 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
432
434 a_scale_thread_desc.GetElementSpaceSize());
436 b_scale_thread_desc.GetElementSpaceSize());
438 c_scale_thread_desc.GetElementSpaceSize());
439
440 StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
441 StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
442 // StaticallyIndexedArray<decltype(c_scale_thread_buf), Number<2>{}> c_scale_thread_bufs;
443
444 // Global prefetch A1 B1, AScale1 BScale1
445 b_blockwise_copy.Run(b_grid_desc,
446 b_grid_buf,
448 b_block_origin_idx,
449 b_thread_bufs(I0));
450 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
451
452 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
453 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
454 __builtin_amdgcn_sched_barrier(0);
455
456 static_for<0, MRepeat, 1>{}([&](auto m0) {
457 a_scale_thread_copy.Run(a_scale_grid_desc,
458 a_scale_grid_buf,
459 a_scale_thread_desc,
460 make_tuple(m0, I0),
461 a_scale_thread_bufs(I0));
462 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
463 a_scale_thread_copy_step.At(Number<0>{}));
464 });
465
466 if constexpr(NumKBlockPerScale == 1)
467 {
468 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
469 a_scale_thread_copy_step.At(Number<2>{}));
470 }
471 else
472 {
473 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
474 a_scale_thread_copy_step.At(Number<1>{}));
475 }
476
477 b_scale_thread_copy.Run(b_scale_grid_desc,
478 b_scale_grid_buf,
479 b_scale_thread_desc,
480 make_tuple(I0, I0),
481 b_scale_thread_bufs(I0));
482
483 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
484
485 static_for<0, MRepeat, 1>{}([&](auto m0) {
486 c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
487 });
488
489 // Local prefill A1
490 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
491
492 // Global prefetch A2, AScale2 BScale2
493 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
494 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
495
496#if 1
497 static_for<0, MRepeat, 1>{}([&](auto m0) {
498 a_scale_thread_copy.Run(a_scale_grid_desc,
499 a_scale_grid_buf,
500 a_scale_thread_desc,
501 make_tuple(m0, I0),
502 a_scale_thread_bufs(I0));
503 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
504 a_scale_thread_copy_step.At(Number<0>{}));
505 });
506
507 if constexpr(NumKBlockPerScale == 1)
508 {
509 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
510 a_scale_thread_copy_step.At(Number<2>{}));
511 }
512 else
513 {
514 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
515 a_scale_thread_copy_step.At(Number<1>{}));
516 }
517
518 b_scale_thread_copy.Run(b_scale_grid_desc,
519 b_scale_grid_buf,
520 b_scale_thread_desc,
521 make_tuple(I0, I0),
522 b_scale_thread_bufs(I0));
523
524 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
525#endif
526 // Initialize C
527 c_thread_buf.Clear();
528
529 // Double register buffer for non-scaled gemm computation
530 // 1. Reduce register pressure
531 // 2. Decouple the dependency between mfma instruction and scale-fma instruction following.
533 AccDataType,
534 2,
535 xdlops_gemm.GetRegSizePerXdlops(),
536 true>
537 c_thread_buf_per_scale;
538
539 // Local prefetch A1
542 static_for<0, KRepeat, 1>{}([&](auto k0) {
543 static_for<0, KGroup, 1>{}([&](auto kg0) {
544 a_thread_copy_.Run(
547 a_block_buf.At(I0),
550 a_thread_buf);
551 });
552 });
553 });
554
555#if 1
556 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
557 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
558 .template AsType<AccDataType>()(Number<t>{}) = 0;
559 });
560
561 // Fill first mfma buffer
562 static_for<0, KRepeat, 1>{}([&](auto k0) {
565
566 static_for<0, KPack, 1>{}([&](auto ik) {
567 a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_buf
568 [Number<a_thread_desc_.CalculateOffset(make_tuple(I0, I0, I0, k0, I0, ik))>{}];
569 b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_bufs
570 [I0][Number<b_thread_desc_.CalculateOffset(make_tuple(I0, I0, k0, ik))>{}];
571 });
572
573 using mfma_input_type =
574 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
575
576 xdlops_gemm.template Run<>(a_thread_vec.template AsType<mfma_input_type>(),
577 b_thread_vec.template AsType<mfma_input_type>(),
578 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
579 });
580#endif
581 __builtin_amdgcn_sched_barrier(0);
582
583 // main body
584 if constexpr(HasMainLoop)
585 {
586 index_t i = 0;
587 do
588 {
589 auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
590 b_blockwise_copy.Run(b_grid_desc,
591 b_grid_buf,
593 b_block_origin_idx,
594 b_thread_bufs(local_read_buf));
595 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
596
597 b_scale_thread_copy.Run(b_scale_grid_desc,
598 b_scale_grid_buf,
599 b_scale_thread_desc,
600 make_tuple(I0, I0),
601 b_scale_thread_bufs(local_read_buf));
602
603 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
604 b_scale_thread_copy_step);
605
606 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
607 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
608 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
609
610 static_for<0, MRepeat, 1>{}([&](auto m0) {
611 a_scale_thread_copy.Run(a_scale_grid_desc,
612 a_scale_grid_buf,
613 a_scale_thread_desc,
614 make_tuple(m0, I0),
615 a_scale_thread_bufs(local_read_buf));
616 a_scale_thread_copy.MoveSrcSliceWindow(
617 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
618 });
619
620 if constexpr(NumKBlockPerScale == 1)
621 {
622 a_scale_thread_copy.MoveSrcSliceWindow(
623 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
624 }
625 else
626 {
627 a_scale_thread_copy.MoveSrcSliceWindow(
628 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
629 }
630
631 static_for<0, MRepeat, 1>{}([&](auto m0) {
632 vector_type<AccDataType, 2> c_scale_thread_vec;
633 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
634 c_scale_thread_buf[m0];
635 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
636 c_scale_thread_buf[m0];
637
638 static_for<0, NRepeat, 1>{}([&](auto n0) {
639 constexpr auto mfma_buf_offset =
640 ((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
641 constexpr auto scale_buf_offset =
642 ((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
643
644 constexpr auto a_local_buf_offset =
645 ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat;
646 constexpr auto b_local_buf_offset =
647 ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat;
648 constexpr auto b_local_buf_id =
649 Number<mfma_reg_buf ^
650 ((m0 * NRepeat + n0 + 1) / (MRepeat * NRepeat))>{};
651
652 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
653 c_thread_buf_per_scale
654 .GetVectorTypeReference(Number<mfma_buf_offset>{})
655 .template AsType<AccDataType>()(Number<t>{}) = 0;
656 });
657
658 static_for<0, KRepeat, 1>{}([&](auto k0) {
661
662 static_for<0, KPack, 1>{}([&](auto ik) {
663 a_thread_vec.template AsType<ComputeDataType>()(ik) =
664 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
665 make_tuple((a_local_buf_offset +
666 HotloopLocalBufSwitch * mfma_reg_buf) %
667 2,
668 I0,
669 I0,
670 k0,
671 I0,
672 ik))>{}];
673 b_thread_vec.template AsType<ComputeDataType>()(ik) =
674 b_thread_bufs
675 [b_local_buf_id][Number<b_thread_desc_.CalculateOffset(
676 make_tuple(b_local_buf_offset, I0, k0, ik))>{}];
677 });
678
679 using mfma_input_type =
680 typename vector_type<ComputeDataType,
681 xdlops_gemm.K1PerXdlops>::type;
682
683 xdlops_gemm.template Run<>(
684 a_thread_vec.template AsType<mfma_input_type>(),
685 b_thread_vec.template AsType<mfma_input_type>(),
686 c_thread_buf_per_scale.GetVectorTypeReference(
688 });
689
690 constexpr index_t c_offset =
691 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
692
693 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
694 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
695
696 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
697 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
698 c_thread_buf_per_scale
699 .GetVectorTypeReference(Number<scale_buf_offset>{})
700 .template AsType<pk_fma_type>()[t],
701 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
702 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
703 .template AsType<pk_fma_type>()[t]);
704 });
705 });
706
707 // We have to 1 stage early sync the lds for workaround the compiler
708 // limitation
709 if constexpr(m0.value == (MRepeat - LocalPrefetchStages - 1))
710 {
712 }
713
714 constexpr auto lds_buf = m0.value >= (MRepeat - LocalPrefetchStages)
715 ? local_read_buf
716 : mfma_reg_buf;
717
718 static_for<0, KRepeat, 1>{}([&](auto k0) {
719 static_for<0, KGroup, 1>{}([&](auto kg0) {
720 a_thread_copy_.Run(
722 make_tuple(Number<(m0 + 2) % MRepeat>{},
723 I0,
724 I0,
726 I0,
727 I0),
728 a_block_buf.At(Number<lds_buf>{}),
731 HotloopLocalBufSwitch * mfma_reg_buf) %
732 2>{},
733 I0,
734 I0,
735 k0,
736 I0,
737 Number<kg0 * KPack / KGroup>{}),
738 a_thread_buf);
739 });
740 });
741 });
742
743 static_for<0, MRepeat, 1>{}([&](auto m0) {
744 c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] *
745 b_scale_thread_bufs[mfma_reg_buf][I0];
746 });
747
748 // We need new compiler to enable this feature
749 // HotLoopScheduler();
750 // __builtin_amdgcn_sched_barrier(0);
751 };
752
753 LoopFunc(I0, I1);
754 LoopFunc(I1, I0);
755
756 i += 2;
757 } while(i < (num_loop - 2));
758 }
759
760 // tail
761 if constexpr(TailNum == TailNumber::Even)
762 {
763 b_blockwise_copy.Run(b_grid_desc,
764 b_grid_buf,
766 b_block_origin_idx,
767 b_thread_bufs(I1));
768 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
769
770 static_for<0, MRepeat, 1>{}([&](auto m0) {
771 vector_type<AccDataType, 2> c_scale_thread_vec;
772 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
773 c_scale_thread_buf[m0];
774 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
775 c_scale_thread_buf[m0];
776
777 static_for<0, NRepeat, 1>{}([&](auto n0) {
778 constexpr auto mfma_buf_offset =
779 ((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
780 constexpr auto scale_buf_offset =
781 ((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
782
783 constexpr auto a_local_buf_offset =
784 ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat;
785 constexpr auto b_local_buf_offset =
786 ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat;
787
788 constexpr auto b_local_buf_id =
789 Number<0 ^ ((m0 * NRepeat + n0 + 1) / (MRepeat * NRepeat))>{};
790
791 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
792 c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
793 .template AsType<AccDataType>()(Number<t>{}) = 0;
794 });
795 static_for<0, KRepeat, 1>{}([&](auto k0) {
798
799 static_for<0, KPack, 1>{}([&](auto ik) {
800 a_thread_vec.template AsType<ComputeDataType>()(ik) =
801 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
802 make_tuple(a_local_buf_offset % 2, I0, I0, k0, I0, ik))>{}];
803 b_thread_vec.template AsType<ComputeDataType>()(ik) =
804 b_thread_bufs[b_local_buf_id][Number<b_thread_desc_.CalculateOffset(
805 make_tuple(b_local_buf_offset, I0, k0, ik))>{}];
806 });
807
808 using mfma_input_type =
809 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
810
811 xdlops_gemm.template Run<>(a_thread_vec.template AsType<mfma_input_type>(),
812 b_thread_vec.template AsType<mfma_input_type>(),
813 c_thread_buf_per_scale.GetVectorTypeReference(
815 });
816
817 constexpr index_t c_offset =
818 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
819
820 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
821 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
822
823 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
824 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
825 c_thread_buf_per_scale
826 .GetVectorTypeReference(Number<scale_buf_offset>{})
827 .template AsType<pk_fma_type>()[t],
828 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
829 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
830 .template AsType<pk_fma_type>()[t]);
831 });
832 });
833
834 if constexpr(m0.value == (MRepeat - LocalPrefetchStages))
835 {
837 }
838
839 constexpr auto lds_buf = m0.value >= (MRepeat - LocalPrefetchStages) ? I1 : I0;
840
841 static_for<0, KRepeat, 1>{}([&](auto k0) {
842 static_for<0, KGroup, 1>{}([&](auto kg0) {
843 a_thread_copy_.Run(
845 make_tuple(Number<(m0 + LocalPrefetchStages) % MRepeat>{},
846 I0,
847 I0,
849 I0,
850 I0),
851 a_block_buf.At(Number<lds_buf>{}),
854 I0,
855 I0,
856 k0,
857 I0,
858 Number<kg0 * KPack / KGroup>{}),
859 a_thread_buf);
860 });
861 });
862 });
863
864 // HotLoopScheduler();
865
866 static_for<0, MRepeat, 1>{}([&](auto m0) {
867 c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
868 });
869
870 static_for<0, MRepeat, 1>{}([&](auto m0) {
871 vector_type<AccDataType, 2> c_scale_thread_vec;
872 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
873 c_scale_thread_buf[m0];
874 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
875 c_scale_thread_buf[m0];
876
877 static_for<0, NRepeat, 1>{}([&](auto n0) {
878 constexpr auto mfma_buf_offset =
879 ((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
880 constexpr auto scale_buf_offset =
881 ((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
882
883 constexpr auto a_local_buf_offset =
884 ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat;
885 constexpr auto b_local_buf_offset =
886 ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat;
887
888 if constexpr(!((m0 == (MRepeat - 1)) && (n0 == (NRepeat - 1))))
889 {
890 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
891 c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
892 .template AsType<AccDataType>()(Number<t>{}) = 0;
893 });
894 static_for<0, KRepeat, 1>{}([&](auto k0) {
897
898 static_for<0, KPack, 1>{}([&](auto ik) {
899 a_thread_vec.template AsType<ComputeDataType>()(ik) =
900 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
901 make_tuple((a_local_buf_offset + HotloopLocalBufSwitch) % 2,
902 I0,
903 I0,
904 k0,
905 I0,
906 ik))>{}];
907 b_thread_vec.template AsType<ComputeDataType>()(ik) =
908 b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
909 make_tuple(b_local_buf_offset, I0, k0, ik))>{}];
910 });
911
912 using mfma_input_type =
913 typename vector_type<ComputeDataType,
914 xdlops_gemm.K1PerXdlops>::type;
915
916 xdlops_gemm.template Run<>(
917 a_thread_vec.template AsType<mfma_input_type>(),
918 b_thread_vec.template AsType<mfma_input_type>(),
919 c_thread_buf_per_scale.GetVectorTypeReference(
921 });
922 }
923
924 constexpr index_t c_offset =
925 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
926
927 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
928 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
929
930 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
931 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
932 c_thread_buf_per_scale
933 .GetVectorTypeReference(Number<scale_buf_offset>{})
934 .template AsType<pk_fma_type>()[t],
935 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
936 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
937 .template AsType<pk_fma_type>()[t]);
938 });
939 });
940
941 if constexpr(m0.value < (MRepeat - LocalPrefetchStages))
942 {
943 static_for<0, KRepeat, 1>{}([&](auto k0) {
944 static_for<0, KGroup, 1>{}([&](auto kg0) {
945 a_thread_copy_.Run(
948 I0,
949 I0,
951 I0,
952 I0),
953 a_block_buf.At(I1),
957 2>{},
958 I0,
959 I0,
960 k0,
961 I0,
962 Number<kg0 * KPack / KGroup>{}),
963 a_thread_buf);
964 });
965 });
966 }
967 });
968 // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
969 // latency
970 // __builtin_amdgcn_sched_barrier(0);
971 }
972 else
973 {
974 static_for<0, MRepeat, 1>{}([&](auto m0) {
975 vector_type<AccDataType, 2> c_scale_thread_vec;
976 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
977 c_scale_thread_buf[m0];
978 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
979 c_scale_thread_buf[m0];
980
981 static_for<0, NRepeat, 1>{}([&](auto n0) {
982 constexpr auto mfma_buf_offset =
983 ((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
984 constexpr auto scale_buf_offset =
985 ((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
986
987 constexpr auto a_local_buf_offset =
988 ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat;
989 constexpr auto b_local_buf_offset =
990 ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat;
991
992 if constexpr(!((m0 == (MRepeat - 1)) && (n0 == (NRepeat - 1))))
993 {
994 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
995 c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
996 .template AsType<AccDataType>()(Number<t>{}) = 0;
997 });
998 static_for<0, KRepeat, 1>{}([&](auto k0) {
1001
1002 static_for<0, KPack, 1>{}([&](auto ik) {
1003 a_thread_vec.template AsType<ComputeDataType>()(ik) =
1004 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1005 make_tuple(a_local_buf_offset % 2, I0, I0, k0, I0, ik))>{}];
1006 b_thread_vec.template AsType<ComputeDataType>()(ik) =
1007 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
1008 make_tuple(b_local_buf_offset, I0, k0, ik))>{}];
1009 });
1010
1011 using mfma_input_type =
1012 typename vector_type<ComputeDataType,
1013 xdlops_gemm.K1PerXdlops>::type;
1014
1015 xdlops_gemm.template Run<>(
1016 a_thread_vec.template AsType<mfma_input_type>(),
1017 b_thread_vec.template AsType<mfma_input_type>(),
1018 c_thread_buf_per_scale.GetVectorTypeReference(
1020 });
1021 }
1022
1023 constexpr index_t c_offset =
1024 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1025
1026 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
1027 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
1028
1029 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
1030 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
1031 c_thread_buf_per_scale
1032 .GetVectorTypeReference(Number<scale_buf_offset>{})
1033 .template AsType<pk_fma_type>()[t],
1034 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
1035 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
1036 .template AsType<pk_fma_type>()[t]);
1037 });
1038 });
1039
1040 if constexpr(m0.value < (MRepeat - LocalPrefetchStages))
1041 {
1042 static_for<0, KRepeat, 1>{}([&](auto k0) {
1043 static_for<0, KGroup, 1>{}([&](auto kg0) {
1044 a_thread_copy_.Run(
1046 make_tuple(
1048 a_block_buf.At(I0),
1051 I0,
1052 I0,
1053 k0,
1054 I0,
1055 Number<kg0 * KPack / KGroup>{}),
1056 a_thread_buf);
1057 });
1058 });
1059 }
1060 });
1061 }
1062 }
1063
1064 protected:
1065 // MRepeat MWave MLane KRepeat KLane KPack
1066 // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
1067 // Reduce the vgpr usage here.
1070
1072 ComputeDataType,
1074 decltype(a_thread_desc_),
1075 Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
1077 5,
1078 A_K1,
1079 A_K1>;
1080
1082
1085
1086 static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
1087
1089};
1090
1091} // namespace ck
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Vgpr
Definition amd_address_space.hpp:20
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, 1, KPack/KGroup >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp:1071
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const CScaleThreadDesc &c_scale_thread_desc, CThreadBuffer &c_thread_buf, const AScaleGridDesc &a_scale_grid_desc, const AScaleThreadDesc &a_scale_thread_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const AScaleThreadTransferStep &a_scale_thread_copy_step, const BScaleGridDesc &b_scale_grid_desc, const BScaleThreadDesc &b_scale_thread_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, const BScaleThreadTransferStep &b_scale_thread_copy_step, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp:384
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack, true > Base
Definition blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp:112
Definition blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp:40
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:75
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10