blockwise_gemm_pipeline_xdlops_base.hpp Source File

blockwise_gemm_pipeline_xdlops_base.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_base.hpp Source File
blockwise_gemm_pipeline_xdlops_base.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
14template <index_t BlockSize,
15 typename ADataType,
16 typename BDataType,
17 typename ComputeDataType,
18 typename AccDataType,
19 typename ATileDesc,
20 typename BTileDesc,
21 typename AMmaTileDesc,
22 typename BMmaTileDesc,
23 index_t ABlockTransferSrcScalarPerVector,
24 index_t BBlockTransferSrcScalarPerVector,
25 index_t MPerBlock,
26 index_t NPerBlock,
27 index_t KPerBlock,
28 index_t MPerXDL,
29 index_t NPerXDL,
30 index_t MRepeat,
31 index_t NRepeat,
32 index_t KPack,
33 bool TransposeC = false>
35{
36 static constexpr auto I0 = Number<0>{};
37 static constexpr auto I1 = Number<1>{};
38 static constexpr auto I2 = Number<2>{};
39 static constexpr auto I3 = Number<3>{};
40
42
43 // Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs.
44 static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
45 static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
46 static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
47
48 static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
49 static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
50 static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
51 static constexpr index_t B_K1 =
52 BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {});
53
56
59
60 static constexpr index_t AMmaKStride = KPack;
61 static constexpr index_t BMmaKStride = KPack;
62
63 static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
64 static constexpr index_t KRepeat = KPerThread / KPack;
65 static constexpr index_t KPerInnerLoop = KPack;
66
67 static constexpr index_t KGroup = []() {
69 // On gfx950, we have mfma that required 32 f8 elements as input,
70 // splited into 2 groups of 16 f8 elements.
71 // the 2 groups is not contiguous in the B preshuffed layout.
72 // and we do not want it to be contiguous in the B preshuffled layout
73 // because a memory instruction can only read 16 f8 elements at a time.
74 return ((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) ||
75 (MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64))
76 ? 2
77 : 1;
78 else
79 return 1;
80 }();
81
84 MPerBlock,
85 NPerBlock,
86 KPerBlock,
87 ABlockTransferSrcScalarPerVector,
88 BBlockTransferSrcScalarPerVector,
89 A_K1,
90 B_K1,
91 A_K1,
92 B_K1,
93 MRepeat,
94 NRepeat,
95 MPerXDL,
96 NPerXDL,
97 xdlops_gemm.KPerXdlops>;
98
99#if defined(__HIP_DEVICE_COMPILE__)
100 static_assert(KPerThread % KPack == 0,
101 "Wrong KPack setting; try increasing KPerThread or decreasing KPack");
102#endif
103
105 AccDataType,
106 MRepeat * NRepeat,
107 xdlops_gemm.GetRegSizePerXdlops(),
108 true>
110
111 __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
112
113 __device__ static auto GetWaveIdx()
114 {
115 const index_t thread_id = ThisThreadBlock::GetThreadId();
116
117 constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
121
122 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
123 }
124
125 __device__ static auto CalculateAThreadOriginDataIndex()
126 {
127 const auto wave_idx = GetWaveIdx();
128
129 const auto waveId_m = wave_idx[I0];
130
131 const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
132
133 return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
134 }
135
136 __device__ static auto CalculateAThreadOriginDataIndex6D()
137 {
138 const auto wave_idx = GetWaveIdx();
139
140 const auto waveId_m = wave_idx[I0];
141
142 const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
143
144 return make_tuple(0, waveId_m, xdlops_a_idx[I1], 0, xdlops_a_idx[I0], 0);
145 }
146
147 __device__ static auto CalculateBThreadOriginDataIndex()
148 {
149 const auto wave_idx = GetWaveIdx();
150
151 const auto waveId_n = wave_idx[I1];
152
153 const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
154
155 return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
156 }
157
158 template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
159 __device__ static auto
161 {
162 const auto wave_idx = GetWaveIdx();
163
164 const auto waveId_m = wave_idx[I0];
165 const auto waveId_n = wave_idx[I1];
166
167 const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
168
169 constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
173
174 constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
178
179 const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
180 make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
181 const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
182 make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
183
184 return make_tuple(c_thread_m, c_thread_n);
185 }
186
187 template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
188 __device__ static auto
190 {
191 const auto wave_idx = GetWaveIdx();
192
193 const auto waveId_m = wave_idx[I0];
194 const auto waveId_n = wave_idx[I1];
195
196 const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
197
198 return make_tuple(
199 m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
200 }
201
203
221 __host__ __device__
224 : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
225 {
226#if defined(__HIP_DEVICE_COMPILE__)
227 static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
228 "wrong! Desc should be known at compile-time");
229
231 "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
232
233 static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
234 "wrong!");
235#endif
236 }
237
238 // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
239 __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
240 {
241 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
242
243 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
244 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
245 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
246 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
247
249 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
250 }
251
252 // XDL output supporting C_xdl = A_xdl * B_xdl
253 __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
254 {
255 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
256
257 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
258 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
259 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
260 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
261
263 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
264 }
265
266 __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
267 {
268 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
269
270 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
271 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
272 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
273 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
274
276 make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
277 }
278
279 // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
280 __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
281 {
282 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
288 Number<NPerXDL>{}));
289
290 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
291 }
292
293 // XDL output supporting C_xdl = A_xdl * B_xdl
294 __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
295 {
296 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
302 Number<NPerXDL>{}));
303
304 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
305 }
306
307 __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
308 {
309 constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
316 Number<NPerXDL>{}));
317
318 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
319 c_block_desc_g_m0_n0_m1_n1_m2_n2);
320 }
321
322 template <typename CGridDesc_M_N>
323 __host__ __device__ static constexpr auto
324 MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
325 {
326 const auto M = c_grid_desc_m_n.GetLength(I0);
327 const auto N = c_grid_desc_m_n.GetLength(I1);
328
329 const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
330 c_grid_desc_m_n,
331 make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
332 make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
335
336 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
337 }
338
339 template <typename CGridDesc_G_M_N>
340 __host__ __device__ static constexpr auto
341 MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
342 {
343 const auto G = c_grid_desc_g_m_n.GetLength(I0);
344 const auto M = c_grid_desc_g_m_n.GetLength(I1);
345 const auto N = c_grid_desc_g_m_n.GetLength(I2);
346
347 const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
348 c_grid_desc_g_m_n,
350 make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
351 make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
354
355 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
356 c_grid_desc_g_m0_n0_m1_n1_m2_n2);
357 }
358 __host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; }
359 static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
360 static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
361
362 protected:
363 // M1, N1 as double buffer index
364 // Read buffer + Compute buffer
365 // A[M0, M1, M2, KPack]
370
371 // B[N0, N1, N2, KPack]
376
377 // C[M, N, NumRegXdlops]
379 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
380
383 decltype(a_block_desc_m0_m1_m2_k),
384 decltype(a_thread_desc_),
387 3,
388 A_K1,
389 A_K1>;
390
393 decltype(b_block_desc_n0_n1_n2_k),
394 decltype(b_thread_desc_),
397 3,
398 B_K1,
399 B_K1>;
400
403};
404
405} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
f8_fnuz_t f8_t
Definition amd_ck_fp8.hpp:1762
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition blockwise_gemm_pipeline_xdlops_base.hpp:57
static __device__ auto CalculateBThreadOriginDataIndex()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:147
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataTypeBuf, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_base.hpp:381
static __device__ auto GetWaveIdx()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:113
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCThreadDesc()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:358
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static __device__ auto CalculateAThreadOriginDataIndex()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:125
static __device__ auto CalculateAThreadOriginDataIndex6D()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:136
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition blockwise_gemm_pipeline_xdlops_base.hpp:41
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
ThreadwiseTensorSliceTransfer_v4< BDataType, ComputeDataTypeBuf, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_pipeline_xdlops_base.hpp:391
decltype(CalculateAThreadOriginDataIndex()) Tuple4
Definition blockwise_gemm_pipeline_xdlops_base.hpp:202
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
Definition blockwise_gemm_pipeline_xdlops.hpp:34
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:75
static __device__ constexpr index_t GetNumOfThread()
Definition thread_group.hpp:15
static __device__ index_t GetThreadId()
Definition thread_group.hpp:19
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition xdlops_gemm.hpp:1821