BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > Struct Template Reference#
GPU kernel for batched tensor contraction operations. More...
#include <batched_contraction_kernel.hpp>
Public Types | |
| using | Problem = ck_tile::remove_cvref_t<Problem_> |
| Tensor contraction problem specification. | |
| using | ADataType |
| Data type for input tensor A. | |
| using | BDataType |
| Data type for input tensor B. | |
| using | DsDataType |
| using | EDataType |
| Data type for output tensor E. | |
| using | TilePartitioner |
| using | GemmPipeline = ck_tile::remove_cvref_t<GemmPipeline_> |
| GEMM computation pipeline. | |
| using | EpiloguePipeline |
| Epilogue pipeline for post-GEMM operations. | |
| using | UniversalGemmKernel |
| using | KernelArgs |
Public Member Functions | |
| CK_TILE_DEVICE void | operator() (const KernelArgs &kargs) const |
Static Public Member Functions | |
| static CK_TILE_HOST constexpr auto | GetKernelName () |
| Returns the kernel name for debugging and profiling purposes. | |
| static CK_TILE_HOST constexpr bool | IsSupportedArguments (const KernelArgs &kargs) |
| Validates whether the given kernel arguments are supported. | |
| static CK_TILE_HOST constexpr ck_tile::index_t | GetSmemSize () |
| Returns the shared memory size required by the kernel. | |
| static CK_TILE_HOST constexpr auto | GetBlockSize () |
| Returns the GPU block size for kernel launch. | |
| static CK_TILE_HOST constexpr auto | GridSize (const KernelArgs &kargs) |
| static CK_TILE_HOST constexpr KernelArgs | MakeKernelArgs (const BatchedContractionHostArgs< NumDTensor > &host_args) |
Static Public Attributes | |
| static constexpr ck_tile::index_t | NumDimG = Problem::NumDimG |
| Number of batch dimensions. | |
| static constexpr ck_tile::index_t | NumDimM |
| Number of M (output row) dimensions. | |
| static constexpr ck_tile::index_t | NumDimN |
| Number of N (output column) dimensions. | |
| static constexpr ck_tile::index_t | NumDimK |
| Number of K (contraction) dimensions. | |
| static constexpr ck_tile::index_t | NumDTensor |
| Number of auxiliary input D tensors. | |
| static constexpr ck_tile::index_t | kBlockSize |
| GPU block size inherited from GEMM kernel. | |
Detailed Description
struct BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >
GPU kernel for batched tensor contraction operations.
- Overview
- This kernel performs batched tensor contraction operations using the underlying UniversalGemmKernel. It supports arbitrary tensor dimensionalities (G, M, N, K) and processes multiple batch instances in parallel. Each batch performs: E = epilogue_op(contraction(A, B), D0, D1, ...).
- Template Parameters
-
Problem_ Tensor contraction problem specification defining data types and dimensions TilePartitioner_ Tile partitioning strategy for workload distribution GemmPipeline_ GEMM computation pipeline for core matrix operations EpiloguePipeline_ Epilogue pipeline for post-GEMM operations and tensor fusion
Member Typedef Documentation
◆ ADataType
| using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::ADataType |
Data type for input tensor A.
◆ BDataType
| using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::BDataType |
Data type for input tensor B.
◆ DsDataType
| using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::DsDataType |
Data types for auxiliary input tensors D
◆ EDataType
| using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::EDataType |
Data type for output tensor E.
◆ EpiloguePipeline
| using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::EpiloguePipeline |
Epilogue pipeline for post-GEMM operations.
◆ GemmPipeline
| using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GemmPipeline = ck_tile::remove_cvref_t<GemmPipeline_> |
GEMM computation pipeline.
◆ KernelArgs
| using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::KernelArgs |
Kernel argument structure
◆ Problem
| using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::Problem = ck_tile::remove_cvref_t<Problem_> |
Tensor contraction problem specification.
◆ TilePartitioner
| using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::TilePartitioner |
Tile partitioning strategy for workload distribution
◆ UniversalGemmKernel
| using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::UniversalGemmKernel |
Member Function Documentation
◆ GetBlockSize()
|
inlinestaticconstexpr |
Returns the GPU block size for kernel launch.
- Returns
- 3D block dimensions for GPU kernel execution
◆ GetKernelName()
|
inlinestaticconstexpr |
Returns the kernel name for debugging and profiling purposes.
- Returns
- Constant string identifier for this kernel
◆ GetSmemSize()
|
inlinestaticconstexpr |
Returns the shared memory size required by the kernel.
- Returns
- Shared memory size in bytes
Delegates to underlying GEMM kernel's shared memory requirements
◆ GridSize()
|
inlinestaticconstexpr |
◆ IsSupportedArguments()
|
inlinestaticconstexpr |
Validates whether the given kernel arguments are supported.
- Parameters
-
kargs Kernel arguments to validate
- Returns
- True if arguments are supported, false otherwise
Checks underlying GEMM kernel support and ensures valid batch dimensions
◆ MakeKernelArgs()
|
inlinestaticconstexpr |
◆ operator()()
|
inline |
Member Data Documentation
◆ kBlockSize
|
staticconstexpr |
GPU block size inherited from GEMM kernel.
◆ NumDimG
|
staticconstexpr |
Number of batch dimensions.
◆ NumDimK
|
staticconstexpr |
Number of K (contraction) dimensions.
◆ NumDimM
|
staticconstexpr |
Number of M (output row) dimensions.
◆ NumDimN
|
staticconstexpr |
Number of N (output column) dimensions.
◆ NumDTensor
|
staticconstexpr |
Number of auxiliary input D tensors.
The documentation for this struct was generated from the following file: