device_pool_fwd.hpp Source File

device_pool_fwd.hpp Source File#

Composable Kernel: device_pool_fwd.hpp Source File
device_pool_fwd.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <vector>
7
10
11namespace ck {
12namespace tensor_operation {
13namespace device {
14
15template <index_t InOutRank,
16 index_t WindowRank,
17 typename InDataType,
18 typename OutDataType,
19 typename IndexDataType,
20 typename InLayout,
21 typename OutLayout,
22 ReduceTensorOp ReduceOpId,
23 bool OutputIndex>
25{
26 virtual std::unique_ptr<BaseArgument>
27 MakeArgumentPointer(const void* p_in_dev,
28 void* p_out_dev,
29 void* p_out_indices_dev,
30 std::vector<ck::index_t> input_n_c_wis_lengths,
31 std::vector<ck::index_t> window_xs_lengths,
32 std::vector<ck::index_t> output_n_c_wos_lengths,
33 std::vector<ck::index_t> input_n_c_wis_stride,
34 std::vector<ck::index_t> output_n_c_wis_stride,
35 std::vector<ck::index_t> indices_n_c_wis_stride,
36 std::vector<ck::index_t> window_xs_strides,
37 std::vector<ck::index_t> window_xs_dilations,
38 std::vector<ck::index_t> input_left_pads,
39 std::vector<ck::index_t> input_right_pads,
40 std::vector<ck::index_t> pooling_dims) = 0;
41
42 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
43};
44
45} // namespace device
46} // namespace tensor_operation
47} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
ReduceTensorOp
Definition reduction_enums.hpp:9
Definition device_pool_fwd.hpp:25
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_dev, void *p_out_dev, void *p_out_indices_dev, std::vector< ck::index_t > input_n_c_wis_lengths, std::vector< ck::index_t > window_xs_lengths, std::vector< ck::index_t > output_n_c_wos_lengths, std::vector< ck::index_t > input_n_c_wis_stride, std::vector< ck::index_t > output_n_c_wis_stride, std::vector< ck::index_t > indices_n_c_wis_stride, std::vector< ck::index_t > window_xs_strides, std::vector< ck::index_t > window_xs_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, std::vector< ck::index_t > pooling_dims)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0