25template <
typename ALayout,
32 typename CShuffleDataType,
35 typename AElementwiseOperation,
36 typename BElementwiseOperation,
37 typename CDEElementwiseOperation,
49 typename ABlockTransferThreadClusterLengths_K0_M_K1,
50 typename ABlockTransferThreadClusterArrangeOrder,
51 typename ABlockTransferSrcAccessOrder,
55 bool ABlockLdsAddExtraM,
56 typename BBlockTransferThreadClusterLengths_K0_N_K1,
57 typename BBlockTransferThreadClusterArrangeOrder,
58 typename BBlockTransferSrcAccessOrder,
62 bool BBlockLdsAddExtraN,
63 index_t CShuffleMRepeatPerShuffle,
64 index_t CShuffleNRepeatPerShuffle,
65 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
66 index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
77 AElementwiseOperation,
78 BElementwiseOperation,
79 CDEElementwiseOperation>
94 static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
95 static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
96 static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
116 const auto a_grid_desc_m_k = [&]() {
119 const auto a_grid_desc_mraw_kraw =
122 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
126 const auto a_grid_desc_mraw_kraw =
129 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
133 const auto M = a_grid_desc_m_k.GetLength(
I0);
134 const auto K = a_grid_desc_m_k.GetLength(
I1);
150 constexpr auto A_KRow = 2;
152 const auto A_KWmma = K /
WmmaK;
154 const auto M0 = M / MPerBlock;
170 const auto b_grid_desc_n_k = [&]() {
173 const auto b_grid_desc_nraw_kraw =
176 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
180 const auto b_grid_desc_nraw_kraw =
183 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
187 const auto N = b_grid_desc_n_k.GetLength(
I0);
188 const auto K = b_grid_desc_n_k.GetLength(
I1);
204 constexpr auto B_KRow = 2;
206 const auto B_KWmma = K /
WmmaK;
208 const auto N0 = N / NPerBlock;
222 template <
typename ELayout_>
225 const auto e_grid_desc_mraw_nraw = [&]() {
238 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
242 const std::array<index_t, NumDTensor>& Ns,
243 const std::array<index_t, NumDTensor>& DsStride)
275 AElementwiseOperation,
276 BElementwiseOperation,
277 CDEElementwiseOperation,
290 ABlockTransferThreadClusterLengths_K0_M_K1,
291 ABlockTransferThreadClusterArrangeOrder,
292 ABlockTransferSrcAccessOrder,
293 ABlockTransferSrcVectorDim,
294 ABlockTransferSrcScalarPerVector,
295 ABlockTransferDstScalarPerVector_K1,
299 BBlockTransferThreadClusterLengths_K0_N_K1,
300 BBlockTransferThreadClusterArrangeOrder,
301 BBlockTransferSrcAccessOrder,
302 BBlockTransferSrcVectorDim,
303 BBlockTransferSrcScalarPerVector,
304 BBlockTransferDstScalarPerVector_K1,
308 CShuffleMRepeatPerShuffle,
309 CShuffleNRepeatPerShuffle,
310 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
311 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
320 const void* p_b_grid,
321 std::array<const void*, NumDTensor> p_ds_grid,
328 std::array<index_t, NumDTensor> StrideDs,
332 AElementwiseOperation a_element_op,
333 BElementwiseOperation b_element_op,
334 CDEElementwiseOperation cde_element_op)
335 :
p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
336 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
338 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
362 p_ds_grid_(i) =
static_cast<const DDataType*
>(p_ds_grid[i]);
435 throw std::runtime_error(
436 "wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting");
442 const auto K = [&]() {
454 auto launch_kernel = [&](
auto has_main_k_block_loop) {
467 AElementwiseOperation,
468 BElementwiseOperation,
469 CDEElementwiseOperation,
471 has_main_k_block_loop>;
506 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
537 if(arg.
KRaw_ % ABlockTransferSrcScalarPerVector != 0)
545 if(arg.
MRaw_ % ABlockTransferSrcScalarPerVector != 0)
558 if(arg.
KRaw_ % BBlockTransferSrcScalarPerVector != 0)
566 if(arg.
NRaw_ % BBlockTransferSrcScalarPerVector != 0)
578 bool all_valid =
true;
598 if(arg.
NRaw_ % CDEShuffleBlockTransferScalarPerVector_NPerBlock != 0)
624 std::array<const void*, NumDTensor> p_ds,
631 std::array<ck::index_t, NumDTensor> StrideDs,
633 AElementwiseOperation a_element_op,
634 BElementwiseOperation b_element_op,
635 CDEElementwiseOperation cde_element_op)
656 std::unique_ptr<BaseArgument>
659 std::array<const void*, NumDTensor> p_ds,
666 std::array<ck::index_t, NumDTensor> StrideDs,
668 AElementwiseOperation a_element_op,
669 BElementwiseOperation b_element_op,
670 CDEElementwiseOperation cde_element_op)
override
672 return std::make_unique<Argument>(p_a,
695 return std::make_unique<Invoker>(
Invoker{});
701 auto str = std::stringstream();
703 std::map<LoopScheduler, std::string> LoopSchedToString{
706 std::map<PipelineVersion, std::string> PipelineVersionToString{{
PipelineVersion::v1,
"v1"},
710 str <<
"DeviceGemmMultipleD_Wmma_CShuffle"
727 << NumPrefetch <<
", "
729 << LoopSchedToString[LoopSched] <<
", "
730 <<
"PipelineVersion: "
731 << PipelineVersionToString[PipelineVer];
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__global__ void kernel_gemm_mupltipe_d_wmma_cshuffle(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AGridDesc a_grid_desc, const BGridDesc b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:225
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:326
ck::GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AGridDesc, BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer >::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock remove_cvref_t< decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:891
ck::GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AGridDesc, BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer >::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock remove_cvref_t< decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))> DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:888
ck::GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AGridDesc, BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer >::DsGridPointer decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:896
ck::GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AGridDesc, BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer >::DefaultBlock2CTileMap remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(EGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:894
ck::GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AGridDesc, BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer >::CalculateHasMainKBlockLoop __host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:809
ck::GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AGridDesc, BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer >::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock __host__ static __device__ constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N_ &e_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:819
ck::GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AGridDesc, BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer >::MakeDefaultBlock2CTileMap __host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const EGridDesc_M_N &e_grid_desc_m_n, index_t, index_t)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:850
ck::GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AGridDesc, BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer >::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock __host__ static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N_ &ds_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:840
ck::GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AGridDesc, BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const BGridDesc &b_grid_desc, const EGridDesc_M_N &e_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:608
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition device_base.hpp:197
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:318
const BDataType * p_b_grid_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:390
GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:405
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:397
GridwiseOp::DsGridPointer p_ds_grid_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:391
index_t KRaw_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:419
const ADataType * p_a_grid_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:389
CDEElementwiseOperation cde_element_op_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:414
AGridDesc a_grid_desc
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:395
AElementwiseOperation a_element_op_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:412
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, index_t M01, index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:319
EDataType * p_e_grid_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:392
index_t MRaw_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:417
index_t M01_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:408
BGridDesc b_grid_desc
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:396
index_t N01_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:409
BElementwiseOperation b_element_op_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:413
GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:400
EGridDesc_M_N e_grid_desc_m_n_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:398
index_t NRaw_
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:418
GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:402
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:424
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:503
DeviceOp::Argument Argument
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:425
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:427
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:80
static constexpr auto AEnableLds_auto
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:98
decltype(MakeAGridDescriptor(1, 1, 1)) AGridDesc
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:255
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:657
static constexpr auto MWaves
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:94
static constexpr auto I4
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:88
static constexpr auto BEnableLds_auto
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:100
static constexpr auto K1Number
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:92
static constexpr auto BEnableLds
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:108
static constexpr auto AEnableLds
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:107
static constexpr auto AEnableLds_manu
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:104
static constexpr auto NWaves
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:95
decltype(MakeBGridDescriptor(1, 1, 1)) BGridDesc
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:256
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:114
static constexpr auto I3
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:87
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:622
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:516
static constexpr auto I0
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:84
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &Ms, const std::array< index_t, NumDTensor > &Ns, const std::array< index_t, NumDTensor > &DsStride)
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:241
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:617
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}, {}))> DsGridDesc_M_N
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:257
static constexpr auto I2
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:86
static constexpr auto BEnableLds_manu
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:105
static constexpr auto I5
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:89
static constexpr auto I6
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:90
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:693
static constexpr auto matrix_padder
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:110
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:258
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:168
GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AGridDesc, BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer > GridwiseOp
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:261
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:82
static constexpr auto WmmaK
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:96
std::string GetTypeString() const override
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:699
DeviceGemmMultipleD_Wmma_CShuffle DeviceOp
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:81
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:510
static auto MakeInvoker()
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:690
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:223
static constexpr auto I1
Definition device_gemm_multiple_d_wmma_cshuffle.hpp:85
Definition device_gemm_multiple_d.hpp:36
Definition matrix_padder.hpp:180