device_batched_gemm_multiple_d_gemm_multiple_d.hpp Source File

device_batched_gemm_multiple_d_gemm_multiple_d.hpp Source File#

Composable Kernel: device_batched_gemm_multiple_d_gemm_multiple_d.hpp Source File
device_batched_gemm_multiple_d_gemm_multiple_d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <vector>
8
9#include "device_base.hpp"
10
11namespace ck {
12namespace tensor_operation {
13namespace device {
14
15template <typename A0Layout,
16 typename B0Layout,
17 typename D0sLayout,
18 typename B1Layout,
19 typename D1sLayout,
20 typename E1Layout,
21 typename A0DataType,
22 typename B0DataType,
23 typename D0sDataType,
24 typename B1DataType,
25 typename D1sDataType,
26 typename E1DataType,
27 typename A0ElementwiseOperation,
28 typename B0ElementwiseOperation,
29 typename CDE0ElementwiseOperation,
30 typename B1ElementwiseOperation,
31 typename CDE1ElementwiseOperation>
33{
34 static constexpr index_t NumD0Tensor = D0sDataType::Size();
35 static constexpr index_t NumD1Tensor = D1sDataType::Size();
36
37 virtual std::unique_ptr<BaseArgument>
38 MakeArgumentPointer(const void* p_a0,
39 const void* p_b0,
40 std::array<const void*, NumD0Tensor> p_d0s,
41 const void* p_b1,
42 std::array<const void*, NumD1Tensor> p_d1s,
43 void* p_e1,
48 ck::index_t Batch,
49 ck::index_t StrideA0,
50 ck::index_t StrideB0,
51 std::array<ck::index_t, NumD0Tensor> StrideD0s,
52 ck::index_t StrideB1,
53 std::array<ck::index_t, NumD1Tensor> StrideD1s,
54 ck::index_t StrideE1,
55 ck::index_t BatchStrideA0,
56 ck::index_t BatchStrideB0,
57 std::array<ck::index_t, NumD0Tensor> BatchStrideD0s,
58 ck::index_t BatchStrideB1,
59 std::array<ck::index_t, NumD1Tensor> BatchStrideD1s,
60 ck::index_t BatchStrideE1,
61 A0ElementwiseOperation a0_element_op,
62 B0ElementwiseOperation b0_element_op,
63 CDE0ElementwiseOperation cde0_element_op,
64 B1ElementwiseOperation b1_element_op,
65 CDE1ElementwiseOperation cde1_element_op) = 0;
66
67 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
68};
69
70} // namespace device
71} // namespace tensor_operation
72} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_batched_gemm_multiple_d_gemm_multiple_d.hpp:33
static constexpr index_t NumD1Tensor
Definition device_batched_gemm_multiple_d_gemm_multiple_d.hpp:35
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a0, const void *p_b0, std::array< const void *, NumD0Tensor > p_d0s, const void *p_b1, std::array< const void *, NumD1Tensor > p_d1s, void *p_e1, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t O, ck::index_t Batch, ck::index_t StrideA0, ck::index_t StrideB0, std::array< ck::index_t, NumD0Tensor > StrideD0s, ck::index_t StrideB1, std::array< ck::index_t, NumD1Tensor > StrideD1s, ck::index_t StrideE1, ck::index_t BatchStrideA0, ck::index_t BatchStrideB0, std::array< ck::index_t, NumD0Tensor > BatchStrideD0s, ck::index_t BatchStrideB1, std::array< ck::index_t, NumD1Tensor > BatchStrideD1s, ck::index_t BatchStrideE1, A0ElementwiseOperation a0_element_op, B0ElementwiseOperation b0_element_op, CDE0ElementwiseOperation cde0_element_op, B1ElementwiseOperation b1_element_op, CDE1ElementwiseOperation cde1_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumD0Tensor
Definition device_batched_gemm_multiple_d_gemm_multiple_d.hpp:34