add_rmsnorm2d_rdquant_fwd_kernel.hpp Source File

add_rmsnorm2d_rdquant_fwd_kernel.hpp Source File#

Composable Kernel: add_rmsnorm2d_rdquant_fwd_kernel.hpp Source File
add_rmsnorm2d_rdquant_fwd_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// host side args
12// X = A + B, Y = Rmsnorm2d(X), QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
14{
15 const void* p_a; // [m ,n], input, fp16/bf16
16 const void* p_b; // [m ,n], input, fp16/bf16
17 const void* p_gamma; // [1, n], gamma, prec same as input
18
19 void* p_x; // [m, n], output, p_a + p_b, fp16/bf16
20 void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of reuslt of rmsnorm2d(x)
21 void* p_qy; // [m, n], output, result of quant tensor of rmsnorm2d(x) int8
22
23 float epsilon;
24
27 index_t stride; // row_stride
28};
29
30// TODO: Extract some type to wrapper class
31template <typename Pipeline_>
33{
35 using Problem = typename Pipeline::Problem;
36
44
45 static constexpr bool kSaveX = Problem::kSaveX;
46
47 static constexpr index_t Block_M = Problem::BlockShape::Block_M;
48 static constexpr index_t Block_N = Problem::BlockShape::Block_N;
49 static constexpr bool kPadM = false; // always no need to pad along M
50 static constexpr bool kPadN = Problem::kPadN;
51 static constexpr bool kThreePass = Problem::kThreePass;
52
53 static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
54 static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
55 static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
56 static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
57
58 static constexpr auto I0 = number<0>{};
59 static constexpr auto I1 = number<1>{};
60
61 struct Kargs
62 {
63 const void* p_a;
64 const void* p_b;
65 const void* p_gamma;
66
67 void* p_x;
68 void* p_yscale;
69 void* p_qy;
70
71 float epsilon;
72
75 index_t stride; // row_stride
76 };
78
79 CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
80 {
81 return Kargs{hargs.p_a,
82 hargs.p_b,
83 hargs.p_gamma,
84 hargs.p_x,
85 hargs.p_yscale,
86 hargs.p_qy,
87 hargs.epsilon,
88 hargs.m,
89 hargs.n,
90 hargs.stride};
91 }
92
93 CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
94 {
95 return dim3(integer_divide_ceil(hargs.m, Block_M));
96 }
97
98 CK_TILE_HOST static constexpr auto BlockSize()
99 {
100 return is_wave32() ? Problem::BlockShape::template GetBlockSize<true>()
101 : Problem::BlockShape::template GetBlockSize<false>();
102 }
103
104 // clang-format off
105 template <typename T> struct t2s;
106 template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
107 template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
108 template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
109 template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
110 template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
111 // clang-format on
112
113 // in byte
114 CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
115
116 CK_TILE_HOST static std::string GetName()
117 {
118 // clang-format off
119 using S_ = typename Problem::BlockShape;
120 auto surfix = [&] () {
121 std::string n;
122 if (kPadN) n += "_pn";
123 if (kSaveX) n += "_x";
124 if (kThreePass) n += "_2p";
125 return n; }();
126
127 #define _SS_ std::string
128 #define _TS_ std::to_string
129 return _SS_("add_rmsnorm2d_rdquant_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
130 _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
131 _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
132 _SS_(Pipeline::name) + surfix;
133 #undef _SS_
134 #undef _TS_
135 // clang-format on
136 }
137
139 {
140 const auto iM = get_block_id() * Block_M;
141
142 const auto a_window = [&]() {
144 static_cast<const ADataType*>(kargs.p_a),
145 make_tuple(kargs.m, kargs.n),
146 make_tuple(kargs.stride, 1),
148 number<1>{});
149
150 const auto tmp2_ = pad_tensor_view(
152 return make_tile_window(
153 tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
154 }();
155
156 const auto b_window = [&]() {
158 static_cast<const BDataType*>(kargs.p_b),
159 make_tuple(kargs.m, kargs.n),
160 make_tuple(kargs.stride, 1),
162 number<1>{});
163
164 const auto tmp2_ = pad_tensor_view(
166 return make_tile_window(
167 tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
168 }();
169
170 const auto gamma_window = [&]() {
172 static_cast<const GammaDataType*>(kargs.p_gamma),
173 make_tuple(kargs.n),
174 make_tuple(1),
176 number<1>{});
177
178 const auto tmp2_ =
180
181 return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
182 }();
183
184 auto x_window = [&]() {
185 if constexpr(kSaveX)
186 {
187 const auto tmp2_ = [&]() {
189 static_cast<XDataType*>(kargs.p_x),
190 make_tuple(kargs.m, kargs.n),
191 make_tuple(kargs.stride, 1),
193 number<1>{});
194
195 return pad_tensor_view(tmp_,
198 }();
199 return make_tile_window(
200 tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
201 }
202 else
204 }();
205
206 auto yscale_window = [&]() {
208 static_cast<YScaleDataType*>(kargs.p_yscale),
209 make_tuple(kargs.m),
210 make_tuple(1),
211 number<1>{});
212
214 return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
215 }();
216
217 auto qy_window = [&]() {
219 static_cast<QYDataType*>(kargs.p_qy),
220 make_tuple(kargs.m, kargs.n),
221 make_tuple(kargs.stride, 1),
223 number<1>{});
224
225 auto tmp2_ = pad_tensor_view(
227 return make_tile_window(
228 tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
229 }();
230
231 __shared__ char smem[GetSmemSize()];
232
233 Pipeline{}(a_window,
234 b_window,
235 gamma_window,
236 x_window,
237 yscale_window,
238 qy_window,
239 static_cast<const ComputeDataType>(kargs.epsilon),
240 kargs.n,
241 smem);
242 }
243};
244
245} // namespace ck_tile
#define _TS_
#define _SS_
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_Float16 fp16_t
Definition half.hpp:110
_BitInt(8) fp8_t
Definition float8.hpp:204
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths &window_lengths)
Definition null_tile_window.hpp:66
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
CK_TILE_DEVICE index_t get_block_id()
Definition arch.hpp:119
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:62
void * p_qy
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:69
index_t n
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:74
const void * p_b
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:64
void * p_yscale
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:68
const void * p_a
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:63
index_t stride
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:75
float epsilon
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:71
index_t m
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:73
const void * p_gamma
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:65
void * p_x
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:67
static constexpr const char * name
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:108
static constexpr const char * name
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:110
static constexpr const char * name
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:107
static constexpr const char * name
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:109
static constexpr const char * name
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:106
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:105
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:14
float epsilon
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:23
const void * p_a
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:15
void * p_x
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:19
const void * p_b
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:16
const void * p_gamma
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:17
index_t stride
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:27
void * p_yscale
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:20
void * p_qy
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:21
index_t m
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:25
index_t n
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:26
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:33
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:40
static constexpr bool kPadN
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:50
static constexpr bool kThreePass
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:51
static constexpr bool kSaveX
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:45
static constexpr index_t Vector_N
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:54
remove_cvref_t< typename Problem::BDataType > BDataType
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:38
static constexpr index_t Block_M
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:47
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:138
static CK_TILE_HOST std::string GetName()
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:116
static constexpr index_t Block_N
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:48
static constexpr auto I1
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:59
remove_cvref_t< typename Problem::ADataType > ADataType
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:37
static constexpr index_t kBlockSize
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:56
remove_cvref_t< Pipeline_ > Pipeline
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:34
static constexpr bool kPadM
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:49
typename Pipeline::Problem Problem
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:35
static CK_TILE_HOST constexpr Kargs MakeKargs(const Hargs &hargs)
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:79
static constexpr index_t Repeat_N
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:55
static CK_TILE_HOST constexpr auto BlockSize()
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:98
static CK_TILE_HOST constexpr auto GridSize(const Hargs &hargs)
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:93
static constexpr index_t ThreadPerWarp_N
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:53
static constexpr auto I0
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:58
AddRmsnorm2dRdquantFwdHostArgs Hargs
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:77
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:42
remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:39
remove_cvref_t< typename Problem::QYDataType > QYDataType
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:43
remove_cvref_t< typename Problem::XDataType > XDataType
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:41
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition add_rmsnorm2d_rdquant_fwd_kernel.hpp:114
Definition tile/core/container/sequence.hpp:49