device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp Source File

device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp Source File#

Composable Kernel: device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp Source File
device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13namespace tensor_operation {
14namespace device {
15
16//
17// @brief Device Convolution operation.
18// @note This structure is deprecated (left for backwards compatibility). Please use
19// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle.
20// Supports:
21// @li Forward convolution with up to 3 spatial dimentions
22// @li Input tensor in GNWC data format
23// @li Weight tensor in GKXC data format
24// @li Output tensor in GNWK data format
25//
26// 1D:
27// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
28// 2D:
29// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
30// 3D:
31// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
32//
33template <index_t NDimSpatial,
34 typename ALayout,
35 typename BLayout,
36 typename DsLayout,
37 typename ELayout,
38 typename ADataType,
39 typename BDataType,
40 typename AccDataType,
41 typename CShuffleDataType,
42 typename DsDataType,
43 typename EDataType,
44 typename AElementwiseOperation,
45 typename BElementwiseOperation,
46 typename CDEElementwiseOperation,
47 ConvolutionForwardSpecialization ConvForwardSpecialization,
48 GemmSpecialization GemmSpec,
49 index_t NumGemmKPrefetchStage,
50 index_t BlockSize,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t AK1,
55 index_t BK1,
56 index_t MPerXDL,
57 index_t NPerXDL,
58 index_t MXdlPerWave,
59 index_t NXdlPerWave,
60 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
61 typename ABlockTransferThreadClusterArrangeOrder,
62 typename ABlockTransferSrcAccessOrder,
63 index_t ABlockTransferSrcVectorDim,
64 index_t ABlockTransferSrcScalarPerVector,
65 index_t ABlockTransferDstScalarPerVector_AK1,
66 index_t ABlockLdsExtraM,
67 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
68 typename BBlockTransferThreadClusterArrangeOrder,
69 typename BBlockTransferSrcAccessOrder,
70 index_t BBlockTransferSrcVectorDim,
71 index_t BBlockTransferSrcScalarPerVector,
72 index_t BBlockTransferDstScalarPerVector_BK1,
73 index_t BBlockLdsExtraN,
74 index_t CShuffleMXdlPerWavePerShuffle,
75 index_t CShuffleNXdlPerWavePerShuffle,
76 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
77 index_t CDEBlockTransferScalarPerVector_NPerBlock,
78 typename AComputeDataType =
79 decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
81 ADataType>()), // ComputeType is InputType by default (first
82 // in tuple for MultiAB), unpack if tuple was
83 // passed
84 typename BComputeDataType = AComputeDataType,
87 NDimSpatial,
88 ALayout,
89 BLayout,
90 DsLayout,
91 ELayout,
92 ADataType,
93 BDataType,
94 AccDataType,
95 CShuffleDataType,
96 DsDataType,
97 EDataType,
98 AElementwiseOperation,
99 BElementwiseOperation,
100 CDEElementwiseOperation,
101 ConvForwardSpecialization,
102 GemmSpec,
103 NumGemmKPrefetchStage,
104 BlockSize,
105 MPerBlock,
106 NPerBlock,
107 KPerBlock,
108 AK1,
109 BK1,
110 MPerXDL,
111 NPerXDL,
112 MXdlPerWave,
113 NXdlPerWave,
114 ABlockTransferThreadClusterLengths_AK0_M_AK1,
115 ABlockTransferThreadClusterArrangeOrder,
116 ABlockTransferSrcAccessOrder,
117 ABlockTransferSrcVectorDim,
118 ABlockTransferSrcScalarPerVector,
119 ABlockTransferDstScalarPerVector_AK1,
120 ABlockLdsExtraM,
121 BBlockTransferThreadClusterLengths_BK0_N_BK1,
122 BBlockTransferThreadClusterArrangeOrder,
123 BBlockTransferSrcAccessOrder,
124 BBlockTransferSrcVectorDim,
125 BBlockTransferSrcScalarPerVector,
126 BBlockTransferDstScalarPerVector_BK1,
127 BBlockLdsExtraN,
128 CShuffleMXdlPerWavePerShuffle,
129 CShuffleNXdlPerWavePerShuffle,
130 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
131 CDEBlockTransferScalarPerVector_NPerBlock,
132 AComputeDataType,
133 BComputeDataType,
134 LoopSched>;
135
136} // namespace device
137} // namespace tensor_operation
138} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
LoopScheduler
Definition loop_scheduler.hpp:15
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:325