device_moe_mx_gemm_bns.hpp Source File

device_moe_mx_gemm_bns.hpp Source File#

Composable Kernel: device_moe_mx_gemm_bns.hpp Source File
device_moe_mx_gemm_bns.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <typename ALayout,
25 typename BLayout,
26 typename DsLayout,
27 typename CLayout,
28 typename ADataType,
29 typename AScaleDataType,
30 typename BDataType,
31 typename BScaleDataType,
32 typename DsDataType,
33 typename CDataType,
34 typename GemmAccDataType,
35 typename CShuffleDataType,
36 typename AElementwiseOperation,
37 typename BElementwiseOperation,
38 typename CElementwiseOperation,
39 GemmSpecialization GemmSpec,
40 index_t ScaleBlockSize,
41 index_t BlockSize,
42 index_t MPerBlock,
43 index_t NPerBlock,
44 index_t KPerBlock,
45 index_t AK1,
46 index_t BK1,
47 index_t MPerXDL,
48 index_t NPerXDL,
49 index_t MXdlPerWave,
50 index_t NXdlPerWave,
51 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
52 typename ABlockTransferThreadClusterArrangeOrder,
53 typename ABlockTransferSrcAccessOrder,
54 index_t ABlockTransferSrcVectorDim,
55 index_t ABlockTransferSrcScalarPerVector,
56 index_t ABlockTransferDstScalarPerVector_AK1,
57 bool ABlockLdsExtraM,
58 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
59 typename BBlockTransferThreadClusterArrangeOrder,
60 typename BBlockTransferSrcAccessOrder,
61 index_t BBlockTransferSrcVectorDim,
62 index_t BBlockTransferSrcScalarPerVector,
63 index_t BBlockTransferDstScalarPerVector_BK1,
64 bool BBlockLdsExtraN,
65 index_t CShuffleMXdlPerWavePerShuffle,
66 index_t CShuffleNXdlPerWavePerShuffle,
67 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
68 typename CDEShuffleBlockTransferScalarPerVectors,
71 index_t ActivationOP = 0,
72 bool NSwizzle = false,
73 bool IsInputGemm = true,
74 bool MulRoutedWeight = true,
75 typename IndexType = index_t,
76 typename ComputeTypeA = ADataType,
77 typename ComputeTypeB = BDataType>
79 BLayout,
80 DsLayout,
81 CLayout,
82 ADataType,
83 AScaleDataType,
84 BDataType,
85 BScaleDataType,
86 DsDataType,
87 CDataType,
88 ScaleBlockSize,
89 AElementwiseOperation,
90 BElementwiseOperation,
91 CElementwiseOperation>
92{
94 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
95 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
96 static constexpr index_t NumDTensor = DsDataType::Size();
97 template <index_t NXdlPerWave_>
100 BLayout,
101 DsLayout,
102 CLayout,
103 ADataType,
104 AScaleDataType,
105 BDataType,
106 BScaleDataType,
107 GemmAccDataType,
108 CShuffleDataType,
109 DsDataType,
110 CDataType,
111 AElementwiseOperation,
112 BElementwiseOperation,
113 CElementwiseOperation,
114 GemmSpec,
115 ScaleBlockSize,
116 BlockSize,
117 MPerBlock,
118 NPerBlock,
119 KPerBlock,
120 AK1,
121 BK1,
122 MPerXDL,
123 NPerXDL,
124 MXdlPerWave,
125 NXdlPerWave_,
126 ABlockTransferThreadClusterLengths_AK0_M_AK1,
127 ABlockTransferThreadClusterArrangeOrder,
128 ABlockTransferSrcAccessOrder,
129 ABlockTransferSrcVectorDim,
130 ABlockTransferSrcScalarPerVector,
131 ABlockTransferDstScalarPerVector_AK1,
132 false,
133 ABlockLdsExtraM,
134 BBlockTransferThreadClusterLengths_BK0_N_BK1,
135 BBlockTransferThreadClusterArrangeOrder,
136 BBlockTransferSrcAccessOrder,
137 BBlockTransferSrcVectorDim,
138 BBlockTransferSrcScalarPerVector,
139 BBlockTransferDstScalarPerVector_BK1,
140 false,
141 BBlockLdsExtraN,
142 CShuffleMXdlPerWavePerShuffle,
143 CShuffleNXdlPerWavePerShuffle,
144 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
145 CDEShuffleBlockTransferScalarPerVectors,
146 BlkGemmPipeSched,
147 BlkGemmPipelineVer,
148 ActivationOP,
149 NSwizzle,
150 IsInputGemm,
151 MulRoutedWeight,
152 IndexType,
153 ComputeTypeA,
154 ComputeTypeB>;
157
158 using Argument = typename GridwiseGemm64::Argument;
159
162
163 int GetPreShuffleParameters() override { return NPerXDL; }
164
165 // Invoker
166 struct Invoker : public BaseInvoker
167 {
168 template <typename GridwiseGemm>
169 float RunImp(const typename GridwiseGemm::Argument& arg,
170 const StreamConfig& stream_config = StreamConfig{})
171 {
172 if(stream_config.log_level_ > 0)
173 {
174 arg.Print();
175 }
176
177 if(!GridwiseGemm::CheckValidity(arg))
178 {
179 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
180 }
181
182 index_t gdx, gdy, gdz;
183 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
184
185 float ave_time = 0;
186
187 index_t k_grain = arg.KBatch * KPerBlock;
188 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
189
190 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
191
192 const auto RunKernel = [&](const auto& kernel) {
193 if(stream_config.flush_cache)
194 {
195
196 std::array<std::size_t, NumDTensor> DsSize;
197
198 auto arg_ = arg;
199
200 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
201 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
202 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
203 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
204
205 auto size_a_buffer =
206 a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
207 auto size_b_buffer =
208 b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
209
210 const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
211 arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
212
213 static_for<0, NumDTensor, 1>{}([&](auto i) {
214 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
215 DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
216 });
217 ck::utility::RotatingMemWrapperMultiD<typename GridwiseGemm::Argument,
218 DsDataType>
219 rotating_mem(arg_,
220 stream_config.rotating_count,
221 size_a_buffer,
222 size_b_buffer,
223 DsSize);
224 rotating_mem.Print();
225
226 auto run_flush_cache = [&]() {
227 // flush icache
229 // rotating mem
230 rotating_mem.Next();
231 // clear c mem
232 if(arg_.KBatch > 1)
233 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
234 0,
235 arg_.M * arg_.N * sizeof(CDataType),
236 stream_config.stream_id_));
237 };
238
240 stream_config,
241 run_flush_cache,
242 kernel,
243 dim3(gdx, gdy, gdz),
244 dim3(BlockSize),
245 0,
246 arg_);
247 }
248 else
249 {
250 if(arg.KBatch > 1)
251 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
252 0,
253 arg.M * arg.N * sizeof(CDataType),
254 stream_config.stream_id_));
255
256 ave_time = launch_and_time_kernel(
257 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
258 }
259 };
260
261 // TODO: Check if this is the right algorithm for minimum_occupancy
262 constexpr index_t minimum_occupancy =
263 BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
264 ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
265 MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
266 ? 2
267 : 1
268 : 2;
269
270 constexpr auto MemoryDataOp =
272 if(has_main_k_block_loop)
273 {
274 // Tail number always full
275 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
276 {
277 const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
278 true,
279 MemoryDataOp,
280 minimum_occupancy,
282 RunKernel(kernel);
283 }
284 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
285 {
286 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
287 {
288 const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
289 true,
290 MemoryDataOp,
291 minimum_occupancy,
293 RunKernel(kernel);
294 }
295 else
296 {
297 const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
298 true,
299 MemoryDataOp,
300 minimum_occupancy,
302 RunKernel(kernel);
303 }
304 }
305 else
306 {
307 throw std::runtime_error("todo: only v1 & v3 support now");
308 }
309 }
310 else
311 {
312 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
313 {
314 const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
315 false,
316 MemoryDataOp,
317 minimum_occupancy,
319 RunKernel(kernel);
320 }
321 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
322 {
323 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
324 {
325 const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
326 false,
327 MemoryDataOp,
328 minimum_occupancy,
330 RunKernel(kernel);
331 }
332 else
333 {
334 const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
335 false,
336 MemoryDataOp,
337 minimum_occupancy,
339 RunKernel(kernel);
340 }
341 }
342 }
343
344 return ave_time;
345 }
346
348
349 // polymorphic
350 float Run(const BaseArgument* p_arg,
351 const StreamConfig& stream_config = StreamConfig{}) override
352 {
353 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
354 }
355 };
356
357 static constexpr bool IsValidCompilationParameter()
358 {
359 // TODO: properly implement this check
360 return true;
361 }
362
363 static bool IsSupportedArgument(const Argument& arg)
364 {
365 // only impl kbatch 1 now
366 if(arg.KBatch > 1)
367 {
368 return false;
369 }
371 {
372 return false;
373 }
374 if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
375 {
376 return false;
377 }
378
379 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
380 GemmSpec == GemmSpecialization::NKPadding ||
381 GemmSpec == GemmSpecialization::MNKPadding ||
382 GemmSpec == GemmSpecialization::KPadding))
383 {
384 return false;
385 }
386 if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
387 {
388 return false;
389 }
390
391 if(get_warp_size() == 64)
392 {
393 if constexpr(NXdlPerWave64 > 0)
394 {
396 }
397 }
398 else
399 {
400 if constexpr(NXdlPerWave32 > 0)
401 {
403 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
404 }
405 }
406 return false;
407 }
408
409 // polymorphic
410 bool IsSupportedArgument(const BaseArgument* p_arg) override
411 {
412 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
413 }
414
415 static auto MakeArgument(const void* p_sorted_token_ids,
416 const void* p_sorted_expert_ids,
417 const void* p_max_token_id,
418 const void* p_a,
419 const void* p_a_scale,
420 const void* p_b,
421 const void* p_b_scale,
422 std::array<const void*, NumDTensor> p_ds,
423 void* p_c,
424 index_t NumTokens,
425 index_t TopK,
426 index_t M,
427 index_t N,
428 index_t K,
429 index_t StrideA,
430 index_t StrideScaleA,
431 index_t StrideB,
432 index_t StrideScaleB,
433 std::array<index_t, NumDTensor> StrideDs,
434 index_t StrideC,
435 index_t KBatch,
436 AElementwiseOperation a_element_op,
437 BElementwiseOperation b_element_op,
438 CElementwiseOperation c_element_op)
439 {
440 return Argument{static_cast<const index_t*>(p_sorted_token_ids),
441 static_cast<const index_t*>(p_sorted_expert_ids),
442 static_cast<const index_t*>(p_max_token_id),
443 static_cast<const ADataType*>(p_a),
444 static_cast<const AScaleDataType*>(p_a_scale),
445 static_cast<const BDataType*>(p_b),
446 static_cast<const BScaleDataType*>(p_b_scale),
447 p_ds,
448 static_cast<CDataType*>(p_c),
449 NumTokens,
450 TopK,
451 M,
452 N,
453 K,
454 StrideA,
455 StrideScaleA,
456 StrideB,
457 StrideScaleB,
458 StrideDs,
459 StrideC,
460 KBatch,
461 a_element_op,
462 b_element_op,
463 c_element_op};
464 }
465
466 static auto MakeInvoker() { return Invoker{}; }
467
468 // polymorphic
469 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
470 const void* p_a_scale,
471 const void* p_b,
472 const void* p_b_scale,
473 std::array<const void*, NumDTensor> p_ds,
474 void* p_c,
475 index_t M,
476 index_t N,
477 index_t K,
478 index_t StrideA,
479 index_t StrideScaleA,
480 index_t StrideB,
481 index_t StrideScaleB,
482 std::array<ck::index_t, NumDTensor> StrideDs,
483 index_t StrideC,
484 index_t KBatch,
485 AElementwiseOperation a_element_op,
486 BElementwiseOperation b_element_op,
487 CElementwiseOperation c_element_op) override
488 {
489 return std::make_unique<Argument>(nullptr,
490 nullptr,
491 nullptr,
492 static_cast<const ADataType*>(p_a),
493 static_cast<const AScaleDataType*>(p_a_scale),
494 static_cast<const BDataType*>(p_b),
495 static_cast<const BScaleDataType*>(p_b_scale),
496 p_ds,
497 static_cast<CDataType*>(p_c),
498 M, // randoms set, no use
499 0,
500 M,
501 N,
502 K,
503 StrideA,
504 StrideScaleA,
505 StrideB,
506 StrideScaleB,
507 StrideDs,
508 StrideC,
509 KBatch,
510 a_element_op,
511 b_element_op,
512 c_element_op);
513 }
514
515 // polymorphic
516 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
517 {
518 return std::make_unique<Invoker>(Invoker{});
519 }
520
521 // polymorphic
522 std::string GetTypeString() const override
523 {
524 auto str = std::stringstream();
525
526 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
529
530 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
536
537 // clang-format off
538 str << "DeviceMoeGEmmMx"
539 << "<"
540 << getGemmSpecializationString(GemmSpec) << ", "
541 << std::string(ALayout::name)[0]
542 << std::string(BLayout::name)[0]
543 << std::string(CLayout::name)[0]
544 << ">"
545 << " BlkSize: "
546 << BlockSize << ", "
547 << "BlkTile: "
548 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
549 << "WaveTile: "
550 << MPerXDL<<"x"<<NPerXDL << ", "
551 << "WaveMap: "
552 << MXdlPerWave<<"x" << NXdlPerWave<<", "
553 << "VmemReadVec: "
554 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
555 << "BlkGemmPipelineScheduler: "
556 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
557 << "BlkGemmPipelineVersion: "
558 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
559 << "BlkGemmPipelinePrefetchStages: "
560 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
561 // clang-format on
562
563 return str.str();
564 }
565};
566
567} // namespace device
568} // namespace tensor_operation
569} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_mx_gemm_bns.hpp:48
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
constexpr index_t packed_size_v
Definition data_type.hpp:411
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_moe_mx_gemm_bns.hpp:179
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_gemm_multiple_d.hpp:167
Definition device_moe_mx_gemm_bns.hpp:167
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_moe_mx_gemm_bns.hpp:350
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_moe_mx_gemm_bns.hpp:169
Definition device_moe_mx_gemm_bns.hpp:92
static constexpr index_t BPackedSize
Definition device_moe_mx_gemm_bns.hpp:161
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_moe_mx_gemm_bns.hpp:94
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_moe_mx_gemm_bns.hpp:155
std::string GetTypeString() const override
Definition device_moe_mx_gemm_bns.hpp:522
static auto MakeInvoker()
Definition device_moe_mx_gemm_bns.hpp:466
static constexpr index_t APackedSize
Definition device_moe_mx_gemm_bns.hpp:160
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_moe_mx_gemm_bns.hpp:156
GridwiseMoeGemmMXBNS< ALayout, BLayout, DsLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, IndexType, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_moe_mx_gemm_bns.hpp:98
typename GridwiseGemm64::Argument Argument
Definition device_moe_mx_gemm_bns.hpp:158
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_moe_mx_gemm_bns.hpp:415
static constexpr auto NXdlPerWave32
Definition device_moe_mx_gemm_bns.hpp:95
static bool IsSupportedArgument(const Argument &arg)
Definition device_moe_mx_gemm_bns.hpp:363
int GetPreShuffleParameters() override
Definition device_moe_mx_gemm_bns.hpp:163
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_moe_mx_gemm_bns.hpp:516
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_moe_mx_gemm_bns.hpp:410
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_moe_mx_gemm_bns.hpp:469
static constexpr index_t NumDTensor
Definition device_moe_mx_gemm_bns.hpp:96
static constexpr bool IsValidCompilationParameter()
Definition device_moe_mx_gemm_bns.hpp:357
Definition flush_cache.hpp:174