block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp Source File

block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp Source File#

Composable Kernel: block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp Source File
block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11namespace ck_tile {
12
13// This pipeline is qkv all located in LDS
14template <typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
16{
31
34 static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
35 static_assert(kQLoadOnce == Policy::QLoadOnce);
36
37 static constexpr index_t kBlockSize = Problem::kBlockSize;
38
39 static constexpr index_t kM0 = BlockFmhaShape::kM0;
40 static constexpr index_t kN0 = BlockFmhaShape::kN0;
41 static constexpr index_t kK0 = BlockFmhaShape::kK0;
42 static constexpr index_t kN1 = BlockFmhaShape::kN1;
43 static constexpr index_t kK1 = BlockFmhaShape::kK1;
44 static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
45 static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
46
47 static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
48
49 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
50 static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
51 static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
52 static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
53 static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
54 static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
55 static constexpr auto BiasEnum = Problem::BiasEnum;
56 static constexpr bool kStoreLSE = Problem::kStoreLSE;
57 static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
58 static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
59
60 static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
61 (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
64
65 // last dimension vector length used to create tensor view(and decide buffer_load vector length)
66 // ... together with tensor distribution. tensor dist should able to overwrite this
67 static constexpr index_t kAlignmentQ =
68 kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
69 static constexpr index_t kAlignmentK =
70 kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
71 static constexpr index_t kAlignmentV = []() {
72 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
73 return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
74 else
75 return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
76 }();
77
78 static constexpr index_t kAlignmentOacc =
79 kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>();
80
81 static constexpr index_t kAlignmentBias =
82 kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
83
84 static constexpr index_t kBlockPerCu = []() {
85 if constexpr(Problem::kBlockPerCu != -1)
86 return Problem::kBlockPerCu;
87 else
88 {
89 if constexpr(kQKHeaddim <= 32)
90 {
91 return 2;
92 }
93 else if constexpr(kQKHeaddim <= 64)
94 {
95 return 3;
96 }
97 else if constexpr(kQKHeaddim <= 128)
98 {
100 return 1;
101 else
102 return 2;
103 }
104 else if constexpr(kQKHeaddim <= 256)
105 {
106 return 1;
107 }
108 else
109 {
110 return 1;
111 }
112 }
113 }();
114
115 static constexpr const char* name = "qr";
116
118 {
119 return Policy::template GetSmemSize<Problem>();
120 }
121
122 template <typename QDramBlockWindowTmp,
123 typename KDramBlockWindowLengths,
124 typename KPageBlockNavigator,
125 typename VDramBlockWindowLengths,
126 typename VPageBlockNavigator,
127 typename BiasDramBlockWindowTmp,
128 typename LSEaccDramBlockWindowTmp,
129 typename QElementFunction,
130 typename KElementFunction,
131 typename VElementFunction,
132 typename BiasElementFunction,
133 typename LSEaccElementFunction,
134 typename SAccElementFunction,
135 typename PComputeElementFunction,
136 typename OAccElementFunction,
137 typename PositionEncoding,
138 typename AttentionVariantParams,
139 typename BlockIndices>
141 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
142 const QElementFunction& q_element_func,
143 const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
144 const KPageBlockNavigator& k_page_block_navigator,
145 const KElementFunction& k_element_func,
146 const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
147 const VPageBlockNavigator& v_page_block_navigator,
148 const VElementFunction& v_element_func,
149 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
150 const BiasElementFunction& bias_element_func,
151 LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
152 const LSEaccElementFunction& lse_acc_element_func,
153 const SAccElementFunction& s_acc_element_func,
154 const PComputeElementFunction& p_compute_element_func,
155 const OAccElementFunction& o_acc_element_func,
156 index_t num_splits,
157 index_t i_split,
158 FmhaMask mask,
159 PositionEncoding position_encoding,
160 float scale_s,
161 const AttentionVariant& variant,
162 const AttentionVariantParams& variant_params,
163 const BlockIndices& block_indices,
164 index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
165 void* smem_ptr) const
166 {
167 static_assert(
168 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
169 std::is_same_v<KDataType, remove_cvref_t<typename KPageBlockNavigator::DataType>> &&
170 std::is_same_v<VDataType, remove_cvref_t<typename VPageBlockNavigator::DataType>>,
171 "wrong!");
172
173 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
174 kN0 == KDramBlockWindowLengths{}[number<0>{}] &&
175 kK0 == KDramBlockWindowLengths{}[number<1>{}] &&
176 kN1 == VDramBlockWindowLengths{}[number<0>{}] &&
177 kK1 == VDramBlockWindowLengths{}[number<1>{}] &&
178 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
179 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
180 "wrong!");
181
182 // K tile in LDS
183 KDataType* k_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
184 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
186 k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
187 auto k_lds_window =
189
190 // V tile in LDS
192 reinterpret_cast<VDataType*>(smem_ptr),
193 Policy::template MakeVLdsBlockDescriptor<Problem>());
194 auto v_lds_window = make_tile_window(
195 v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
196
197 // Block GEMM
198 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
199 constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
200
201 auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
202 q_dram_block_window_tmp.get_window_lengths(),
203 q_dram_block_window_tmp.get_window_origin(),
204 Policy::template MakeQRegTileDistribution<Problem>());
205
206 auto q = load_tile(q_dram_window);
207
208 using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
209 auto s_acc = SaccBlockTileType{};
210
211 // reduction function for softmax
212 const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
213 const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
214
215 // infer Sacc, S, P, M, L, Oacc type
216 using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
217
218 using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
219 SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
220
221 using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
222
223 // init Oacc, M, L
224 auto o_acc = OaccBlockTileType{};
225 auto m = MLBlockTileType{};
226 auto l = MLBlockTileType{};
227
228 clear_tile(o_acc);
230 clear_tile(l);
231
232 const auto q_origin = q_dram_window.get_window_origin();
233 const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
234 q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
235
236 // check early exit if no work to do
237 if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
238 {
239 const index_t logical_num_total_loop =
240 integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0);
241 if(logical_num_total_loop <= 0)
242 {
243 if constexpr(kStoreLSE)
244 {
245 auto lse_acc =
246 make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
247
249
250 store_tile(lse_acc_dram_window_tmp,
251 tile_elementwise_in(lse_acc_element_func, lse_acc));
252 }
253
254 // Note: here occ are all cleard, return it
255 // Note: q loaded but no fence, ignore it.
256 return o_acc;
257 }
258 }
259
260 const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
261 const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
262 // make sure the first tile is completely located in page-block (page-block size should be
263 // divisible by kN0)
264 // relationship between each *_start variables: aligned_physical_seqlen_k_start <=
265 // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
266 const index_t aligned_physical_seqlen_k_start =
267 [&, physical_seqlen_k_start_ = physical_seqlen_k_start] {
268 if constexpr(kIsPagedKV)
269 {
270 return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0);
271 }
272 else
273 {
274 return physical_seqlen_k_start_;
275 }
276 }();
277 const index_t num_total_loop =
278 integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
279
280 auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
281 k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
282
283 const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
284 auto bias_dram_window =
285 make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
286 bias_dram_block_window_tmp.get_window_lengths(),
287 {bias_origin.at(number<0>{}),
288 logical_seqlen_k_start - (physical_seqlen_k_start -
289 aligned_physical_seqlen_k_start)}, // M/N
290 Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
291
292 auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
293 v_dram_block_window_lengths,
294 {0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
295 Policy::template MakeVDramTileDistribution<Problem>());
296
297 auto q_tile = tile_elementwise_in(q_element_func, q);
298
299 // prefetch K tile
300 index_t i_total_loops = 0;
301 constexpr index_t k0_loops = kQKHeaddim / kK0;
302 constexpr index_t k1_loops = kN0 / kK1;
303
304 static_assert(2 <= k0_loops);
305 static_assert(1 <= k1_loops);
306 do
307 {
308 // STAGE 1, QK gemm
309 auto k_dram_window = make_tile_window(
310 k_dram_block_window,
311 Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
312 // load
313
314 auto k_block_tile = load_tile(k_dram_window);
315 {
316 // moving k_dram_window is an in-page-block operation, so there is
317 // no need to invoke k_page_block_navigator.move_tile_window() here.
318 move_tile_window(k_dram_window, {0, kK0});
319 clear_tile(s_acc); // initialize C
320 store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
321 k_block_tile = load_tile(k_dram_window);
322 }
323 auto physical_next_block_id_k =
324 amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id(
325 i_page_block_k, k_dram_block_window, {kN0, 0}));
326 auto physical_next_block_id_v = amd_wave_read_first_lane(
327 v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1}));
328
330 {
331 __builtin_amdgcn_sched_barrier(
332 0); // prevent from messing up the order of global loads
333 }
334 const auto bias_tile = load_tile(bias_dram_window); // load bias tile
336 {
337 __builtin_amdgcn_sched_barrier(
338 0); // prevent from messing up the order of global loads
339 }
340
341 if constexpr(k0_loops > 2)
342 {
343 static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
345 gemm_0(s_acc,
346 get_slice_tile(q_tile,
347 sequence<0, i_k0 * kK0>{},
348 sequence<kM0, (i_k0 + 1) * kK0>{}),
349 k_lds_window);
351 move_tile_window(k_dram_window, {0, kK0});
352
354 k_lds_window,
355 tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
356 k_block_tile = load_tile(k_dram_window); // global read i + 2
357 });
358 }
359
360 const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
361 { // tail
363 gemm_0(s_acc,
364 get_slice_tile(q_tile,
365 sequence<0, (k0_loops - 2) * kK0>{},
366 sequence<kM0, (k0_loops - 1) * kK0>{}),
367 k_lds_window);
369
370 store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
372
373 gemm_0(s_acc,
374 get_slice_tile(q_tile,
375 sequence<0, (k0_loops - 1) * kK0>{},
376 sequence<kM0, k0_loops * kK0>{}),
377 k_lds_window);
378 }
379
380 // STAGE 2, scale_s, add bias, mask, softmax
382 {
383 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
384 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
386 [&](auto& x, const auto& y) {
387#if !CK_TILE_FMHA_FWD_FAST_EXP2
388 x += type_convert<SaccDataType>(bias_element_func(y));
389#else
391 type_convert<SaccDataType>(bias_element_func(y));
392#endif
393 },
394 s_acc,
395 bias_tile);
396 }
397 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
398 {
399 const auto k_origin = k_page_block_navigator.to_global_window_origin(
400 i_page_block_k, k_dram_block_window.get_window_origin());
401 constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
402 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
403 sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
404 sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
405 const auto tile_idx = get_x_indices_from_distributed_indices(
406 s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
407
408 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
409 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
410 constexpr auto i_j_idx = make_tuple(idx0, idx1);
411
412 s_acc(i_j_idx) *= scale_s;
413 // position_encoding accept only logical coordinates, do conversion here
414 position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset);
415 });
416 });
417 }
418 else
419 {
420 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
421 if constexpr(kHasLogitsSoftCap)
422 {
423 auto apply_logits_transform =
424 [&variant, &variant_params, &block_indices](auto& x) {
425 x = variant.LogitsTransform(variant_params,
426 variant.QueryTransform(variant_params, x),
427 block_indices.batch_idx,
428 block_indices.qo_head_idx,
429 block_indices.kv_head_idx);
430 };
431#if !CK_TILE_FMHA_FWD_FAST_EXP2
432 tile_elementwise_inout(apply_logits_transform, s_acc);
433#else
434 tile_elementwise_inout(apply_logits_transform, s_acc);
435#endif
436 }
437 else
438 {
439#if !CK_TILE_FMHA_FWD_FAST_EXP2
440 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
441#endif
442 }
443 }
444 move_tile_window(bias_dram_window, {0, kN0});
445
447 if constexpr(kHasUnevenSplits)
448 {
449 const auto k_origin = k_page_block_navigator.to_global_window_origin(
450 i_page_block_k, k_dram_block_window.get_window_origin());
452 s_acc,
454 [&,
455 physical_seqlen_k_start_ = physical_seqlen_k_start,
456 physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
457 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
458 if constexpr(kIsPagedKV)
459 {
460 return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col;
461 }
462 else
463 {
464 return physical_seqlen_k_end_ <= col;
465 }
466 });
467 }
468
469 if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
470 {
471 const auto k_origin = k_page_block_navigator.to_global_window_origin(
472 i_page_block_k, k_dram_block_window.get_window_origin());
473 // mask accept only logical coordinates, do conversion here
474 bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
475 k_origin.at(number<0>{}) - kv_l2p_offset,
476 number<kM0>{},
477 number<kN0>{});
478 if(need_perpixel_check)
479 {
481 s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
482 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
483 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
484 return mask.IsOutOfBound(row, col - kv_l2p_offset);
485 });
486 }
487 }
488
489 const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
491 s,
492 sequence<1>{},
493 f_max,
494 -numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
496
497 const auto m_old = m; // m{j-1}
499 [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
500
502 s.get_tile_distribution()); // Pcompute{j}
503
504 static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
508 FmhaMask::IsMasking)
509 {
512 : raw_m;
513 }
514 else
515 {
516 return raw_m;
517 }
518 };
519
520 constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
521 sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
522 constexpr auto i_idx = make_tuple(idx0);
523#if CK_TILE_FMHA_FWD_FAST_EXP2
524 auto row_max = scale_s * get_validated_m(m[i_idx]);
525#endif
526 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
527 constexpr auto i_j_idx = make_tuple(idx0, idx1);
528#if CK_TILE_FMHA_FWD_FAST_EXP2
531 {
532 p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
533 }
534 else
535 {
536 if constexpr(kHasLogitsSoftCap)
537 {
538 p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
539 }
540 else
541 {
542 p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
543 }
544 }
545#else
546 p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
547#endif
548 });
549 });
550
552 p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
553
555 // l{j}, Oacc{j}
556 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
557 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
558 constexpr auto i_idx = make_tuple(idx0);
559#if CK_TILE_FMHA_FWD_FAST_EXP2
560 const auto tmp = [&]() {
563 {
564 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
565 }
566 else
567 {
568 if constexpr(kHasLogitsSoftCap)
569 {
570
571 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
572 }
573 else
574 {
575 auto row_max = scale_s * get_validated_m(m[i_idx]);
576 return exp2(scale_s * m_old[i_idx] - row_max);
577 }
578 }
579 }();
580#else
581 const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
582#endif
583 l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
584 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
585 constexpr auto i_j_idx = make_tuple(idx0, idx1);
586 // FIXME: this use different equation from FA v2 paper,
587 // but produce correc result.
588 // Is the equation wrong?
589 o_acc(i_j_idx) *= tmp;
590 });
591 });
592
594 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
595 {
597 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
598 shuffle_tile(v_shuffle_tmp, v_prefetch);
600 v_lds_window,
601 tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
602 }
603 else
604 {
605 store_tile(v_lds_window,
606 tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
607 }
608 i_page_block_v = v_page_block_navigator.move_tile_window(
609 i_page_block_v, v_dram_window, {0, kK1}, physical_next_block_id_v);
610
611 const auto p =
612 cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
613
614 // STAGE 3, KV gemm
615 if constexpr(k1_loops > 1)
616 {
617 static_for<0, k1_loops - 1, 1>{}([&,
618 &i_page_block_v_ = i_page_block_v,
619 &v_dram_window_ = v_dram_window](auto i_k1) {
620 auto physical_next_block_id_v_ =
621 amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id(
622 i_page_block_v_, v_dram_window_, {0, kK1}));
623 const auto v = load_tile(v_dram_window_); // load next v
625 gemm_1(o_acc,
627 p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
628 v_lds_window);
630 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
631 {
633 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
634 shuffle_tile(v_shuffle_tmp, v);
635 store_tile(v_lds_window,
636 tile_elementwise_in(v_element_func,
637 v_shuffle_tmp)); // store the prefetch
638 }
639 else
640 {
641 store_tile(v_lds_window,
642 tile_elementwise_in(v_element_func, v)); // store next v
643 }
644 i_page_block_v_ = v_page_block_navigator.move_tile_window(
645 i_page_block_v_, v_dram_window_, {0, kK1}, physical_next_block_id_v_);
646 });
647 }
648 // move K tile windows
649 i_page_block_k = k_page_block_navigator.move_tile_window(
650 i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k);
651 // tail
652 {
654 gemm_1(o_acc,
655 get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
656 v_lds_window);
658 }
659 } while(++i_total_loops < num_total_loop);
660
661 if constexpr(kStoreLSE)
662 {
663 // store lse acc
664 auto lse_acc = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
665
666 constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
667 sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
668 constexpr auto i_idx = make_tuple(idx0);
669#if CK_TILE_FMHA_FWD_FAST_EXP2
672 {
673 lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
674 }
675 else
676 {
677 if constexpr(kHasLogitsSoftCap)
678 {
679 lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
680 }
681 else
682 {
683 lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
684 }
685 }
686#else
687 lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
688#endif
689 });
690
691 store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc));
692 }
693
694 // finally, O
695 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
696
697 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
698 constexpr auto i_idx = make_tuple(idx0);
699 const auto tmp = [&]() {
701 FmhaMask::IsMasking)
702 {
703 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
704 }
705 else
706 return 1 / l[i_idx];
707 }();
708 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
709 constexpr auto i_j_idx = make_tuple(idx0, idx1);
710 o_acc(i_j_idx) *= tmp;
711 });
712 });
713
714 o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
715
716 return o_acc;
717 }
718
719 template <typename QDramBlockWindowTmp,
720 typename KDramBlockWindowLengths,
721 typename KPageBlockNavigator,
722 typename VDramBlockWindowLengths,
723 typename VPageBlockNavigator,
724 typename BiasDramBlockWindowTmp,
725 typename LSEaccDramBlockWindowTmp,
726 typename PositionEncoding,
727 typename AttentionVariantParams,
728 typename BlockIndices>
730 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
731 const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
732 const KPageBlockNavigator& k_page_block_navigator,
733 const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
734 const VPageBlockNavigator& v_page_block_navigator,
735 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
736 LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
737 index_t num_splits,
738 index_t i_split,
739 FmhaMask mask,
740 PositionEncoding position_encoding,
741 float scale_s,
742 const AttentionVariant& variant,
743 const AttentionVariantParams& variant_params,
744 const BlockIndices& block_indices,
745 index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
746 void* smem_ptr) const
747 {
748 return operator()(q_dram_block_window_tmp,
749 identity{},
750 k_dram_block_window_lengths,
751 k_page_block_navigator,
752 identity{},
753 v_dram_block_window_lengths,
754 v_page_block_navigator,
755 identity{},
756 bias_dram_block_window_tmp,
757 identity{},
758 lse_acc_dram_block_window_tmp,
759 identity{},
760 identity{},
761 identity{},
762 identity{},
763 num_splits,
764 i_split,
765 mask,
766 position_encoding,
767 scale_s,
768 variant,
769 variant_params,
770 block_indices,
771 kv_l2p_offset,
772 smem_ptr);
773 }
774};
775
776} // namespace ck_tile
#define CK_TILE_FMHA_FWD_FAST_EXP2
Definition config.hpp:234
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto integer_divide_floor(X x, Y y)
Definition tile/core/numeric/math.hpp:143
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ NO_BIAS
Definition block_attention_bias_enum.hpp:13
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:16
static constexpr index_t kM0
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:39
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowLengths &k_dram_block_window_lengths, const KPageBlockNavigator &k_page_block_navigator, const KElementFunction &k_element_func, const VDramBlockWindowLengths &v_dram_block_window_lengths, const VPageBlockNavigator &v_page_block_navigator, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, LSEaccDramBlockWindowTmp &lse_acc_dram_window_tmp, const LSEaccElementFunction &lse_acc_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, index_t num_splits, index_t i_split, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, index_t kv_l2p_offset, void *smem_ptr) const
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:141
static constexpr bool kPadHeadDimQ
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:52
static constexpr index_t kN0
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:40
remove_cvref_t< Policy_ > Policy
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:18
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:24
remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:29
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:19
static constexpr bool kIsGroupMode
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:49
remove_cvref_t< typename Problem::PDataType > PDataType
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:26
static constexpr index_t kK0
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:41
static constexpr bool kHasUnevenSplits
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:58
remove_cvref_t< Problem_ > Problem
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:17
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:27
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:21
static constexpr index_t kBlockPerCu
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:84
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:23
static constexpr bool kQLoadOnce
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:34
static constexpr index_t kN1
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:42
static constexpr bool kPadSeqLenK
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:51
static constexpr index_t kQKHeaddim
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:44
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:30
static constexpr index_t kAlignmentQ
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:67
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:117
static constexpr bool kStoreLSE
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:56
static constexpr bool kIsPagedKV
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:57
static constexpr index_t kAlignmentBias
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:81
static constexpr index_t kSubQKHeaddim
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:45
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:32
static constexpr index_t kAlignmentK
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:69
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:28
static constexpr bool kHasLogitsSoftCap
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:54
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:22
static constexpr auto BiasEnum
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:55
static constexpr index_t kAlignmentV
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:71
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:33
static constexpr bool kPadSeqLenQ
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:50
static constexpr index_t kBlockSize
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:37
static constexpr index_t kK1
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:43
static constexpr bool kPadHeadDimV
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:53
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:25
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:20
static constexpr const char * name
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:115
static constexpr index_t kAlignmentOacc
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:78
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowLengths &k_dram_block_window_lengths, const KPageBlockNavigator &k_page_block_navigator, const VDramBlockWindowLengths &v_dram_block_window_lengths, const VPageBlockNavigator &v_page_block_navigator, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, LSEaccDramBlockWindowTmp &lse_acc_dram_block_window_tmp, index_t num_splits, index_t i_split, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, index_t kv_l2p_offset, void *smem_ptr) const
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:730
Definition tile/core/utility/functional.hpp:86
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
#define C_LOG2E
Definition tile/core/numeric/math.hpp:469