device_grouped_gemm_multiple_d_dl.hpp Source File

device_grouped_gemm_multiple_d_dl.hpp Source File#

Composable Kernel: device_grouped_gemm_multiple_d_dl.hpp Source File
device_grouped_gemm_multiple_d_dl.hpp
Go to the documentation of this file.
1#pragma once
2// SPDX-License-Identifier: MIT
3// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
4
5#pragma once
6
7#include <iostream>
8#include <sstream>
9
11#include "ck/utility/env.hpp"
21
22namespace ck {
23namespace tensor_operation {
24namespace device {
25
26template <typename GridwiseGemm,
27 typename GemmDesc,
28 typename AElementwiseOperation,
29 typename BElementwiseOperation,
30 typename CDEElementwiseOperation,
31 bool HasMainKBlockLoop,
32 bool HasDoubleTailKBlockLoop>
33__global__ void
34#if CK_USE_LAUNCH_BOUNDS
36#endif
37 kernel_grouped_gemm_multiple_d_dl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
38 const index_t group_count,
39 const AElementwiseOperation a_element_op,
40 const BElementwiseOperation b_element_op,
41 const CDEElementwiseOperation cde_element_op)
42{
43#if(defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \
44 defined(__gfx11__) || defined(__gfx94__) || defined(__gfx12__))
45 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
46
47 const index_t block_id = get_block_1d_id();
48
49 const auto gemm_desc_ptr =
50 reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
51
52 index_t left = 0;
53 index_t right = group_count;
54 index_t group_id = index_t((left + right) / 2);
55 while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
56 block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
57 left <= right)
58 {
59 if(block_id < gemm_desc_ptr[group_id].BlockStart_)
60 {
61 right = group_id;
62 }
63 else
64 {
65 left = group_id;
66 }
67 group_id = index_t((left + right) / 2);
68 }
69
70 GridwiseGemm::Run(gemm_desc_ptr[group_id].a_ptr_,
71 gemm_desc_ptr[group_id].b_ptr_,
72 gemm_desc_ptr[group_id].ds_ptr_,
73 gemm_desc_ptr[group_id].e_ptr_,
74 p_shared,
75 a_element_op,
76 b_element_op,
77 cde_element_op,
78 gemm_desc_ptr[group_id].a_grid_desc_k0_m0_m1_k1_,
79 gemm_desc_ptr[group_id].b_grid_desc_k0_n0_n1_k1_,
80 gemm_desc_ptr[group_id].ds_grid_desc_m0_m10_m11_n0_n10_n11_,
81 gemm_desc_ptr[group_id].e_grid_desc_m0_m10_m11_n0_n10_n11_,
82 gemm_desc_ptr[group_id].block_2_etile_map_,
83 integral_constant<bool, HasMainKBlockLoop>{},
84 integral_constant<bool, HasDoubleTailKBlockLoop>{});
85#else
86 ignore = gemm_descs_const;
87 ignore = group_count;
88 ignore = a_element_op;
89 ignore = b_element_op;
90 ignore = cde_element_op;
91#endif
92}
93
94template <typename ALayout,
95 typename BLayout,
96 typename DsLayout,
97 typename ELayout,
98 typename ADataType,
99 typename BDataType,
100 typename AccDataType,
101 typename DsDataType,
102 typename EDataType,
103 typename AElementwiseOperation,
104 typename BElementwiseOperation,
105 typename CDEElementwiseOperation,
106 GemmSpecialization GemmSpec,
107 index_t BlockSize,
108 index_t MPerBlock,
109 index_t NPerBlock,
110 index_t K0PerBlock,
111 index_t K1,
112 index_t M1PerThread,
113 index_t N1PerThread,
114 index_t KPerThread,
115 typename M1N1ThreadClusterM1Xs,
116 typename M1N1ThreadClusterN1Xs,
117 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
118 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
119 typename ABlockTransferThreadClusterArrangeOrder,
120 typename ABlockTransferSrcAccessOrder,
121 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
122 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
123 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
124 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
125 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
126 typename BBlockTransferThreadClusterArrangeOrder,
127 typename BBlockTransferSrcAccessOrder,
128 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
129 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
130 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
131 typename CThreadTransferSrcDstAccessOrder,
132 index_t CThreadTransferSrcDstVectorDim,
133 index_t CThreadTransferDstScalarPerVector,
137 bool> = false>
138struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
139 BLayout,
140 DsLayout,
141 ELayout,
142 ADataType,
143 BDataType,
144 DsDataType,
145 EDataType,
146 AElementwiseOperation,
147 BElementwiseOperation,
148 CDEElementwiseOperation>
149{
150 using DeviceOp = DeviceGroupedGemmMultipleD_Dl;
151 static constexpr index_t NumDTensor = DsDataType::Size();
152
153 static constexpr auto I0 = Number<0>{};
154 static constexpr auto I1 = Number<1>{};
155 static constexpr auto I2 = Number<2>{};
156 static constexpr auto I3 = Number<3>{};
157 static constexpr auto I4 = Number<4>{};
158 static constexpr auto I5 = Number<5>{};
159
160 static constexpr auto K1Number = Number<K1>{};
161
162 static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
163 {
164 assert(K % K1 == 0);
165
166 const index_t K0 = K / K1;
167
168 const auto a_grid_desc_m_k = [&]() {
169 if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
170 {
171 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
172 }
173 else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
174 {
175 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
176 }
177 }();
178
179 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
180 {
181 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
182
184 a_grid_desc_m_k,
186 make_right_pad_transform(M, PadM)),
187 make_tuple(Sequence<1>{}, Sequence<0>{}),
188 make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
189 }
190 else
191 {
193 a_grid_desc_m_k,
196 make_tuple(Sequence<1>{}, Sequence<0>{}),
197 make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
198 }
199 }
200
201 static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
202 {
203 assert(K % K1 == 0);
204
205 const index_t K0 = K / K1;
206
207 const auto b_grid_desc_k_n = [&]() {
208 if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
209 {
210 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
211 }
212 else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
213 {
214 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
215 }
216 }();
217
218 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
219 {
220 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
221
223 b_grid_desc_k_n,
225 make_right_pad_transform(N, PadN)),
226 make_tuple(Sequence<0>{}, Sequence<1>{}),
227 make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
228 }
229 else
230 {
232 b_grid_desc_k_n,
235 make_tuple(Sequence<0>{}, Sequence<1>{}),
236 make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
237 }
238 }
239
240 template <typename ELay>
241 static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
242 {
243 const auto c_grid_desc_m_n = [&]() {
244 if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
245 {
246 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1));
247 }
248 else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
249 {
250 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE));
251 }
252 }();
253
254 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
255 {
256 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
257 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
258
260 c_grid_desc_m_n,
262 make_tuple(Sequence<0>{}, Sequence<1>{}),
263 make_tuple(Sequence<0>{}, Sequence<1>{}));
264 }
265 else
266 {
267
269 c_grid_desc_m_n,
271 make_tuple(Sequence<0>{}, Sequence<1>{}),
272 make_tuple(Sequence<0>{}, Sequence<1>{}));
273 }
274 }
275
276 static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
277 const std::array<index_t, NumDTensor>& NRaws,
278 const std::array<index_t, NumDTensor>& DsStride)
279 {
280 return generate_tuple(
281 [&](auto i) {
282 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
283
284 return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
285 },
286 Number<NumDTensor>{});
287 }
288
289 using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
290 using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
291 using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {}));
292 using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
293
294 // GridwiseGemm
295 using GridwiseGemm =
296 GridwiseGemmDlMultipleD_km_kn_mn<BlockSize,
297 ADataType,
298 AccDataType,
299 DsDataType,
300 EDataType,
301 AElementwiseOperation,
302 BElementwiseOperation,
303 CDEElementwiseOperation,
304 InMemoryDataOperationEnum::Set,
305 AGridDesc_K0_M_K1,
306 BGridDesc_K0_N_K1,
307 EGridDesc_M_N,
308 MPerBlock,
309 NPerBlock,
310 K0PerBlock,
311 K1,
312 M1PerThread,
313 N1PerThread,
314 KPerThread,
315 M1N1ThreadClusterM1Xs,
316 M1N1ThreadClusterN1Xs,
317 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
318 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
319 ABlockTransferThreadClusterArrangeOrder,
320 ABlockTransferSrcAccessOrder,
321 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
322 ABlockTransferSrcVectorTensorContiguousDimOrder,
323 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
324 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
325 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
326 BBlockTransferThreadClusterArrangeOrder,
327 BBlockTransferSrcAccessOrder,
328 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
329 BBlockTransferSrcVectorTensorContiguousDimOrder,
330 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
331 CThreadTransferSrcDstAccessOrder,
332 CThreadTransferSrcDstVectorDim,
333 CThreadTransferDstScalarPerVector>;
334
335 using AGridDesc_K0_M0_M1_K1 =
336 decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
337 using BGridDesc_K0_N0_N1_K1 =
338 decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
339 using DsGridDesc_M0_M10_M11_N0_N10_N11 =
340 decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{}));
341 using EGridDesc_M0_M10_M11_N0_N10_N11 =
342 decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{}));
343
344 struct GroupedGemmBlock2ETileMap
345 {
346 using Block2ETileMap =
347 remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{}))>;
348
349 GroupedGemmBlock2ETileMap()
350 {
351 block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{});
352 BlockStart_ = -1;
353 }
354
355 GroupedGemmBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart)
356 {
357 block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n);
358 BlockStart_ = BlockStart;
359 }
360
361 template <typename TopIdx>
362 __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
363 {
364 return block_2_etile_map_.CalculateBottomIndex(
365 make_multi_index(idx_top[I0] - BlockStart_));
366 }
367
368 // it's actually E-Tile
369 template <typename CTileIdx, typename CTileDim>
370 __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
371 const CTileDim& c_tile_dim) const
372 {
373 return block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
374 }
375
376 __host__ bool CheckValidity(const EGridDesc_M_N& e_grid_desc_m_n) const
377 {
378 return block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
379 }
380
381 Block2ETileMap block_2_etile_map_;
382 ck::index_t BlockStart_;
383 };
384
385 struct GemmKernelArg
386 {
387 // pointers
388 const ADataType* a_ptr_;
389 const BDataType* b_ptr_;
390 typename GridwiseGemm::DsGridPointer ds_ptr_;
391 EDataType* e_ptr_;
392
393 // tensor descriptors for problem definiton
394 AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
395 BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
396 DsGridDesc_M_N ds_grid_desc_m_n_;
397 EGridDesc_M_N e_grid_desc_m_n_;
398
399 // tensor descriptors for block/thread-wise copy
400 AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_;
401 BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_;
402 DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_;
403 EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_;
404
405 // block-to-e-tile map
406 GroupedGemmBlock2ETileMap block_2_etile_map_;
407 ck::index_t BlockStart_, BlockEnd_;
408 };
409
410 // Argument
411 struct Argument : public BaseArgument
412 {
413 Argument(std::vector<const void*>& p_As,
414 std::vector<const void*>& p_Bs,
415 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
416 std::vector<void*>& p_Es,
417 std::vector<GemmDesc>& gemm_descs,
418 AElementwiseOperation a_element_op,
419 BElementwiseOperation b_element_op,
420 CDEElementwiseOperation cde_element_op)
421 : a_element_op_{a_element_op},
422 b_element_op_{b_element_op},
423 cde_element_op_{cde_element_op},
424 gemm_kernel_host_args_{nullptr}
425 {
426 grid_size_ = 0;
427
428 group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
429
430 if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
431 group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
432 group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
433 {
434 throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
435 }
436
437 gemm_desc_kernel_arg_.reserve(group_count_);
438
439 skipped_group_count_ = 0;
440
441 for(std::size_t i = 0; i < gemm_descs.size(); i++)
442 {
443 const index_t M = gemm_descs[i].M_;
444 const index_t N = gemm_descs[i].N_;
445 const index_t K = gemm_descs[i].K_;
446
447 a_mtx_mraw_kraw_.emplace_back(M, K);
448 b_mtx_nraw_kraw_.emplace_back(N, K);
449
450 if(M == 0)
451 {
452 skipped_group_count_++;
453 continue;
454 }
455
456 const index_t StrideA = gemm_descs[i].stride_A_;
457 const index_t StrideB = gemm_descs[i].stride_B_;
458 const index_t StrideE = gemm_descs[i].stride_C_;
459
460 typename GridwiseGemm::DsGridPointer p_ds_grid{};
461 DsGridDesc_M_N ds_grid_desc_m_n;
462
463 static_for<0, NumDTensor, 1>{}([&](auto j) {
464 using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
465 using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
466
467 p_ds_grid(j) = static_cast<const DDataType*>(p_Ds[i][j]);
468 ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
469 M, N, gemm_descs[i].stride_Ds_[j]);
470 });
471
472 // tensor descriptors for problem definiton
473 const auto a_grid_desc_k0_m_k1 =
474 DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
475 const auto b_grid_desc_k0_n_k1 =
476 DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
477 const auto e_grid_desc_m_n =
478 DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
479
480 if(GridwiseGemm::CheckValidity(
481 a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, e_grid_desc_m_n))
482 {
483
484 const index_t grid_size_grp =
485 GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0)
486 .block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n);
487
488 const index_t BlockStart = grid_size_;
489 const index_t BlockEnd = grid_size_ + grid_size_grp;
490
491 grid_size_ += grid_size_grp;
492
493 // block-to-e-tile map
494 const auto block_2_etile_map =
495 GroupedGemmBlock2ETileMap(e_grid_desc_m_n, BlockStart);
496
497 // tensor descriptors for block/thread-wise copy
498 const auto a_grid_desc_k0_m0_m1_k1 =
499 GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1);
500 const auto b_grid_desc_k0_n0_n1_k1 =
501 GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1);
502 const auto ds_grid_desc_m0_m10_m11_n0_n10_n11 =
503 GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n);
504 const auto e_grid_desc_m0_m10_m11_n0_n10_n11 =
505 GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(e_grid_desc_m_n);
506
507 gemm_desc_kernel_arg_.push_back(
508 GemmKernelArg{static_cast<const ADataType*>(p_As[i]),
509 static_cast<const BDataType*>(p_Bs[i]),
510 p_ds_grid,
511 static_cast<EDataType*>(p_Es[i]),
512 a_grid_desc_k0_m_k1,
513 b_grid_desc_k0_n_k1,
514 ds_grid_desc_m_n,
515 e_grid_desc_m_n,
516 a_grid_desc_k0_m0_m1_k1,
517 b_grid_desc_k0_n0_n1_k1,
518 ds_grid_desc_m0_m10_m11_n0_n10_n11,
519 e_grid_desc_m0_m10_m11_n0_n10_n11,
520 block_2_etile_map,
521 BlockStart,
522 BlockEnd});
523 }
524 }
525 }
526
527 // private:
528 index_t group_count_;
529 index_t skipped_group_count_;
530
531 // TODO: A,B element op is unused since gridwise_gemm_dl_v1r3 does NOT support prologue
532 // for the time being.
533 AElementwiseOperation a_element_op_;
534 BElementwiseOperation b_element_op_;
535 CDEElementwiseOperation cde_element_op_;
536
537 std::vector<GemmKernelArg> gemm_desc_kernel_arg_;
538 std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
539 std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
540
541 index_t grid_size_;
542 void* gemm_kernel_host_args_;
543 };
544
545 // Invoker
546 struct Invoker : public BaseInvoker
547 {
548 using Argument = DeviceOp::Argument;
549
550 float Run(const Argument& arg,
551 const StreamConfig& stream_config = StreamConfig{},
552 hipStream_t cpy_stream = nullptr,
553 hipEvent_t cpy_event = nullptr)
554 {
555 auto K0 = arg.gemm_desc_kernel_arg_[0].a_grid_desc_k0_m_k1_.GetLength(I0);
556 bool all_has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
557 bool all_has_double_tail_k_block_loop =
558 GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
559
560 for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
561 {
562 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
563 {
564 std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
565 << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0)
566 << ", "
567 << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1)
568 << ", "
569 << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2)
570 << "}" << std::endl;
571
572 std::cout << ", arg.b_grid_desc_k0_n_k1_{"
573 << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0)
574 << ", "
575 << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1)
576 << ", "
577 << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2)
578 << "}" << std::endl;
579
580 std::cout << ", arg.e_grid_desc_m_n_{ "
581 << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
582 << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
583 << std::endl;
584 }
585
586 if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
587 arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
588 arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_))
589 {
590 throw std::runtime_error(
591 "wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting");
592 }
593
594 K0 = arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
595 bool not_all_has_main_k_block_loop_same =
596 all_has_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop(K0);
597 bool not_all_has_double_tail_k_block_loop_same =
598 all_has_double_tail_k_block_loop xor
599 GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
600
601 if(not_all_has_main_k_block_loop_same or not_all_has_double_tail_k_block_loop_same)
602 {
603 std::ostringstream err;
604 err << "Not all gemms have same value for [main|double_tail]_k_block_loop! in "
605 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
606 throw std::runtime_error(err.str());
607 }
608 }
609
610 // If the user provides copy stream and copy event, we assume that they're also
611 // responsible for providing allocated host memory (eg. pinned) which
612 // would be used to copy kernel arguments to the device.
613 if(cpy_stream && cpy_event)
614 {
615 if(arg.gemm_kernel_host_args_ == nullptr)
616 {
617 std::ostringstream err;
618 err << "No memory has been allocated for gemm kernel host args "
619 << "when providing the copy stream and copy event! In " << __FILE__ << ":"
620 << __LINE__ << ", in function: " << __func__;
621 throw std::runtime_error(err.str());
622 }
623 hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
624 arg.gemm_kernel_host_args_,
625 arg.group_count_ * sizeof(GemmKernelArg),
626 hipMemcpyHostToDevice,
627 cpy_stream));
628 hipGetErrorString(hipEventRecord(cpy_event, cpy_stream));
629 hipGetErrorString(hipEventSynchronize(cpy_event));
630 }
631 else // In this case CK owns memory allocated on host.
632 {
633 hipGetErrorString(
634 hipMemcpyAsync(arg.p_workspace_,
635 arg.gemm_desc_kernel_arg_.data(),
636 arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
637 hipMemcpyHostToDevice,
638 stream_config.stream_id_));
639 }
640
641 auto launch_kernel = [&](auto has_main_k_block_loop,
642 auto has_double_tail_k_block_loop) {
643 constexpr bool has_main_loop = has_main_k_block_loop.value;
644 constexpr bool has_double_loop = has_double_tail_k_block_loop.value;
645
646 const auto kernel = kernel_grouped_gemm_multiple_d_dl<GridwiseGemm,
647 GemmKernelArg,
648 AElementwiseOperation,
649 BElementwiseOperation,
650 CDEElementwiseOperation,
651 has_main_loop,
652 has_double_loop>;
653
655 stream_config,
656 kernel,
657 dim3(arg.grid_size_),
658 dim3(BlockSize),
659 0,
661 arg.gemm_desc_kernel_arg_.size(),
662 arg.a_element_op_,
663 arg.b_element_op_,
664 arg.cde_element_op_);
665 };
666
667 if(all_has_main_k_block_loop && all_has_double_tail_k_block_loop)
668 {
669 return launch_kernel(integral_constant<bool, true>{},
670 integral_constant<bool, true>{});
671 }
672 else if(all_has_main_k_block_loop && !all_has_double_tail_k_block_loop)
673 {
674 return launch_kernel(integral_constant<bool, true>{},
675 integral_constant<bool, false>{});
676 }
677 else if(!all_has_main_k_block_loop && all_has_double_tail_k_block_loop)
678 {
679 return launch_kernel(integral_constant<bool, false>{},
680 integral_constant<bool, true>{});
681 }
682 else
683 {
684 return launch_kernel(integral_constant<bool, false>{},
685 integral_constant<bool, false>{});
686 }
687 }
688
689 // polymorphic
690 float Run(const BaseArgument* p_arg,
691 const StreamConfig& stream_config = StreamConfig{}) override
692 {
693 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
694 }
695 };
696
697 static bool IsSupportedArgument(const Argument& arg)
698 {
699 if((ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) +
700 arg.skipped_group_count_) != arg.group_count_)
701 {
702 return false;
703 }
704
705 if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
707 {
708 for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
709 {
710 if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
711 arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
712 arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_))
713 {
714 return false;
715 }
716 }
717 return true;
718 }
719 else
720 {
721 return false;
722 }
723 }
724
725 // polymorphic
726 bool IsSupportedArgument(const BaseArgument* p_arg) override
727 {
728 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
729 }
730
731 static auto MakeArgument(std::vector<const void*>& p_As,
732 std::vector<const void*>& p_Bs,
733 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
734 std::vector<void*>& p_Es,
735 std::vector<GemmDesc> gemm_descs,
736 AElementwiseOperation a_element_op,
737 BElementwiseOperation b_element_op,
738 CDEElementwiseOperation cde_element_op)
739 {
740 return Argument{
741 p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, cde_element_op};
742 }
743
744 static auto MakeInvoker() { return Invoker{}; }
745
746 // polymorphic
747 std::unique_ptr<BaseArgument>
748 MakeArgumentPointer(std::vector<const void*>& p_As,
749 std::vector<const void*>& p_Bs,
750 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
751 std::vector<void*>& p_Es,
752 std::vector<GemmDesc>& gemm_descs,
753 AElementwiseOperation a_element_op,
754 BElementwiseOperation b_element_op,
755 CDEElementwiseOperation cde_element_op) override
756 {
757 return std::make_unique<Argument>(
758 p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, cde_element_op);
759 }
760
761 // polymorphic
762 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
763 {
764 return std::make_unique<Invoker>(Invoker{});
765 }
766
767 // polymorphic
768 std::string GetTypeString() const override
769 {
770 auto str = std::stringstream();
771
772 // clang-format off
773 str << "DeviceGroupedGemmMultipleD_Dl"
774 << "<"
775 << BlockSize << ", "
776 << MPerBlock << ", "
777 << NPerBlock << ", "
778 << K0PerBlock << ", "
779 << K1 << ", "
780 << M1PerThread << ", "
781 << N1PerThread << ", "
782 << KPerThread
783 << getGemmSpecializationString(GemmSpec)
784 << ">";
785 // clang-format on
786
787 return str.str();
788 }
789
790 size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
791 {
792 return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmKernelArg);
793 }
794
795 size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
796 {
797 return GetWorkSpaceSize(p_arg);
798 }
799
800 size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); }
801
802 void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
803 {
804 return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
805 }
806
807 //----------------------------------------------------------------------------------------------
815 void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const
816 {
817 Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
818 if(!pArg_)
819 {
820 throw std::runtime_error("Failed to cast argument pointer!");
821 }
822
823 pArg_->gemm_kernel_host_args_ = p_host_kernel_args;
824 std::copy(pArg_->gemm_desc_kernel_arg_.begin(),
825 pArg_->gemm_desc_kernel_arg_.end(),
826 static_cast<GemmKernelArg*>(pArg_->gemm_kernel_host_args_));
827 }
828};
829
830} // namespace device
831} // namespace tensor_operation
832} // namespace ck
#define CK_CONSTANT_ADDRESS_SPACE
Definition ck.hpp:23
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition amd_address_space.hpp:35
bool is_xdl_supported()
Definition host_utility/device_prop.hpp:68
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_gfx103_supported()
Definition host_utility/device_prop.hpp:120
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition amd_address_space.hpp:24
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
Definition device_grouped_gemm.hpp:99
Definition device_grouped_gemm.hpp:80
#define CK_ENV(name)
Definition utility/env.hpp:129