device_gemm_xdl_layernorm_cshuffle.hpp Source File#
device_gemm_xdl_layernorm_cshuffle.hpp
Go to the documentation of this file.
25// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example,
28// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
32// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta)
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
__global__ void kernel_gemm_layernorm_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC0 *__restrict__ p_c0_bias_grid, const FloatC0 *__restrict__ p_c0_add_grid, const FloatC0 *__restrict__ p_c0_gamma_grid, const FloatC0 *__restrict__ p_c0_beta_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const AccElementwiseOperation acc_element_op, const CElementwiseOperation c_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp:41
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp:160
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp:268
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))> DefaultBlock2CTileMap
Definition gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp:373
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
BaseOperator()=default
Definition device_gemm_xdl_layernorm_cshuffle.hpp:445
C0GridDesc_N c0_grid_desc_n_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:493
CGridDesc_M_N c_grid_desc_m_n_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:492
BElementwiseOperation b_element_op_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:496
CElementwiseOperation c_element_op_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:498
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:490
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, const C0DataType *p_c0_grid_add, const C0DataType *p_c0_grid_bias, const C0DataType *p_c0_grid_gamma, const C0DataType *p_c0_grid_beta, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_xdl_layernorm_cshuffle.hpp:446
const C0DataType * p_c0_grid_gamma_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:488
AccElementwiseOperation acc_element_op_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:497
const BDataType * p_b_grid_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:484
const C0DataType * p_c0_grid_add_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:487
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:491
const C0DataType * p_c0_grid_beta_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:489
const ADataType * p_a_grid_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:483
AElementwiseOperation a_element_op_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:495
const C0DataType * p_c0_grid_bias_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:486
CDataType * p_c_grid_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:485
Block2CTileMap block_2_ctile_map_
Definition device_gemm_xdl_layernorm_cshuffle.hpp:494
Definition device_gemm_xdl_layernorm_cshuffle.hpp:503
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_layernorm_cshuffle.hpp:507
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_layernorm_cshuffle.hpp:631
DeviceOp::Argument Argument
Definition device_gemm_xdl_layernorm_cshuffle.hpp:504
Definition device_gemm_xdl_layernorm_cshuffle.hpp:81
std::string GetTypeString() const override
Definition device_gemm_xdl_layernorm_cshuffle.hpp:762
decltype(MakeGridDescriptor_N(1)) C0GridDesc_N
Definition device_gemm_xdl_layernorm_cshuffle.hpp:384
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_layernorm_cshuffle.hpp:85
static auto MakeInvoker()
Definition device_gemm_xdl_layernorm_cshuffle.hpp:716
std::unique_ptr< BaseInvoker > MakeInvokerPointer()
Definition device_gemm_xdl_layernorm_cshuffle.hpp:756
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition device_gemm_xdl_layernorm_cshuffle.hpp:383
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_layernorm_cshuffle.hpp:438
static auto MakeGridDescriptor_N(index_t NRaw)
Definition device_gemm_xdl_layernorm_cshuffle.hpp:356
static constexpr auto I1
Definition device_gemm_xdl_layernorm_cshuffle.hpp:89
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_layernorm_cshuffle.hpp:86
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, const C0DataType *p_c0_bias, const C0DataType *p_c0_add, const C0DataType *p_c0_gamma, const C0DataType *p_c0_beta, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_xdl_layernorm_cshuffle.hpp:679
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, const void *p_c0_bias, const void *p_c0_add, const void *p_c0_gamma, const void *p_c0_beta, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, AccElementwiseOperation acc_element_op, CElementwiseOperation c_element_op, index_t=1)
Definition device_gemm_xdl_layernorm_cshuffle.hpp:718
static constexpr auto I0
Definition device_gemm_xdl_layernorm_cshuffle.hpp:88
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_layernorm_cshuffle.hpp:674
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_gemm_xdl_layernorm_cshuffle.hpp:195
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_layernorm_cshuffle.hpp:644
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition device_gemm_xdl_layernorm_cshuffle.hpp:298
decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)) BGridDesc_BK0_N_BK1
Definition device_gemm_xdl_layernorm_cshuffle.hpp:382
typename GridwiseGemm64::DefaultBlock2CTileMap Block2CTileMap
Definition device_gemm_xdl_layernorm_cshuffle.hpp:441
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_gemm_xdl_layernorm_cshuffle.hpp:92
GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, C0DataType, ReduceAccDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, NXdlPerWave_, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadCopySrcDstScalarPerVector_NPerBlock, LoopSched > GridwiseGemmBase
Definition device_gemm_xdl_layernorm_cshuffle.hpp:388
static constexpr auto I2
Definition device_gemm_xdl_layernorm_cshuffle.hpp:90
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_layernorm_cshuffle.hpp:638
DeviceGemmLayerNorm_Xdl_CShuffle DeviceOp
Definition device_gemm_xdl_layernorm_cshuffle.hpp:82
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_layernorm_cshuffle.hpp:439
decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)) AGridDesc_AK0_M_AK1
Definition device_gemm_xdl_layernorm_cshuffle.hpp:381