gridwise_normalization_bwd_gamma_beta.hpp Source File

gridwise_normalization_bwd_gamma_beta.hpp Source File#

Composable Kernel: gridwise_normalization_bwd_gamma_beta.hpp Source File
gridwise_normalization_bwd_gamma_beta.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10
11namespace ck {
12
13// dgamma = reduce_sum(dy * (x - mean) * inv_std)
14// dbeta = reduce_sum(dy)
15template <typename DYDataType,
16 typename XDataType,
17 typename MeanInvStdDataType,
18 typename ComputeDataType,
19 typename DGammaDataType,
20 typename DBetaDataType,
21 typename GridDesc_M_K,
22 typename GridDesc_M,
23 index_t BlockSize,
24 index_t MThreadClusterSize,
25 index_t KThreadClusterSize,
26 index_t MThreadSliceSize,
27 index_t KThreadSliceSize,
28 index_t DYSrcVectorDim,
29 index_t DYSrcVectorSize,
30 index_t XSrcVectorDim,
31 index_t XSrcVectorSize,
32 index_t MeanInvStdSrcVectorDim,
33 index_t MeanInvStdSrcVectorSize,
34 index_t DGammaDstVectorSize,
35 index_t DBetaDstVectorSize>
37{
38 // if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce)
39 static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) ||
40 (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)),
41 "Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
42
43 static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) ||
44 (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)),
45 "Invalid thread slice sizes and/or x vector sizes configuration, please check!");
46
47 // do not force SliceSize == MeanInvStdSrcVectorSize for groupnorm
48 static_assert(
49 ((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) ||
50 (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0)),
51 "Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!");
52
53 static_assert(MThreadSliceSize == DGammaDstVectorSize && MThreadSliceSize == DBetaDstVectorSize,
54 "Invalid thread slice sizes and/or dx vector sizes configuration, please check!");
55
57
60
63
66
68
69 static constexpr auto thread_cluster_desc =
71
74
77
78 static constexpr auto thread_buffer_desc_m =
80
82
84 BlockSize,
88 true>;
89
90 static constexpr auto I0 = Number<0>{};
91 static constexpr auto I1 = Number<1>{};
92
93 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
94 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
95
96 __device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k,
97 const GridDesc_M_K& x_grid_desc_m_k,
98 const GridDesc_M_K& mean_grid_desc_m_k,
99 const GridDesc_M_K& inv_std_grid_desc_m_k,
100 const GridDesc_M& dgamma_grid_desc_m,
101 const GridDesc_M& dbeta_grid_desc_m,
102 index_t num_k_block_tile_iteration,
103 const DYDataType* const __restrict__ p_dy_global,
104 const XDataType* const __restrict__ p_x_global,
105 const MeanInvStdDataType* const __restrict__ p_mean_global,
106 const MeanInvStdDataType* const __restrict__ p_inv_std_global,
107 DGammaDataType* const __restrict__ p_dgamma_global,
108 DBetaDataType* const __restrict__ p_dbeta_global)
109 {
110 // LDS
111 __shared__ ComputeDataType p_reduce_work_buffer[BlockSize];
112
113 auto reduce_work_buf =
114 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
115
116 // Global
117 const auto dy_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
118 p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize());
119
120 const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
121 p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
122
123 const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
124 p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize());
125
126 const auto inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
127 p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize());
128
129 auto dgamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
130 p_dgamma_global, dgamma_grid_desc_m.GetElementSpaceSize());
131
132 auto dbeta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
133 p_dbeta_global, dbeta_grid_desc_m.GetElementSpaceSize());
134
135 // VGPR
136 auto dy_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
137 ComputeDataType,
138 MThreadSliceSize * KThreadSliceSize,
139 true>{};
140
141 auto x_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
142 ComputeDataType,
143 MThreadSliceSize * KThreadSliceSize,
144 true>{};
145
146 auto mean_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
147 ComputeDataType,
148 MThreadSliceSize * KThreadSliceSize,
149 true>{};
150
151 auto inv_std_thread_buf = StaticBuffer<AddressSpaceEnum::Vgpr,
152 ComputeDataType,
153 MThreadSliceSize * KThreadSliceSize,
154 true>{};
155
156 auto dgamma_thread_buf =
158
159 auto dbeta_thread_buf =
161
162 const index_t thread_local_id = get_thread_local_1d_id();
163 const index_t block_global_id = get_block_1d_id();
164
165 const auto thread_cluster_idx =
166 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
167
168 const auto thread_m_cluster_id = thread_cluster_idx[I0];
169 const auto thread_k_cluster_id = thread_cluster_idx[I1];
170
171 // IO
172 auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DYDataType,
173 ComputeDataType,
174 GridDesc_M_K,
175 decltype(thread_buffer_desc_m_k),
178 DYSrcVectorDim,
179 DYSrcVectorSize,
180 1,
181 true>(
182 dy_grid_desc_m_k,
183 make_multi_index(block_global_id * M_BlockTileSize +
184 thread_m_cluster_id * MThreadSliceSize,
185 thread_k_cluster_id * KThreadSliceSize));
186
187 auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
188 ComputeDataType,
189 GridDesc_M_K,
190 decltype(thread_buffer_desc_m_k),
193 XSrcVectorDim,
194 XSrcVectorSize,
195 1,
196 true>(
197 x_grid_desc_m_k,
198 make_multi_index(block_global_id * M_BlockTileSize +
199 thread_m_cluster_id * MThreadSliceSize,
200 thread_k_cluster_id * KThreadSliceSize));
201
202 auto threadwise_mean_load =
203 ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
204 ComputeDataType,
205 GridDesc_M_K,
206 decltype(thread_buffer_desc_m_k),
209 MeanInvStdSrcVectorDim,
210 MeanInvStdSrcVectorSize,
211 1,
212 true>(
213 mean_grid_desc_m_k,
214 make_multi_index(block_global_id * M_BlockTileSize +
215 thread_m_cluster_id * MThreadSliceSize,
216 thread_k_cluster_id * KThreadSliceSize));
217
218 auto threadwise_inv_std_load =
219 ThreadwiseTensorSliceTransfer_v2<MeanInvStdDataType,
220 ComputeDataType,
221 GridDesc_M_K,
222 decltype(thread_buffer_desc_m_k),
225 MeanInvStdSrcVectorDim,
226 MeanInvStdSrcVectorSize,
227 1,
228 true>(
229 inv_std_grid_desc_m_k,
230 make_multi_index(block_global_id * M_BlockTileSize +
231 thread_m_cluster_id * MThreadSliceSize,
232 thread_k_cluster_id * KThreadSliceSize));
233
234 auto threadwise_dgamma_store =
236 DGammaDataType,
237 decltype(thread_buffer_desc_m),
238 GridDesc_M,
242 0,
243 DGammaDstVectorSize,
245 1,
246 true>(
247 dgamma_grid_desc_m,
248 make_multi_index(block_global_id * M_BlockTileSize +
249 thread_m_cluster_id * MThreadSliceSize),
250 PassThroughOp{});
251
252 auto threadwise_dbeta_store =
254 DBetaDataType,
255 decltype(thread_buffer_desc_m),
256 GridDesc_M,
260 0,
261 DBetaDstVectorSize,
263 1,
264 true>(
265 dbeta_grid_desc_m,
266 make_multi_index(block_global_id * M_BlockTileSize +
267 thread_m_cluster_id * MThreadSliceSize),
268 PassThroughOp{});
269
271 dgamma_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
272 dbeta_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
273 });
274
275 constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
276
277 for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
278 {
279 threadwise_dy_load.Run(dy_grid_desc_m_k,
280 dy_global_val_buf,
282 make_tuple(I0, I0),
283 dy_thread_buf);
284
285 threadwise_x_load.Run(x_grid_desc_m_k,
286 x_global_val_buf,
288 make_tuple(I0, I0),
289 x_thread_buf);
290
291 threadwise_mean_load.Run(mean_grid_desc_m_k,
292 mean_global_val_buf,
294 make_tuple(I0, I0),
295 mean_thread_buf);
296
297 threadwise_inv_std_load.Run(inv_std_grid_desc_m_k,
298 inv_std_global_val_buf,
300 make_tuple(I0, I0),
301 inv_std_thread_buf);
302
303 threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k);
304 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
305 threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_fwd_step_m_k);
306 threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k,
307 thread_copy_fwd_step_m_k);
308
310 constexpr auto offset_m =
311 Number<thread_buffer_desc_m.CalculateOffset(make_tuple(iM))>{};
312
314 constexpr auto offset_m_k =
315 Number<thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK))>{};
316
317 dgamma_thread_buf(offset_m) +=
318 dy_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] *
319 (x_thread_buf[offset_m_k] - mean_thread_buf[offset_m_k]);
320
321 dbeta_thread_buf(offset_m) += dy_thread_buf[offset_m_k];
322 });
323 });
324 }
325
327 if constexpr(I > 0)
329
330 BlockwiseSumReduce::Reduce(reduce_work_buf, dbeta_thread_buf(I));
332 BlockwiseSumReduce::Reduce(reduce_work_buf, dgamma_thread_buf(I));
333 });
334
335 if(thread_k_cluster_id == 0)
336 {
337 threadwise_dgamma_store.Run(thread_buffer_desc_m,
338 make_tuple(I0),
339 dgamma_thread_buf,
340 dgamma_grid_desc_m,
341 dgamma_global_val_buf);
342
343 threadwise_dbeta_store.Run(thread_buffer_desc_m,
344 make_tuple(I0),
345 dbeta_thread_buf,
346 dbeta_grid_desc_m,
347 dbeta_global_val_buf);
348 }
349 }
350};
351
352} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition gridwise_normalization_bwd_gamma_beta.hpp:37
typename conditional< XSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type XThreadBufferDimAccessOrder
Definition gridwise_normalization_bwd_gamma_beta.hpp:61
typename conditional< DYSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type DYThreadBufferDimAccessOrder
Definition gridwise_normalization_bwd_gamma_beta.hpp:58
DYThreadBufferDimAccessOrder ThreadClusterArrangeOrder
Definition gridwise_normalization_bwd_gamma_beta.hpp:67
PartitionedBlockwiseReduction< ComputeDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder, reduce::Add, true > BlockwiseSumReduce
Definition gridwise_normalization_bwd_gamma_beta.hpp:83
Sequence< MThreadSliceSize, KThreadSliceSize > ThreadBufferLengths_M_K
Definition gridwise_normalization_bwd_gamma_beta.hpp:72
static constexpr auto thread_cluster_desc
Definition gridwise_normalization_bwd_gamma_beta.hpp:69
static constexpr auto I0
Definition gridwise_normalization_bwd_gamma_beta.hpp:90
tensor_operation::element_wise::PassThrough PassThroughOp
Definition gridwise_normalization_bwd_gamma_beta.hpp:81
static constexpr index_t M_BlockTileSize
Definition gridwise_normalization_bwd_gamma_beta.hpp:93
static constexpr auto I1
Definition gridwise_normalization_bwd_gamma_beta.hpp:91
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition gridwise_normalization_bwd_gamma_beta.hpp:56
typename conditional< MeanInvStdSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type MeanInvStdThreadBufferDimAccessOrder
Definition gridwise_normalization_bwd_gamma_beta.hpp:64
static constexpr index_t K_BlockTileSize
Definition gridwise_normalization_bwd_gamma_beta.hpp:94
static __device__ void Run(const GridDesc_M_K &dy_grid_desc_m_k, const GridDesc_M_K &x_grid_desc_m_k, const GridDesc_M_K &mean_grid_desc_m_k, const GridDesc_M_K &inv_std_grid_desc_m_k, const GridDesc_M &dgamma_grid_desc_m, const GridDesc_M &dbeta_grid_desc_m, index_t num_k_block_tile_iteration, const DYDataType *const __restrict__ p_dy_global, const XDataType *const __restrict__ p_x_global, const MeanInvStdDataType *const __restrict__ p_mean_global, const MeanInvStdDataType *const __restrict__ p_inv_std_global, DGammaDataType *const __restrict__ p_dgamma_global, DBetaDataType *const __restrict__ p_dbeta_global)
Definition gridwise_normalization_bwd_gamma_beta.hpp:96
Sequence< MThreadSliceSize > ThreadBufferLengths_M
Definition gridwise_normalization_bwd_gamma_beta.hpp:73
static constexpr auto thread_buffer_desc_m
Definition gridwise_normalization_bwd_gamma_beta.hpp:78
static constexpr auto thread_buffer_desc_m_k
Definition gridwise_normalization_bwd_gamma_beta.hpp:75
Definition reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, ComputeDataType &in_out_value)
Definition reduction_functions_blockwise.hpp:44
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:66
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition reduction_operator.hpp:37
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340