28 static constexpr auto I0 = Number<0>{};
29 static constexpr auto I1 = Number<1>{};
37 template <
typename... Ts>
38 __host__ __device__
constexpr static auto
39 GenerateDefaultIdxsTuple([[maybe_unused]]
const Tuple<Ts...>&
shape)
41 return generate_tuple(
43 if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
54 Number<Tuple<Ts...>::Size()>{});
66 template <
typename Idx,
typename... Ts>
67 __host__ __device__
constexpr static auto
68 GenerateLowerDim([[maybe_unused]]
const Tuple<Ts...>&
shape)
70 if constexpr(Idx::value == 0)
72 if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>
::value)
75 constexpr index_t merge_nelems =
decltype(UnrollNestedTuple(
76 tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
77 using LowerDimsSequence =
78 typename arithmetic_sequence_gen<0, merge_nelems, 1>::type;
79 return LowerDimsSequence::Reverse();
90 using PreviousSeqT =
decltype(GenerateLowerDim<Number<Idx::value - 1>>(Tuple<Ts...>{}));
91 const auto next_seq_val = PreviousSeqT::At(I0) + 1;
92 if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>
::value)
94 constexpr index_t merge_nelems =
decltype(UnrollNestedTuple(
95 tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
96 using LowerDimsSequence =
97 typename arithmetic_sequence_gen<next_seq_val, next_seq_val + merge_nelems, 1>::
99 return LowerDimsSequence::Reverse();
103 return Sequence<next_seq_val>{};
119 template <
typename... ShapeDims,
typename... IdxDims>
120 __host__ __device__
constexpr static auto AlignShapeToIdx(
const Tuple<ShapeDims...>&
shape,
121 const Tuple<IdxDims...>& idx)
123 if constexpr(!IsNestedTuple(Tuple<IdxDims...>{}))
133 auto aligned_shape = generate_tuple(
135 if constexpr(is_detected<is_tuple,
136 tuple_element_t<i, Tuple<IdxDims...>>>
::value)
142 return make_tuple(
shape.At(i));
145 Number<Tuple<IdxDims...>::Size()>{});
148 return AlignShapeToIdx(UnrollNestedTuple<0, 1>(aligned_shape),
149 UnrollNestedTuple<0, 1>(idx));
160 template <
typename... ShapeDims,
typename DescriptorToMerge>
161 __host__ __device__
constexpr static auto MakeMerge1d(
const Tuple<ShapeDims...>&
shape,
162 const DescriptorToMerge& desc)
165 const auto merge_elems = TupleReverse(UnrollNestedTuple(
shape));
167 using MergeElemsSequence =
typename arithmetic_sequence_gen<0, merge_elems.Size(), 1>::type;
168 const auto lower_dims = make_tuple(MergeElemsSequence::Reverse());
169 const auto upper_dims = make_tuple(Sequence<0>{});
171 if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
173 return transform_tensor_descriptor(
174 desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
181 return transform_tensor_descriptor(
183 make_tuple(make_merge_transform_v1_carry_check(merge_elems)),
201 template <
typename... ShapeDims,
typename... IdxDims,
typename DescriptorToMerge>
202 __host__ __device__
constexpr static auto
203 CreateMergedDescriptor(
const Tuple<ShapeDims...>&
shape,
204 [[maybe_unused]]
const Tuple<IdxDims...>& idxs,
205 DescriptorToMerge& desc)
207 const auto transforms = generate_tuple(
210 if constexpr(is_detected<is_tuple,
211 tuple_element_t<i, Tuple<ShapeDims...>>>
::value &&
212 !is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>
::value)
216 const auto merge_elems = TupleReverse(UnrollNestedTuple(
shape.At(i)));
217 if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
219 return make_merge_transform(merge_elems);
226 return make_merge_transform_v1_carry_check(merge_elems);
233 !(!is_detected<is_tuple, tuple_element_t<i, Tuple<ShapeDims...>>>
::value &&
234 is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>
::value),
235 "Wrong Idx for layout()");
237 return make_pass_through_transform(
shape.At(i));
240 Number<Tuple<ShapeDims...>::Size()>{});
242 const auto lower_dims =
243 generate_tuple([&](
auto i) {
return GenerateLowerDim<Number<i>>(
shape); },
244 Number<Tuple<ShapeDims...>::Size()>{});
245 const auto upper_dims = generate_tuple([&](
auto i) {
return Sequence<i.value>{}; },
246 Number<Tuple<ShapeDims...>::Size()>{});
248 return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
251 using Descriptor1dType =
252 remove_cvref_t<
decltype(MakeMerge1d(Shape{}, UnrolledDescriptorType{}))>;
253 using DefaultIdxsTupleType = remove_cvref_t<
decltype(GenerateDefaultIdxsTuple(Shape{}))>;
268 template <
typename... ShapeDims,
typename... IdxDims>
269 __host__ __device__
constexpr static auto
271 const Tuple<IdxDims...>& idxs,
272 const UnrolledDescriptorType& naive_descriptor)
274 if constexpr(Tuple<IdxDims...>::Size() == I1)
277 return MakeMerge1d(
shape, naive_descriptor);
285 static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
286 "Idx rank and Shape rank must be the same (except 1d).");
288 const auto aligned_shape = AlignShapeToIdx(
shape, idxs);
290 return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idxs), naive_descriptor);
295 Shape{}, DefaultIdxsTupleType{}, UnrolledDescriptorType{}))>;
299 return unrolled_descriptor_.GetElementSpaceSize();
311 const UnrolledDescriptorType& unnested_descriptor)
312 : unrolled_descriptor_(unnested_descriptor), shape_(
shape)
315 if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
317 descriptor_1d_ = MakeMerge1d(shape_, unrolled_descriptor_);
318 merged_nests_descriptor_ =
319 TransformDesc(shape_, DefaultIdxsTupleType{}, unrolled_descriptor_);
329 template <
typename Idxs>
332 static_assert(remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime(),
333 "Compiletime operator used on runtime layout.");
334 using TransformedDesc =
decltype(
TransformDesc(Shape{}, Idxs{}, UnrolledDescriptorType{}));
335 using UnrolledIdx =
decltype(UnrollNestedTuple(Idxs{}));
336 return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
345 template <
typename... Ts>
346 __host__ __device__ index_t
operator()(
const Tuple<Ts...>& Idx)
const
348 if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == 1)
351 return descriptor_1d_.CalculateOffset(Idx);
353 else if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == Shape::Size())
356 return merged_nests_descriptor_.CalculateOffset(UnrollNestedTuple(Idx));
361 const auto transformed_desc =
TransformDesc(shape_, Idx, unrolled_descriptor_);
362 return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
372 template <index_t IDim>
375 const auto elem = shape_.At(Number<IDim>{});
376 if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>
::value)
378 const auto unrolled_element = UnrollNestedTuple(elem);
379 return TupleReduce<I0.value, unrolled_element.Size()>(
380 [](
auto x,
auto y) {
return x * y; }, unrolled_element);
395 const auto unrolled_shape = UnrollNestedTuple(shape_);
396 return TupleReduce<I0.value, unrolled_shape.Size()>([](
auto x,
auto y) {
return x * y; },
405 __host__ __device__
constexpr const Shape&
GetShape()
const {
return shape_; }
414 return generate_tuple([&](
auto i) {
return GetLength<i>(); }, Number<Shape::Size()>{});
424 return GenerateDefaultIdxsTuple(shape_);
436 __host__ __device__
constexpr const MergedNestsDescriptorType&
439 return merged_nests_descriptor_;
451 return descriptor_1d_;
463 return unrolled_descriptor_;
470 UnrolledDescriptorType unrolled_descriptor_;
472 Descriptor1dType descriptor_1d_;
474 MergedNestsDescriptorType merged_nests_descriptor_;