block_gemm_areg_bsmem_creg_v2.hpp Source File

block_gemm_areg_bsmem_creg_v2.hpp Source File#

Composable Kernel: block_gemm_areg_bsmem_creg_v2.hpp Source File
block_gemm_areg_bsmem_creg_v2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// A is block distributed tensor
12// B is block window on shared memory
13// C is block distributed tensor
14template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
16{
23
24 static constexpr index_t kBlockSize = Problem::kBlockSize;
25
26 // C += A * B
27 template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
28 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
29 const ABlockTensorTmp& a_block_tensor_tmp,
30 const BBlockWindowTmp& b_block_window_tmp) const
31 {
32 static_assert(
33 std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
34 std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
35 std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
36 "wrong!");
37
38 constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
39 constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
40 constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
41
42 static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
43 KPerBlock == BlockGemmShape::kK,
44 "wrong!");
45
46 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
47
48 using WG = remove_cvref_t<decltype(config.template at<0>())>;
49
50 constexpr index_t MWarp = config.template at<1>();
51 constexpr index_t NWarp = config.template at<2>();
52
53 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
54 constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
55 constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
56
57 constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
58 constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
59
60 const index_t iNWarp = get_warp_id() % NWarp;
61
62 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
69
70 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
71 c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
72
73 // constrcut from A-block-tensor from A-Block-tensor-tmp
74 // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
75 // distribution
78
79 a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
80
81 // construct B-warp-window
82 auto b_warp_window_tmp = make_tile_window(
83 b_block_window_tmp.get_bottom_tensor_view(),
85 b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
86 make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
87
88#if 0 // FIXME: using array will cause register spill
89 array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
90 {b_warp_window_tmp}};
91
92 for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
93 {
94 for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
95 {
96 move_tile_window(b_warp_windows(nIter)(kIter),
97 {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
98 }
99 }
100#else
102 statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
103 NIterPerWarp>
104 b_warp_windows;
105
106 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
107 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
108 b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
109
110 move_tile_window(b_warp_windows(nIter)(kIter),
111 {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
112 });
113 });
114#endif
115
116 // check C-block-distribution
117 static_assert(
118 std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
119 remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
120 .get_static_tile_distribution_encoding())>>,
121 "wrong!");
122
123 using AWarpDstr = typename WG::AWarpDstr;
124 using CWarpDstr = typename WG::CWarpDstr;
125
126 using AWarpTensor = typename WG::AWarpTensor;
127 using CWarpTensor = typename WG::CWarpTensor;
128
129 constexpr auto a_warp_y_lengths =
130 to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
131 constexpr auto c_warp_y_lengths =
132 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
133
134 constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
135 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
136
137 // hot loop:
138 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
139 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
140 // read B warp tensor from B Block window
141 const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
142
143 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
144 // read A warp tensor from A block tensor
145 AWarpTensor a_warp_tensor;
146
147 a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
148 merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
149 merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
150
151 // read C warp tensor from C block tensor
152 CWarpTensor c_warp_tensor;
153
154 c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
155 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
156 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
157
158 // warp GEMM
159 WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
160 // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
161
162 // write C warp tensor into C block tensor
163 c_block_tensor.set_y_sliced_thread_data(
164 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
165 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
166 c_warp_tensor.get_thread_buffer());
167 });
168 });
169 });
170 }
171
172 template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
174 {
175 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
176
177 using WG = remove_cvref_t<decltype(config.template at<0>())>;
178
179 constexpr index_t MWarp = config.template at<1>();
180 constexpr index_t NWarp = config.template at<2>();
181
182 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
183 constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
184
185 constexpr auto a_block_outer_dstr_encoding =
192
193 constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
194 a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
195
196 return make_static_tile_distribution(a_block_dstr_encode);
197 }
198
199 CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
200 {
201 constexpr index_t MPerBlock = BlockGemmShape::kM;
202 constexpr index_t NPerBlock = BlockGemmShape::kN;
203
204 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
205
206 using WG = remove_cvref_t<decltype(config.template at<0>())>;
207
208 constexpr index_t MWarp = config.template at<1>();
209 constexpr index_t NWarp = config.template at<2>();
210
211 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
212 constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
213 // constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
214
215 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
222
223 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
224 c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
225 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
226 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
227 return c_block_tensor;
228 }
229
230 // C = A * B
231 template <typename ABlockTensorTmp, typename BBlockWindowTmp>
232 CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
233 const BBlockWindowTmp& b_block_window_tmp) const
234 {
235 auto c_block_tensor = MakeCBlockTile();
236 operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
237 return c_block_tensor;
238 }
239};
240
241} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
array< index_t, N > multi_index
Definition tile/core/container/multi_index.hpp:17
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
Definition block_gemm_areg_bsmem_creg_v2.hpp:16
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition block_gemm_areg_bsmem_creg_v2.hpp:22
static CK_TILE_DEVICE constexpr auto MakeABlockTileDistribution()
Definition block_gemm_areg_bsmem_creg_v2.hpp:173
remove_cvref_t< Policy_ > Policy
Definition block_gemm_areg_bsmem_creg_v2.hpp:18
remove_cvref_t< typename Problem::ADataType > ADataType
Definition block_gemm_areg_bsmem_creg_v2.hpp:19
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition block_gemm_areg_bsmem_creg_v2.hpp:199
remove_cvref_t< typename Problem::BDataType > BDataType
Definition block_gemm_areg_bsmem_creg_v2.hpp:20
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, const ABlockTensorTmp &a_block_tensor_tmp, const BBlockWindowTmp &b_block_window_tmp) const
Definition block_gemm_areg_bsmem_creg_v2.hpp:28
static constexpr index_t kBlockSize
Definition block_gemm_areg_bsmem_creg_v2.hpp:24
remove_cvref_t< typename Problem::CDataType > CDataType
Definition block_gemm_areg_bsmem_creg_v2.hpp:21
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp &a_block_tensor_tmp, const BBlockWindowTmp &b_block_window_tmp) const
Definition block_gemm_areg_bsmem_creg_v2.hpp:232
remove_cvref_t< Problem_ > Problem
Definition block_gemm_areg_bsmem_creg_v2.hpp:17
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192