tensor_adaptor.hpp Source File

tensor_adaptor.hpp Source File#

Composable Kernel: tensor_adaptor.hpp Source File
tile/core/tensor/tensor_adaptor.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
14
15namespace ck_tile {
16
17// Transforms: Tuple<transforms...>
18// LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
19// UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
20// BottomDimensionHiddenIds : Sequence<...>
21// TopDimensionHiddenIds : Sequence<...>
22template <typename Transforms,
23 typename LowerDimensionHiddenIdss,
24 typename UpperDimensionHiddenIdss,
25 typename BottomDimensionHiddenIds,
26 typename TopDimensionHiddenIds>
28{
30 {
31 return Transforms::size();
32 }
33
34 CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const { return transforms_; }
35
37 {
38 return LowerDimensionHiddenIdss{};
39 }
40
42 {
43 return UpperDimensionHiddenIdss{};
44 }
45
47 {
48 return BottomDimensionHiddenIds{};
49 }
50
52 {
53 return TopDimensionHiddenIds{};
54 }
55
56 CK_TILE_HOST_DEVICE static constexpr auto initialize_element_size(const Transforms& transforms)
57 {
58 const auto lengths = generate_tuple(
59 [&](auto idim_top) {
60 constexpr index_t idim_hidden = TopDimensionHiddenIds::at(idim_top);
61
63
64 constexpr index_t itran = tmp[number<0>{}];
65 constexpr index_t idim_up = tmp[number<1>{}];
66 constexpr bool found = tmp[number<2>{}];
67
68 static_assert(found == true,
69 "wrong! not found matching transformation and upper-dimension");
70
71 const auto length =
72 transforms[number<itran>{}].get_upper_lengths()[number<idim_up>{}];
73
74 return length;
75 },
77
78 // TODO: make container_reduce support tuple of number and index_t
79 return container_reduce(lengths, multiplies{}, number<1>{});
80 }
81
82 template <index_t IDimHidden>
83 CK_TILE_HOST_DEVICE static constexpr auto
85 {
86 // FIXME: length of bottom dimension is not known, since info about lower dim length are not
87 // saved in transformation
88 static_assert(IDimHidden >= ndim_bottom_, "wrong! not implemented");
89
90 index_t itran_found = 0;
91 index_t idim_up_found = 0;
92 bool found = false;
93
94 static_for<0, ntransform_, 1>{}([&](auto itran) {
95 constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
96
97 static_for<0, up_dim_ids.size(), 1>{}([&](auto idim_up) {
98 if constexpr(up_dim_ids[idim_up] == IDimHidden)
99 {
100 itran_found = itran;
101 idim_up_found = idim_up;
102 found = true;
103 }
104 });
105 });
106
107 return make_tuple(itran_found, idim_up_found, found);
108 }
109
111 {
112 return BottomDimensionHiddenIds::size();
113 }
114
116 {
117 return TopDimensionHiddenIds::size();
118 }
119
121 {
122 constexpr auto all_low_dim_ids =
123 unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
124 LowerDimensionHiddenIdss{});
125
126 constexpr auto all_up_dim_ids =
127 unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
128 UpperDimensionHiddenIdss{});
129
130 constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
131
132 using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
134 equal<index_t>>::type;
135
136 return unique_sort_all_dim_ids::size();
137 }
138
143
147
148 // may be index_t or number<>
149 using ElementSize = remove_cv_t<decltype(initialize_element_size(Transforms{}))>;
150
151 public:
152 CK_TILE_HOST_DEVICE constexpr tensor_adaptor() = default;
153
154 CK_TILE_HOST_DEVICE constexpr tensor_adaptor(const Transforms& transforms)
155 : transforms_{transforms}, element_size_{initialize_element_size(transforms)}
156 {
157 static_assert(Transforms::size() == ntransform_ &&
158 LowerDimensionHiddenIdss::size() == ntransform_ &&
159 UpperDimensionHiddenIdss::size() == ntransform_,
160 "wrong! inconsistent # of transformations");
161
162 // TODO check dependency of dimensions is valid
163 }
164
165 CK_TILE_HOST_DEVICE constexpr auto get_element_size() const { return element_size_; }
166
167 // FIXME: this logic is wrong when getting bottome dimension lengths
168 template <index_t IDimHidden>
170 {
171 static_assert(IDimHidden >= 0 && IDimHidden < ndim_hidden_, "wrong! out of range");
172
174
175 constexpr index_t itran = tmp[number<0>{}];
176 constexpr index_t idim_up = tmp[number<1>{}];
177 constexpr bool found = tmp[number<2>{}];
178
179 static_assert(found == true,
180 "wrong! not found matching transformation and upper-dimension");
181
182 return transforms_[number<itran>{}].get_upper_lengths()[number<idim_up>{}];
183 }
184
185 template <index_t IDimTop>
187 {
188 return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_top));
189 }
190
191#if 0
192 // FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
193 template <index_t IDimBottom>
195 get_bottom_dimension_length(number<IDimBottom> idim_bottom) const
196 {
197 return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_bottom));
198 }
199#endif
200
202 {
203 return generate_tuple([&](auto i) { return get_top_dimension_length(i); },
205 }
206
207#if 0
208 // FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
209 CK_TILE_HOST_DEVICE constexpr auto GetBottomDimensionLengths() const
210 {
211 return generate_tuple([&](auto i) { return get_bottom_dimension_length(i); },
213 }
214#endif
215
216 template <typename TopIdx>
217 CK_TILE_HOST_DEVICE constexpr auto calculate_bottom_index(const TopIdx& idx_top) const
218 {
219 static_assert(TopIdx::size() == TopDimensionHiddenIds::size(),
220 "wrong! # of dimension inconsistent");
221
222 constexpr index_t ntransform = get_num_of_transform();
223 constexpr index_t ndim_hidden = get_num_of_hidden_dimension();
224
225 multi_index<ndim_hidden> idx_hidden;
226
227 // initialize uppest index
229
230 // calculate hidden index
231 static_for<ntransform, 0, -1>{}([&](auto itran_p1) {
232 auto itran = itran_p1 - number<1>{};
233 const auto& tran = get_transforms().at(itran);
234 constexpr auto dims_low = get_lower_dimension_hidden_idss().at(itran);
235 constexpr auto dims_up = get_upper_dimension_hidden_idss().at(itran);
236
237 const auto idx_up = get_container_subset(idx_hidden, dims_up);
238
239 multi_index<dims_low.size()> idx_low;
240
241 tran.calculate_lower_index(idx_low, idx_up);
242
243 set_container_subset(idx_hidden, dims_low, idx_low);
244 });
245
246 return get_container_subset(idx_hidden, BottomDimensionHiddenIds{});
247 }
248
249 CK_TILE_HOST_DEVICE static constexpr bool is_static()
250 {
251 bool is_known = true;
252
253 static_for<0, Transforms::size(), 1>{}([&](auto i) {
254 is_known &= remove_cvref_t<decltype(Transforms{}[i])>::is_known_at_compile_time();
255 });
256
258 }
259
260 CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); }
261
262 template <index_t Internal = 0>
264 const array<index_t, ndim_hidden_>& guaranteed_vector_lengths,
265 const array<index_t, ndim_hidden_>& guaranteed_vector_strides)
266 {
267 auto vector_lengths = guaranteed_vector_lengths;
268 auto vector_strides = guaranteed_vector_strides;
269
270 static_for<0,
271 Internal ? std::min(Internal, get_num_of_transform()) : get_num_of_transform(),
272 1>{}([&](auto itran) {
273 constexpr auto low_dims = get_lower_dimension_hidden_idss().at(itran);
274 constexpr auto up_dims = get_upper_dimension_hidden_idss().at(itran);
275
276 const auto up_guaranteed_vector_lengths =
277 get_container_subset(guaranteed_vector_lengths, up_dims);
278 const auto up_guaranteed_vector_strides =
279 get_container_subset(guaranteed_vector_strides, up_dims);
280
281 // only need type of transform
282 auto [up_vector_lengths, up_vector_strides] =
283 Transforms{}.at(itran).calculate_upper_dimension_safe_vector_length_strides(
284 get_container_subset(vector_lengths, low_dims),
285 get_container_subset(vector_strides, low_dims));
286
287 if constexpr(up_dims.size() > 0)
288 {
289 for(index_t i = 0; i < up_dims.size(); ++i)
290 {
291 up_vector_lengths(i) = (up_guaranteed_vector_lengths[i] != -1)
292 ? up_guaranteed_vector_lengths[i]
293 : up_vector_lengths[i];
294
295 up_vector_strides(i) = (up_guaranteed_vector_strides[i] != -1)
296 ? up_guaranteed_vector_strides[i]
297 : up_vector_strides[i];
298 }
299 }
300
301 set_container_subset(vector_lengths, up_dims, up_vector_lengths);
302 set_container_subset(vector_strides, up_dims, up_vector_strides);
303 });
304 if constexpr(Internal > 0)
305 {
306 return make_tuple(vector_lengths, vector_strides);
307 }
308 else
309 {
310 constexpr auto top_dims = TopDimensionHiddenIds{};
311 return make_tuple(get_container_subset(vector_lengths, top_dims),
312 get_container_subset(vector_strides, top_dims));
313 }
314 }
315
316 private:
317 Transforms transforms_;
318 ElementSize element_size_;
319};
320
321template <typename Transforms,
322 typename LowerDimensionHiddenIdss,
323 typename UpperDimensionHiddenIdss,
324 typename BottomDimensionHiddenIds,
325 typename TopDimensionHiddenIds>
326CK_TILE_HOST_DEVICE static void print(const tensor_adaptor<Transforms,
327 LowerDimensionHiddenIdss,
328 UpperDimensionHiddenIdss,
329 BottomDimensionHiddenIds,
330 TopDimensionHiddenIds>& adaptor)
331{
332 printf("tensor_adaptor{\n");
333 printf(" transforms: [");
334 print(adaptor.get_transforms());
335 printf("],\n");
336
337 printf(" LowerDimensionHiddenIds: [");
338 print(LowerDimensionHiddenIdss{});
339 printf("],\n");
340
341 printf(" UpperDimensionHiddenIds: [");
342 print(UpperDimensionHiddenIdss{});
343 printf("],\n");
344
345 printf(" BottomDimensionHiddenIds: [");
346 print(BottomDimensionHiddenIds{});
347 printf("],\n");
348
349 //
350 printf(" TopDimensionHiddenIds: [");
351 print(TopDimensionHiddenIds{});
352 printf("]\n}\n");
353}
354
355// Transforms: Tuple<transforms...>
356// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
357// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
358template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
359CK_TILE_HOST_DEVICE constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
360 LowerDimensionOldTopIdss,
361 UpperDimensionNewTopIdss)
362{
363 constexpr index_t ntransform = Transforms::size();
364
365 static_assert(LowerDimensionOldTopIdss::size() == ntransform &&
366 UpperDimensionNewTopIdss::size() == ntransform,
367 "wrong!");
368
369 // sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
370 constexpr auto all_low_dim_old_top_ids = unpack(
371 [](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{});
372
373 constexpr auto all_up_dim_new_top_ids = unpack(
374 [](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
375
376 static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
377 is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
378 "wrong!");
379
380 constexpr index_t ndim_old_top = all_low_dim_old_top_ids.size();
381 constexpr index_t ndim_new_top = all_up_dim_new_top_ids.size();
382
383 // low_dim_hidden_idss
384 constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
385
386 // up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
387 constexpr auto up_dim_hidden_idss = generate_tuple(
388 [](auto itran) { return UpperDimensionNewTopIdss{}[itran] + number<ndim_old_top>{}; },
390
391 // bottom_dim_hidden_ids
392 constexpr auto bottom_dim_hidden_ids =
394
395 // top_dim_hidden_ids
396 constexpr auto top_dim_hidden_ids =
398
400 remove_cvref_t<decltype(low_dim_hidden_idss)>,
401 remove_cvref_t<decltype(up_dim_hidden_idss)>,
402 remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
403 remove_cvref_t<decltype(top_dim_hidden_ids)>>{transforms};
404}
405
406// TODO: How to fix this? It uses an struct instead of lambda because lambda
407// doesn't have constructor, and to put it outside the scope where it is used
408// (transform_tensor_adaptor) because template cannot be defined inside a function
409// template
410template <typename NewTransforms>
412{
413 template <typename I>
414 CK_TILE_HOST_DEVICE constexpr auto operator()(I) const
415 {
416 using Tran = remove_reference_t<decltype(NewTransforms{}.at(I{}))>;
417 return number<Tran::get_num_of_upper_dimension()>{};
418 }
419};
420
421template <typename OldTensorAdaptor,
422 typename NewTransforms,
423 typename NewLowerDimensionOldTopIdss,
424 typename NewUpperDimensionNewTopIdss>
425CK_TILE_HOST_DEVICE constexpr auto
426transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor,
427 const NewTransforms& new_transforms,
428 NewLowerDimensionOldTopIdss,
429 NewUpperDimensionNewTopIdss)
430{
431 // sanity check
432 {
433 static_assert(NewTransforms::size() == NewLowerDimensionOldTopIdss::size() &&
434 NewTransforms::size() == NewUpperDimensionNewTopIdss::size(),
435 "wrong! inconsitent number of transform");
436
437 constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
438 NewLowerDimensionOldTopIdss{});
439
440 constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
441 NewUpperDimensionNewTopIdss{});
442
443 static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
444 is_valid_sequence_map<decltype(all_new_top_ids)>::value,
445 "wrong!");
446 }
447
448 // lower dimension's hidden idss
449 // convert lower dimension top idss (tuple of sequences) to hidden idss (tuple of
450 // sequences)
451 constexpr auto low_dim_hidden_idss = transform_tuples(
452 // convert lower dimension top ids (a sequence) to hidden ids (a sequence)
453 [](auto low_dim_top_ids) constexpr {
454 return transform_sequences(
455 // convert lower dimension top id to hidden id
456 [](auto low_dim_top_id) constexpr {
457 return OldTensorAdaptor::get_top_dimension_hidden_ids()[low_dim_top_id];
458 },
459 low_dim_top_ids);
460 },
461 NewLowerDimensionOldTopIdss{});
462
463 constexpr index_t num_new_transform = NewTransforms::size();
464
465 // upper dimension's hidden idss
466 constexpr index_t old_hidden_dim_number = OldTensorAdaptor::get_num_of_hidden_dimension();
467
468 constexpr auto up_dim_numbers =
470
471 constexpr auto up_dim_numbers_scan = merge_sequences(
473
474 constexpr auto up_dim_hidden_idss = generate_tuple(
475 [old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
476 return
477 typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
478 old_hidden_dim_number + up_dim_numbers_scan[i + 1],
479 1>::type{};
480 },
482
483 // new top dimension's hidden ids
484 constexpr auto unordered_new_top_dim_hidden_ids =
485 unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
486
487 constexpr auto new_top_dim_unordered2ordered = unpack(
488 [](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{});
489
490 constexpr auto new_top_dim_hidden_ids =
491 unordered_new_top_dim_hidden_ids.reorder_old_to_new(new_top_dim_unordered2ordered);
492
493 // put everything together
494 const auto all_transforms =
495 container_concat(old_tensor_adaptor.get_transforms(), new_transforms);
496
497 constexpr auto all_low_dim_hidden_idss =
498 container_concat(OldTensorAdaptor::get_lower_dimension_hidden_idss(), low_dim_hidden_idss);
499
500 constexpr auto all_up_dim_hidden_idss =
501 container_concat(OldTensorAdaptor::get_upper_dimension_hidden_idss(), up_dim_hidden_idss);
502
503 return tensor_adaptor<
504 remove_cvref_t<decltype(all_transforms)>,
505 remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
506 remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
507 remove_cvref_t<decltype(OldTensorAdaptor::get_bottom_dimension_hidden_ids())>,
508 remove_cvref_t<decltype(new_top_dim_hidden_ids)>>{all_transforms};
509}
510
511template <typename TensorAdaptor0, typename TensorAdaptor1>
512CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0,
513 const TensorAdaptor1& adaptor1)
514{
515 static_assert(TensorAdaptor0::get_num_of_top_dimension() ==
516 TensorAdaptor1::get_num_of_bottom_dimension(),
517 "wrong!");
518
519 // all_transforms = transform0 + transform1
520 const auto all_transforms =
521 container_concat(adaptor0.get_transforms(), adaptor1.get_transforms());
522
523 // shift
524 constexpr index_t adaptor0_max_hidden_id = [&]() {
525 index_t adaptor0_max_hidden_id_ = numeric<index_t>::min();
526
527 static_for<0, TensorAdaptor0::get_num_of_transform(), 1>{}([&](auto itran) {
528 constexpr index_t ndim_low =
529 TensorAdaptor0{}.get_transforms()[itran].get_num_of_lower_dimension();
530
531 static_for<0, ndim_low, 1>{}([&](auto idim_low) {
532 adaptor0_max_hidden_id_ =
533 max(adaptor0_max_hidden_id_,
534 TensorAdaptor0::get_lower_dimension_hidden_idss()[itran][idim_low].value);
535 });
536
537 constexpr index_t ndim_up =
538 TensorAdaptor0{}.get_transforms()[itran].get_num_of_upper_dimension();
539
540 static_for<0, ndim_up, 1>{}([&](auto idim_up) {
541 adaptor0_max_hidden_id_ =
542 max(adaptor0_max_hidden_id_,
543 TensorAdaptor0::get_upper_dimension_hidden_idss()[itran][idim_up].value);
544 });
545 });
546
547 return adaptor0_max_hidden_id_;
548 }();
549
550 constexpr index_t adaptor1_min_hidden_id = [&]() {
551 index_t adaptor1_min_hidden_id_ = numeric<index_t>::max();
552
553 static_for<0, TensorAdaptor1::get_num_of_transform(), 1>{}([&](auto itran) {
554 constexpr index_t ndim_low =
555 TensorAdaptor1{}.get_transforms()[itran].get_num_of_lower_dimension();
556
557 // get the min of all lower dimenions, but not bottom dimension (because their id will
558 // be matched with top id from adaptor0)
559 static_for<0, ndim_low, 1>{}([&](auto idim_low) {
560 constexpr index_t low_dim_hidden_id =
561 TensorAdaptor1::get_lower_dimension_hidden_idss()[itran][idim_low].value;
562
563 bool is_bottom_dim = false;
564 static_for<0, TensorAdaptor1::get_num_of_bottom_dimension(), 1>{}([&](auto i) {
565 if constexpr(low_dim_hidden_id ==
566 TensorAdaptor1::get_bottom_dimension_hidden_ids()[i])
567 {
568 is_bottom_dim = true;
569 }
570 });
571
572 if(!is_bottom_dim)
573 {
574 adaptor1_min_hidden_id_ = min(adaptor1_min_hidden_id_, low_dim_hidden_id);
575 }
576 });
577
578 constexpr index_t ndim_up =
579 TensorAdaptor1{}.get_transforms()[itran].get_num_of_upper_dimension();
580
581 // get the min of all upper dimensions
582 static_for<0, ndim_up, 1>{}([&](auto idim_up) {
583 adaptor1_min_hidden_id_ =
584 min(adaptor1_min_hidden_id_,
585 TensorAdaptor1::get_upper_dimension_hidden_idss()[itran][idim_up].value);
586 });
587 });
588
589 return adaptor1_min_hidden_id_;
590 }();
591
592 constexpr index_t adaptor1_hidden_id_shift =
593 adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id;
594
595 constexpr index_t ndim_bottom_1 = TensorAdaptor1::get_num_of_bottom_dimension();
596
597 // all_low_dim_hidden_idss =
598 // low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1))
599 constexpr auto low_dim_hidden_idss_1 = generate_tuple(
600 // generate sequence of ids for a transform
601 [&](auto itran) {
602 constexpr auto ndim_low_1 =
603 TensorAdaptor1::get_lower_dimension_hidden_idss()[itran].size();
604
605 constexpr auto low_dim_hidden_ids_1 =
606 TensorAdaptor1::get_lower_dimension_hidden_idss()[itran];
607
608 // sequence in, sequence out
609 constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr {
610 auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
611
612 // shift hidden id so every dim id is unique
613 static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
614 low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift;
615 });
616
617 // match hidden id
618 static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
619 static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) {
620 // if this low dim is bottom dim, then do id matching
621 if constexpr(low_dim_hidden_ids_1[idim_low_1] ==
622 TensorAdaptor1::get_bottom_dimension_hidden_ids()
623 [idim_bottom_1])
624 {
625 low_dim_hidden_ids_1_mod_(idim_low_1) =
626 TensorAdaptor0::get_top_dimension_hidden_ids()[idim_bottom_1];
627 }
628 });
629 });
630
631 return low_dim_hidden_ids_1_mod_;
632 }();
633
635 [&](auto i) constexpr { return number<low_dim_hidden_ids_1_mod[i]>{}; },
637 },
638 number<TensorAdaptor1::get_num_of_transform()>{});
639
640 constexpr auto all_low_dim_hidden_idss =
641 container_concat(TensorAdaptor0::get_lower_dimension_hidden_idss(), low_dim_hidden_idss_1);
642
643 // all_up_dim_hidden_idss =
644 // up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1)
645 constexpr auto up_dim_hidden_idss_1 = generate_tuple(
646 // generate sequence of ids for a transform
647 [&](auto itran) {
648 constexpr auto ndim_up_1 =
649 TensorAdaptor1::get_upper_dimension_hidden_idss()[itran].size();
650
651 constexpr auto up_dim_hidden_ids_1 =
652 TensorAdaptor1::get_upper_dimension_hidden_idss()[itran];
653
654 // sequence in, constexpr tuple out
655 constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr {
656 auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
657
658 // shift hidden id
659 static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
660 up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift;
661 });
662
663 return up_dim_hidden_ids_1_mod_;
664 }();
665
666 // constexpr tuple to sequence
668 [&](auto i) constexpr { return number<up_dim_hidden_ids_1_mod[i]>{}; },
670 },
671 number<TensorAdaptor1::get_num_of_transform()>{});
672
673 constexpr auto all_up_dim_hidden_idss =
674 container_concat(TensorAdaptor0::get_upper_dimension_hidden_idss(), up_dim_hidden_idss_1);
675
676 // bottom_dim_hidden_ids = bottom_dim_hidden_ids_0
677 constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::get_bottom_dimension_hidden_ids();
678
679 // top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
680 constexpr auto top_dim_hidden_ids =
681 TensorAdaptor1::get_top_dimension_hidden_ids() + number<adaptor1_hidden_id_shift>{};
682
683 // put everything together
684 return tensor_adaptor<remove_cvref_t<decltype(all_transforms)>,
685 remove_cvref_t<decltype(all_low_dim_hidden_idss)>,
686 remove_cvref_t<decltype(all_up_dim_hidden_idss)>,
687 remove_cvref_t<decltype(bottom_dim_hidden_ids)>,
688 remove_cvref_t<decltype(top_dim_hidden_ids)>>{all_transforms};
689}
690
691template <typename X,
692 typename... Xs,
693 typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
694CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
695{
697}
698
699} // namespace ck_tile
700
701// Macro function
702// construct constexpr tensor_adaptor from constexpr encoding
703// encoded_tensor_adaptor are Tuple of following objects:
704// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
705// 1.1 name (coord_transform_enum)
706// 1.2 meta data for constructor of the transform
707// 1.3 num of lower dimension (index_t)
708// 1.4 lower dimension Ids (array of fixed size)
709// 1.5 num of up dimension (index_t)
710// 1.6 upper dimension Ids (array of fixed size)
711// 2. num of transforms (index_t)
712// 3. encoded bottom dimension Ids (array of fixed size)
713// 4. num of bottom dimension (index_t)
714// 5. encoded top dimension Ids (array of fixed size)
715// 6. num of top dimension (index_t)
716#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
717 [encoded_tensor_adaptor]() { \
718 using namespace ck_tile; \
719 \
720 constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
721 constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
722 constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
723 constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
724 constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
725 constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
726 \
727 constexpr auto trans = [&encoded_transforms]() { \
728 return generate_tuple( \
729 [&encoded_transforms](auto i) constexpr { \
730 constexpr auto name = encoded_transforms[i].template at<0>(); \
731 constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
732 constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
733 constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
734 \
735 static_assert(name == coord_transform_enum::pass_through || \
736 name == coord_transform_enum::pad || \
737 name == coord_transform_enum::embed || \
738 name == coord_transform_enum::merge || \
739 name == coord_transform_enum::unmerge || \
740 name == coord_transform_enum::replicate, \
741 ""); \
742 \
743 if constexpr(name == coord_transform_enum::pass_through) \
744 { \
745 index_t pos = 0; \
746 auto low_len = meta_data.template pop<index_t>(pos); \
747 \
748 return make_pass_through_transform(low_len); \
749 } \
750 else if constexpr(name == coord_transform_enum::pad) \
751 { \
752 index_t pos = 0; \
753 auto low_len = meta_data.template pop<index_t>(pos); \
754 auto left_pad = meta_data.template pop<index_t>(pos); \
755 auto right_pad = meta_data.template pop<index_t>(pos); \
756 \
757 return make_pad_transform(low_len, left_pad, right_pad); \
758 } \
759 else if constexpr(name == coord_transform_enum::embed) \
760 { \
761 index_t pos = 0; \
762 auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
763 auto coefficients = \
764 meta_data.template pop<array<index_t, num_up_dim>>(pos); \
765 \
766 return make_embed_transform(up_lens, coefficients); \
767 } \
768 else if constexpr(name == coord_transform_enum::merge) \
769 { \
770 index_t pos = 0; \
771 auto low_lens = meta_data.template pop<array<index_t, num_low_dim>>(pos); \
772 \
773 return make_merge_transform(low_lens); \
774 } \
775 else if constexpr(name == coord_transform_enum::unmerge) \
776 { \
777 index_t pos = 0; \
778 auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
779 \
780 return make_unmerge_transform(up_lens); \
781 } \
782 else if constexpr(name == coord_transform_enum::replicate) \
783 { \
784 index_t pos = 0; \
785 auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
786 \
787 return make_replicate_transform(up_lens); \
788 } \
789 }, \
790 number<num_transform>{}); \
791 }(); \
792 \
793 constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \
794 return generate_tuple( \
795 [&encoded_transforms](auto i) { \
796 constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
797 constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
798 \
799 return TO_SEQUENCE(low_dims, num_low_dim); \
800 }, \
801 number<num_transform>()); \
802 }(); \
803 \
804 constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \
805 return generate_tuple( \
806 [&encoded_transforms](auto i) { \
807 constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
808 constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
809 \
810 return TO_SEQUENCE(up_dims, num_up_dim); \
811 }, \
812 number<num_transform>()); \
813 }(); \
814 \
815 constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
816 constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
817 \
818 return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
819 remove_cvref_t<decltype(low_dim_idss)>, \
820 remove_cvref_t<decltype(up_dim_idss)>, \
821 remove_cvref_t<decltype(bottom_dim_ids)>, \
822 remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
823 }()
824
825// Macro function
826// construct static tensor_adaptor from constexpr encoding
827// encoded_tensor_adaptor are Tuple of following objects:
828// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
829// 1.1 name (coord_transform_enum)
830// 1.2 meta data for constructor of the transform
831// 1.3 num of lower dimension (index_t)
832// 1.4 lower dimension Ids (array of fixed size)
833// 1.5 num of up dimension (index_t)
834// 1.6 upper dimension Ids (array of fixed size)
835// 2. num of transforms (index_t)
836// 3. encoded bottom dimension Ids (array of fixed size)
837// 4. num of bottom dimension (index_t)
838// 5. encoded top dimension Ids (array of fixed size)
839// 6. num of top dimension (index_t)
840#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
841 [encoded_tensor_adaptor]() { \
842 using namespace ck_tile; \
843 \
844 constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
845 constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
846 constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
847 constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
848 constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
849 constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
850 \
851 constexpr auto trans = [&encoded_transforms]() { \
852 return generate_tuple( \
853 [&encoded_transforms](auto i) constexpr { \
854 constexpr auto name = encoded_transforms[i].template at<0>(); \
855 constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
856 constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
857 constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
858 \
859 static_assert(name == coord_transform_enum::pass_through || \
860 name == coord_transform_enum::pad || \
861 name == coord_transform_enum::embed || \
862 name == coord_transform_enum::merge || \
863 name == coord_transform_enum::unmerge || \
864 name == coord_transform_enum::replicate, \
865 ""); \
866 \
867 if constexpr(name == coord_transform_enum::pass_through) \
868 { \
869 constexpr index_t low_len = meta_data.template get<index_t>(0); \
870 \
871 return make_pass_through_transform(number<low_len>{}); \
872 } \
873 else if constexpr(name == coord_transform_enum::pad) \
874 { \
875 constexpr index_t low_len = meta_data.template get<index_t>(0); \
876 \
877 constexpr index_t left_pad = \
878 meta_data.template get<index_t>(sizeof(low_len)); \
879 \
880 constexpr index_t right_pad = \
881 meta_data.template pop<index_t>(sizeof(low_len) + sizeof(left_pad)); \
882 \
883 return make_pad_transform( \
884 number<low_len>{}, number<left_pad>{}, number<right_pad>{}); \
885 } \
886 else if constexpr(name == coord_transform_enum::embed) \
887 { \
888 constexpr auto up_lens = \
889 meta_data.template get<array<index_t, num_up_dim>>(0); \
890 \
891 constexpr auto coefficients = \
892 meta_data.template get<array<index_t, num_up_dim>>(sizeof(up_lens)); \
893 \
894 return make_embed_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim), \
895 TO_TUPLE_OF_NUMBER(coefficients, num_up_dim)); \
896 } \
897 else if constexpr(name == coord_transform_enum::merge) \
898 { \
899 constexpr auto low_lens = \
900 meta_data.template get<array<index_t, num_low_dim>>(0); \
901 \
902 return make_merge_transform(TO_TUPLE_OF_NUMBER(low_lens, num_low_dim)); \
903 } \
904 else if constexpr(name == coord_transform_enum::unmerge) \
905 { \
906 constexpr auto up_lens = \
907 meta_data.template get<array<index_t, num_up_dim>>(0); \
908 \
909 return make_unmerge_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
910 } \
911 else if constexpr(name == coord_transform_enum::replicate) \
912 { \
913 constexpr auto up_lens = \
914 meta_data.template get<array<index_t, num_up_dim>>(0); \
915 \
916 return make_replicate_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
917 } \
918 }, \
919 number<num_transform>{}); \
920 }(); \
921 \
922 constexpr auto low_dim_idss = [&encoded_transforms]() { \
923 return generate_tuple( \
924 [&encoded_transforms](auto i) { \
925 constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
926 constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
927 \
928 return TO_SEQUENCE(low_dims, num_low_dim); \
929 }, \
930 number<num_transform>()); \
931 }(); \
932 \
933 constexpr auto up_dim_idss = [&encoded_transforms] { \
934 return generate_tuple( \
935 [&encoded_transforms](auto i) { \
936 constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
937 constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
938 \
939 return TO_SEQUENCE(up_dims, num_up_dim); \
940 }, \
941 number<num_transform>()); \
942 }(); \
943 \
944 constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
945 constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
946 \
947 return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
948 remove_cvref_t<decltype(low_dim_idss)>, \
949 remove_cvref_t<decltype(up_dim_idss)>, \
950 remove_cvref_t<decltype(bottom_dim_ids)>, \
951 remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
952 }()
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
typename std::remove_reference< T >::type remove_reference_t
Definition type_traits.hpp:15
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition tile/core/container/container_helper.hpp:198
CK_TILE_HOST_DEVICE constexpr auto generate_sequence_v2(F &&f, number< N >)
Definition tile/core/container/sequence.hpp:1045
CK_TILE_HOST_DEVICE constexpr auto transform_sequences(F f, sequence< Xs... >)
Definition tile/core/container/sequence.hpp:832
CK_TILE_HOST_DEVICE constexpr auto to_multi_index(const T &x)
Definition tile/core/container/multi_index.hpp:33
CK_TILE_HOST_DEVICE constexpr auto container_concat(const X &x, const Ys &... ys)
Definition tile/core/container/container_helper.hpp:363
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_adaptor(const OldTensorAdaptor &old_tensor_adaptor, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_adaptor.hpp:426
CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X &x)
Definition tile/core/container/tuple.hpp:505
CK_TILE_HOST_DEVICE constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_adaptor.hpp:359
CK_TILE_HOST_DEVICE constexpr void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition tile/core/container/container_helper.hpp:420
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
array< index_t, N > multi_index
Definition tile/core/container/multi_index.hpp:17
CK_TILE_HOST_DEVICE constexpr auto generate_sequence(F, number< N >)
Definition tile/core/container/sequence.hpp:1037
CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition tile/core/container/container_helper.hpp:389
CK_TILE_HOST_DEVICE constexpr auto unpack(F &&f, X &&x)
Definition tile/core/utility/functional.hpp:200
typename std::remove_cv< T >::type remove_cv_t
Definition type_traits.hpp:18
CK_TILE_HOST_DEVICE constexpr auto inclusive_scan_sequence(Seq, Reduce, number< Init >)
Definition tile/core/container/sequence.hpp:870
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition tile/core/tensor/tensor_adaptor.hpp:512
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition tile/core/container/sequence.hpp:287
typename std::conditional< kHasContent, type0, type1 >::type type
Definition tile/core/container/sequence.hpp:302
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
Definition tile/core/numeric/math.hpp:329
static constexpr bool value
Definition type_traits.hpp:77
Definition tile/core/container/sequence.hpp:670
Definition tile/core/tensor/tensor_adaptor.hpp:412
CK_TILE_HOST_DEVICE constexpr auto operator()(I) const
Definition tile/core/tensor/tensor_adaptor.hpp:414
Definition tile/core/numeric/math.hpp:371
Definition tile/core/numeric/math.hpp:98
static CK_TILE_HOST_DEVICE constexpr T max()
Definition tile/core/numeric/numeric.hpp:26
static CK_TILE_HOST_DEVICE constexpr T min()
Definition tile/core/numeric/numeric.hpp:20
Definition tile/core/numeric/math.hpp:50
Definition tile/core/container/sequence.hpp:593
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile/core/tensor/tensor_adaptor.hpp:28
static CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_safe_vector_length_strides(const array< index_t, ndim_hidden_ > &guaranteed_vector_lengths, const array< index_t, ndim_hidden_ > &guaranteed_vector_strides)
Definition tile/core/tensor/tensor_adaptor.hpp:263
static CK_TILE_HOST_DEVICE constexpr bool is_known_at_compile_time()
Definition tile/core/tensor/tensor_adaptor.hpp:260
multi_index< ndim_hidden_ > HiddenIndex
Definition tile/core/tensor/tensor_adaptor.hpp:144
CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_lengths() const
Definition tile/core/tensor/tensor_adaptor.hpp:201
static CK_TILE_HOST_DEVICE constexpr auto get_bottom_dimension_hidden_ids()
Definition tile/core/tensor/tensor_adaptor.hpp:46
static CK_TILE_HOST_DEVICE constexpr bool is_static()
Definition tile/core/tensor/tensor_adaptor.hpp:249
static CK_TILE_HOST_DEVICE constexpr auto get_upper_dimension_hidden_idss()
Definition tile/core/tensor/tensor_adaptor.hpp:41
CK_TILE_HOST_DEVICE constexpr auto get_hidden_dimension_length(number< IDimHidden >) const
Definition tile/core/tensor/tensor_adaptor.hpp:169
static CK_TILE_HOST_DEVICE constexpr auto get_lower_dimension_hidden_idss()
Definition tile/core/tensor/tensor_adaptor.hpp:36
static CK_TILE_HOST_DEVICE constexpr index_t get_num_of_hidden_dimension()
Definition tile/core/tensor/tensor_adaptor.hpp:120
static CK_TILE_HOST_DEVICE constexpr index_t get_num_of_top_dimension()
Definition tile/core/tensor/tensor_adaptor.hpp:115
static CK_TILE_HOST_DEVICE constexpr index_t get_num_of_transform()
Definition tile/core/tensor/tensor_adaptor.hpp:29
CK_TILE_HOST_DEVICE constexpr auto calculate_bottom_index(const TopIdx &idx_top) const
Definition tile/core/tensor/tensor_adaptor.hpp:217
CK_TILE_HOST_DEVICE constexpr const auto & get_transforms() const
Definition tile/core/tensor/tensor_adaptor.hpp:34
CK_TILE_HOST_DEVICE constexpr auto get_element_size() const
Definition tile/core/tensor/tensor_adaptor.hpp:165
multi_index< ndim_bottom_ > BottomIndex
Definition tile/core/tensor/tensor_adaptor.hpp:145
static CK_TILE_HOST_DEVICE constexpr auto initialize_element_size(const Transforms &transforms)
Definition tile/core/tensor/tensor_adaptor.hpp:56
static CK_TILE_HOST_DEVICE constexpr auto get_transform_and_its_upper_dimension(number< IDimHidden >)
Definition tile/core/tensor/tensor_adaptor.hpp:84
CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_length(number< IDimTop > idim_top) const
Definition tile/core/tensor/tensor_adaptor.hpp:186
remove_cv_t< decltype(initialize_element_size(Transforms{}))> ElementSize
Definition tile/core/tensor/tensor_adaptor.hpp:149
CK_TILE_HOST_DEVICE constexpr tensor_adaptor(const Transforms &transforms)
Definition tile/core/tensor/tensor_adaptor.hpp:154
static CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_hidden_ids()
Definition tile/core/tensor/tensor_adaptor.hpp:51
multi_index< ndim_top_ > TopIndex
Definition tile/core/tensor/tensor_adaptor.hpp:146
CK_TILE_HOST_DEVICE constexpr tensor_adaptor()=default
static CK_TILE_HOST_DEVICE constexpr index_t get_num_of_bottom_dimension()
Definition tile/core/tensor/tensor_adaptor.hpp:110