dynamic_buffer.hpp Source File

dynamic_buffer.hpp Source File#

Composable Kernel: dynamic_buffer.hpp Source File
dynamic_buffer.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck/ck.hpp"
8#include "enable_if.hpp"
10#if __clang_major__ >= 20
12#else
14#endif
17
18namespace ck {
19
20// T may be scalar or vector
21// X may be scalar or vector
22// T and X have same scalar type
23// X contains multiple T
24template <AddressSpaceEnum BufferAddressSpace,
25 typename T,
26 typename ElementSpaceSize,
27 bool InvalidElementUseNumericalZeroValue,
29 typename IndexType = index_t>
31{
32 using type = T;
33
35 ElementSpaceSize element_space_size_;
37
38 // XXX: PackedSize semantics for pk_i4_t is different from the other packed types.
39 // Objects of f4x2_pk_t and f6_pk_t are counted as 1 element, while
40 // objects of pk_i4_t are counted as 2 elements. Therefore, element_space_size_ for pk_i4_t must
41 // be divided by 2 to correctly represent the number of addressable elements.
42 static constexpr index_t PackedSize = []() {
44 return 2;
45 else
46 return 1;
47 }();
48
49 __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
50 : p_data_{p_data}, element_space_size_{element_space_size}
51 {
52 }
53
54 __host__ __device__ constexpr DynamicBuffer(T* p_data,
55 ElementSpaceSize element_space_size,
56 T invalid_element_value)
57 : p_data_{p_data},
58 element_space_size_{element_space_size},
59 invalid_element_value_{invalid_element_value}
60 {
61 }
62
63 __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace()
64 {
65 return BufferAddressSpace;
66 }
67
68 __host__ __device__ constexpr const T& operator[](IndexType i) const { return p_data_[i]; }
69
70 __host__ __device__ constexpr T& operator()(IndexType i) { return p_data_[i]; }
71
72 template <typename X,
73 bool DoTranspose = false,
77 bool>::type = false>
78 __host__ __device__ constexpr auto Get(IndexType i, bool is_valid_element) const
79 {
80 // X contains multiple T
81 constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
82
83 constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
84
85 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
86 "wrong! X should contain multiple T");
87
88#if CK_USE_AMD_BUFFER_LOAD
89 bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t);
90#else
91 bool constexpr use_amd_buffer_addressing = false;
92#endif
93
94 if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing &&
95 !DoTranspose)
96 {
97 constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
98
99 if constexpr(InvalidElementUseNumericalZeroValue)
100 {
102 t_per_x,
103 coherence>(
104 p_data_, i, is_valid_element, element_space_size_ / PackedSize);
105 }
106 else
107 {
109 t_per_x,
110 coherence>(
111 p_data_,
112 i,
113 is_valid_element,
116 }
117 }
118 else if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && DoTranspose)
119 {
120#ifdef __gfx12__
121 return amd_global_load_transpose_to_vgpr(p_data_ + i);
122#else
123 static_assert(!DoTranspose, "load-with-transpose only supported on gfx12+");
124#endif
125 }
126 else
127 {
128 if(is_valid_element)
129 {
130#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
131 X tmp;
132
133 __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
134
135 return tmp;
136#else
138#endif
139 }
140 else
141 {
142 if constexpr(InvalidElementUseNumericalZeroValue)
143 {
144 return X{0};
145 }
146 else
147 {
148 return X{invalid_element_value_};
149 }
150 }
151 }
152 }
153
154 template <InMemoryDataOperationEnum Op,
155 typename X,
158 bool>::type = false>
159 __host__ __device__ void Update(IndexType i, bool is_valid_element, const X& x)
160 {
161 if constexpr(Op == InMemoryDataOperationEnum::Set)
162 {
163 this->template Set<X>(i, is_valid_element, x);
164 }
165 else if constexpr(Op == InMemoryDataOperationEnum::AtomicAdd)
166 {
167 this->template AtomicAdd<X>(i, is_valid_element, x);
168 }
169 else if constexpr(Op == InMemoryDataOperationEnum::AtomicMax)
170 {
171 this->template AtomicMax<X>(i, is_valid_element, x);
172 }
173 else if constexpr(Op == InMemoryDataOperationEnum::Add)
174 {
175 auto tmp = this->template Get<X>(i, is_valid_element);
176 using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
177 // handle bfloat addition
178 if constexpr(is_same_v<scalar_t, bhalf_t>)
179 {
180 if constexpr(is_scalar_type<X>::value)
181 {
182 // Scalar type
183 auto result =
185 this->template Set<X>(i, is_valid_element, result);
186 }
187 else
188 {
189 // Vector type
190 constexpr auto vector_size = scalar_type<remove_cvref_t<X>>::vector_size;
191 const vector_type<scalar_t, vector_size> a_vector{tmp};
192 const vector_type<scalar_t, vector_size> b_vector{x};
193 static_for<0, vector_size, 1>{}([&](auto idx) {
194 auto result = type_convert<scalar_t>(
195 type_convert<float>(a_vector.template AsType<scalar_t>()[idx]) +
196 type_convert<float>(b_vector.template AsType<scalar_t>()[idx]));
197 this->template Set<scalar_t>(i + idx, is_valid_element, result);
198 });
199 }
200 }
201 else
202 {
203 this->template Set<X>(i, is_valid_element, x + tmp);
204 }
205 }
206 }
207
208 template <typename DstBuffer, index_t NumElemsPerThread>
209 __host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf,
210 IndexType src_offset,
211 IndexType dst_offset,
212 bool is_valid_element) const
213 {
214 // Copy data from global to LDS memory using direct loads.
215 static_assert(GetAddressSpace() == AddressSpaceEnum::Global,
216 "Source data must come from a global memory buffer.");
217 static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
218 "Destination data must be stored in an LDS memory buffer.");
219
221 src_offset,
222 dst_buf.p_data_,
223 dst_offset,
224 is_valid_element,
226 }
227
228 template <typename X,
232 bool>::type = false>
233 __host__ __device__ void Set(IndexType i, bool is_valid_element, const X& x)
234 {
235 // X contains multiple T
236 constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
237
238 constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
239
240 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
241 "wrong! X should contain multiple T");
242
243#if CK_USE_AMD_BUFFER_LOAD
244 bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t);
245#else
246 bool constexpr use_amd_buffer_addressing = false;
247#endif
248
249#if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
250 bool constexpr workaround_int8_ds_write_issue = true;
251#else
252 bool constexpr workaround_int8_ds_write_issue = false;
253#endif
254
255 if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
256 {
257 constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
258
259 amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
260 x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
261 }
262 else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
264 workaround_int8_ds_write_issue)
265 {
266 if(is_valid_element)
267 {
268 // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
269 // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
270 // ds_write_b128
271 // TODO: remove this after compiler fix
272 static_assert((is_same<remove_cvref_t<T>, int8_t>::value &&
288 "wrong! not implemented for this combination, please add "
289 "implementation");
290
293 {
294 // HACK: cast pointer of x is bad
295 // TODO: remove this after compiler fix
298 }
299 else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
301 {
302 // HACK: cast pointer of x is bad
303 // TODO: remove this after compiler fix
306 }
307 else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
309 {
310 // HACK: cast pointer of x is bad
311 // TODO: remove this after compiler fix
314 }
315 else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
317 {
318 // HACK: cast pointer of x is bad
319 // TODO: remove this after compiler fix
322 }
323 else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
325 {
326 // HACK: cast pointer of x is bad
327 // TODO: remove this after compiler fix
330 }
331 else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
333 {
334 // HACK: cast pointer of x is bad
335 // TODO: remove this after compiler fix
338 }
339 else if constexpr(is_same<remove_cvref_t<T>, int8x8_t>::value &&
341 {
342 // HACK: cast pointer of x is bad
343 // TODO: remove this after compiler fix
346 }
347 else if constexpr(is_same<remove_cvref_t<T>, int8x16_t>::value &&
349 {
350 // HACK: cast pointer of x is bad
351 // TODO: remove this after compiler fix
354 }
355 }
356 }
357 else
358 {
359 if(is_valid_element)
360 {
361#if 0
362 X tmp = x;
363
364 __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
365#else
366 // if(i >= 2169041600)
368#endif
369 }
370 }
371 }
372
373 template <typename X,
376 bool>::type = false>
377 __host__ __device__ void AtomicAdd(IndexType i, bool is_valid_element, const X& x)
378 {
379 using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
380
381 // X contains multiple T
382 constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
383
384 constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
385
386 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
387 "wrong! X should contain multiple T");
388
389 static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem");
390
391#if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
392 bool constexpr use_amd_buffer_addressing =
395 (is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
396 (is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
397#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
398 bool constexpr use_amd_buffer_addressing =
399 sizeof(IndexType) <= sizeof(int32_t) && is_same_v<remove_cvref_t<scalar_t>, int32_t>;
400#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
401 bool constexpr use_amd_buffer_addressing =
402 sizeof(IndexType) <= sizeof(int32_t) &&
404 (is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
405 (is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0));
406#else
407 bool constexpr use_amd_buffer_addressing = false;
408#endif
409
410 if constexpr(use_amd_buffer_addressing)
411 {
412 constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
413
415 x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
416 }
417 else
418 {
419 if(is_valid_element)
420 {
422 }
423 }
424 }
425
426 template <typename X,
429 bool>::type = false>
430 __host__ __device__ void AtomicMax(IndexType i, bool is_valid_element, const X& x)
431 {
432 // X contains multiple T
433 constexpr IndexType scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
434
435 constexpr IndexType scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
436
437 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
438 "wrong! X should contain multiple T");
439
440 static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem");
441
442#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
443 using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
444 bool constexpr use_amd_buffer_addressing =
445 sizeof(IndexType) <= sizeof(int32_t) && is_same_v<remove_cvref_t<scalar_t>, double>;
446#else
447 bool constexpr use_amd_buffer_addressing = false;
448#endif
449
450 if constexpr(use_amd_buffer_addressing)
451 {
452 constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
453
455 x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
456 }
457 else if(is_valid_element)
458 {
460 }
461 }
462
463 __host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
464
465 __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
466};
467
468template <AddressSpaceEnum BufferAddressSpace,
470 typename T,
471 typename ElementSpaceSize>
472__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
473{
475 p, element_space_size};
476}
477
478template <AddressSpaceEnum BufferAddressSpace,
480 typename T,
481 typename ElementSpaceSize>
482__host__ __device__ constexpr auto make_long_dynamic_buffer(T* p,
483 ElementSpaceSize element_space_size)
484{
486 p, element_space_size};
487}
488
489template <
490 AddressSpaceEnum BufferAddressSpace,
492 typename T,
493 typename ElementSpaceSize,
494 typename X,
496__host__ __device__ constexpr auto
497make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value)
498{
500 p, element_space_size, invalid_element_value};
501}
502
503} // namespace ck
Definition ck.hpp:268
ushort bhalf_t
Definition data_type.hpp:30
__device__ void amd_buffer_store(const typename vector_type_maker< T, N >::type::type src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size)
Definition utility/amd_buffer_addressing.hpp:894
__device__ void amd_direct_load_global_to_lds(const T *global_base_ptr, const index_t global_offset, T *lds_base_ptr, const index_t lds_offset, const bool is_valid, const index_t src_element_space_size)
Definition utility/amd_buffer_addressing.hpp:1015
__device__ void amd_buffer_atomic_max(const typename vector_type_maker< T, N >::type::type src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size)
Definition utility/amd_buffer_addressing.hpp:974
int32_t index_t
Definition ck.hpp:299
typename vector_type< int8_t, 8 >::type int8x8_t
Definition dtype_vector.hpp:2178
AmdBufferCoherenceEnum
Definition utility/amd_buffer_addressing.hpp:295
@ DefaultCoherence
Definition utility/amd_buffer_addressing.hpp:296
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicMax
Definition ck.hpp:280
@ AtomicAdd
Definition ck.hpp:279
@ Add
Definition ck.hpp:281
constexpr bool is_native_type()
Definition data_type.hpp:203
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
typename vector_type< int8_t, 4 >::type int8x4_t
Definition dtype_vector.hpp:2177
_Float16 half_t
Definition data_type.hpp:31
__device__ vector_type_maker< T, N >::type::type amd_buffer_load_invalid_element_return_customized_value(const T *p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, index_t src_element_space_size, T customized_value)
Definition utility/amd_buffer_addressing.hpp:865
AddressSpaceEnum
Definition amd_address_space.hpp:15
@ Lds
Definition amd_address_space.hpp:18
@ Global
Definition amd_address_space.hpp:17
__host__ __device__ PY c_style_pointer_cast(PX p_x)
Definition c_style_pointer_cast.hpp:15
typename vector_type< int8_t, 16 >::type int8x16_t
Definition dtype_vector.hpp:2179
__device__ X atomic_max(X *p_dst, const X &x)
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__device__ void amd_buffer_atomic_add(const typename vector_type_maker< T, N >::type::type src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size)
Definition utility/amd_buffer_addressing.hpp:928
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
__host__ __device__ constexpr auto make_long_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:482
constexpr bool is_same_v
Definition type.hpp:283
__device__ vector_type_maker< T, N >::type::type amd_buffer_load_invalid_element_return_zero(const T *p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, index_t src_element_space_size)
Definition utility/amd_buffer_addressing.hpp:829
__device__ X atomic_add(X *p_dst, const X &x)
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
typename vector_type< int8_t, 2 >::type int8x2_t
Definition dtype_vector.hpp:2176
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
signed int int32_t
Definition stdint.h:123
signed char int8_t
Definition stdint.h:121
Definition dynamic_buffer.hpp:31
ElementSpaceSize element_space_size_
Definition dynamic_buffer.hpp:35
__host__ __device__ constexpr DynamicBuffer(T *p_data, ElementSpaceSize element_space_size, T invalid_element_value)
Definition dynamic_buffer.hpp:54
T invalid_element_value_
Definition dynamic_buffer.hpp:36
static constexpr index_t PackedSize
Definition dynamic_buffer.hpp:42
__host__ __device__ void Update(IndexType i, bool is_valid_element, const X &x)
Definition dynamic_buffer.hpp:159
__host__ static __device__ constexpr AddressSpaceEnum GetAddressSpace()
Definition dynamic_buffer.hpp:63
T * p_data_
Definition dynamic_buffer.hpp:34
__host__ __device__ void AtomicAdd(IndexType i, bool is_valid_element, const X &x)
Definition dynamic_buffer.hpp:377
__host__ __device__ constexpr const T & operator[](IndexType i) const
Definition dynamic_buffer.hpp:68
__host__ __device__ constexpr auto Get(IndexType i, bool is_valid_element) const
Definition dynamic_buffer.hpp:78
__host__ static __device__ constexpr bool IsStaticBuffer()
Definition dynamic_buffer.hpp:463
__host__ __device__ void DirectCopyToLds(DstBuffer &dst_buf, IndexType src_offset, IndexType dst_offset, bool is_valid_element) const
Definition dynamic_buffer.hpp:209
__host__ __device__ void Set(IndexType i, bool is_valid_element, const X &x)
Definition dynamic_buffer.hpp:233
__host__ static __device__ constexpr bool IsDynamicBuffer()
Definition dynamic_buffer.hpp:465
T type
Definition dynamic_buffer.hpp:32
__host__ __device__ constexpr DynamicBuffer(T *p_data, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:49
__host__ __device__ constexpr T & operator()(IndexType i)
Definition dynamic_buffer.hpp:70
__host__ __device__ void AtomicMax(IndexType i, bool is_valid_element, const X &x)
Definition dynamic_buffer.hpp:430
Definition type.hpp:177
static constexpr bool value
Definition data_type.hpp:219
Definition data_type.hpp:187
Definition data_type.hpp:39
Definition functional2.hpp:33
Definition dtype_vector.hpp:10