@@ -50,6 +50,14 @@ namespace tensor
5050namespace kernels
5151{
5252
53+ template <typename ReductionOpT, typename T> struct can_use_reduce_over_group
54+ {
55+ static constexpr bool value =
56+ sycl::has_known_identity<ReductionOpT, T>::value &&
57+ !std::is_same_v<T, std::int64_t > && !std::is_same_v<T, std::uint64_t > &&
58+ !std::is_same_v<ReductionOpT, sycl::multiplies<T>>;
59+ };
60+
5361template <typename argT,
5462 typename outT,
5563 typename ReductionOp,
@@ -477,7 +485,8 @@ sycl::event reduction_over_group_with_atomics_strided_impl(
477485 sycl::range<1 >{iter_nelems * reduction_groups * wg};
478486 auto localRange = sycl::range<1 >{wg};
479487
480- if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
488+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
489+ {
481490 using KernelName = class reduction_over_group_with_atomics_krn <
482491 argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
483492 ReductionIndexerT>;
@@ -618,7 +627,8 @@ sycl::event reduction_axis1_over_group_with_atomics_contig_impl(
618627 sycl::range<1 >{iter_nelems * reduction_groups * wg};
619628 auto localRange = sycl::range<1 >{wg};
620629
621- if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
630+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
631+ {
622632 using KernelName =
623633 class reduction_axis1_over_group_with_atomics_contig_krn <
624634 argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
@@ -717,7 +727,8 @@ sycl::event reduction_axis0_over_group_with_atomics_contig_impl(
717727 sycl::range<1 >{iter_nelems * reduction_groups * wg};
718728 auto localRange = sycl::range<1 >{wg};
719729
720- if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
730+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
731+ {
721732 using KernelName =
722733 class reduction_axis0_over_group_with_atomics_contig_krn <
723734 argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
@@ -1007,10 +1018,12 @@ sycl::event reduction_over_group_temps_strided_impl(
10071018 sycl::range<1 >{iter_nelems * reduction_groups * wg};
10081019 auto localRange = sycl::range<1 >{wg};
10091020
1010- if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
1021+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
1022+ {
10111023 using KernelName = class reduction_over_group_temps_krn <
10121024 argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
10131025 ReductionIndexerT>;
1026+
10141027 cgh.parallel_for <KernelName>(
10151028 sycl::nd_range<1 >(globalRange, localRange),
10161029 ReductionOverGroupNoAtomicFunctor<
@@ -1026,6 +1039,7 @@ sycl::event reduction_over_group_temps_strided_impl(
10261039 using KernelName = class custom_reduction_over_group_temps_krn <
10271040 argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
10281041 ReductionIndexerT, SlmT>;
1042+
10291043 cgh.parallel_for <KernelName>(
10301044 sycl::nd_range<1 >(globalRange, localRange),
10311045 CustomReductionOverGroupNoAtomicFunctor<
@@ -1062,68 +1076,67 @@ sycl::event reduction_over_group_temps_strided_impl(
10621076 partially_reduced_tmp + reduction_groups * iter_nelems;
10631077 }
10641078
1065- const sycl::event &first_reduction_ev =
1066- exec_q. submit ([&](sycl::handler &cgh) {
1067- cgh.depends_on (depends);
1079+ const sycl::event &first_reduction_ev = exec_q. submit ([&](sycl::handler
1080+ &cgh) {
1081+ cgh.depends_on (depends);
10681082
1069- using InputIndexerT =
1070- dpctl::tensor::offset_utils::StridedIndexer;
1071- using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
1072- using InputOutputIterIndexerT =
1073- dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
1074- InputIndexerT, ResIndexerT>;
1075- using ReductionIndexerT =
1076- dpctl::tensor::offset_utils::StridedIndexer;
1083+ using InputIndexerT = dpctl::tensor::offset_utils::StridedIndexer;
1084+ using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
1085+ using InputOutputIterIndexerT =
1086+ dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
1087+ InputIndexerT, ResIndexerT>;
1088+ using ReductionIndexerT =
1089+ dpctl::tensor::offset_utils::StridedIndexer;
10771090
1078- // Only 2*iter_nd entries describing shape and strides of
1079- // iterated dimensions of input array from
1080- // iter_shape_and_strides are going to be accessed by
1081- // inp_indexer
1082- InputIndexerT inp_indexer (iter_nd, iter_arg_offset,
1083- iter_shape_and_strides);
1084- ResIndexerT noop_tmp_indexer{};
1091+ // Only 2*iter_nd entries describing shape and strides of
1092+ // iterated dimensions of input array from
1093+ // iter_shape_and_strides are going to be accessed by
1094+ // inp_indexer
1095+ InputIndexerT inp_indexer (iter_nd, iter_arg_offset,
1096+ iter_shape_and_strides);
1097+ ResIndexerT noop_tmp_indexer{};
10851098
1086- InputOutputIterIndexerT in_out_iter_indexer{inp_indexer,
1087- noop_tmp_indexer};
1088- ReductionIndexerT reduction_indexer{
1089- red_nd, reduction_arg_offset, reduction_shape_stride};
1099+ InputOutputIterIndexerT in_out_iter_indexer{inp_indexer,
1100+ noop_tmp_indexer};
1101+ ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset,
1102+ reduction_shape_stride};
10901103
1091- auto globalRange =
1092- sycl::range<1 >{iter_nelems * reduction_groups * wg};
1093- auto localRange = sycl::range<1 >{wg};
1104+ auto globalRange =
1105+ sycl::range<1 >{iter_nelems * reduction_groups * wg};
1106+ auto localRange = sycl::range<1 >{wg};
10941107
1095- if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
1096- using KernelName = class reduction_over_group_temps_krn <
1108+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
1109+ {
1110+ using KernelName = class reduction_over_group_temps_krn <
1111+ argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1112+ ReductionIndexerT>;
1113+ cgh.parallel_for <KernelName>(
1114+ sycl::nd_range<1 >(globalRange, localRange),
1115+ ReductionOverGroupNoAtomicFunctor<
10971116 argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1098- ReductionIndexerT>;
1099- cgh.parallel_for <KernelName>(
1100- sycl::nd_range<1 >(globalRange, localRange),
1101- ReductionOverGroupNoAtomicFunctor<
1102- argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1103- ReductionIndexerT>(
1104- arg_tp, partially_reduced_tmp, ReductionOpT (),
1105- identity_val, in_out_iter_indexer,
1106- reduction_indexer, reduction_nelems, iter_nelems,
1107- preferrered_reductions_per_wi));
1108- }
1109- else {
1110- using SlmT = sycl::local_accessor<resTy, 1 >;
1111- SlmT local_memory = SlmT (localRange, cgh);
1112- using KernelName =
1113- class custom_reduction_over_group_temps_krn <
1114- argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1115- ReductionIndexerT, SlmT>;
1116- cgh.parallel_for <KernelName>(
1117- sycl::nd_range<1 >(globalRange, localRange),
1118- CustomReductionOverGroupNoAtomicFunctor<
1119- argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1120- ReductionIndexerT, SlmT>(
1121- arg_tp, partially_reduced_tmp, ReductionOpT (),
1122- identity_val, in_out_iter_indexer,
1123- reduction_indexer, local_memory, reduction_nelems,
1124- iter_nelems, preferrered_reductions_per_wi));
1125- }
1126- });
1117+ ReductionIndexerT>(
1118+ arg_tp, partially_reduced_tmp, ReductionOpT (),
1119+ identity_val, in_out_iter_indexer, reduction_indexer,
1120+ reduction_nelems, iter_nelems,
1121+ preferrered_reductions_per_wi));
1122+ }
1123+ else {
1124+ using SlmT = sycl::local_accessor<resTy, 1 >;
1125+ SlmT local_memory = SlmT (localRange, cgh);
1126+ using KernelName = class custom_reduction_over_group_temps_krn <
1127+ argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1128+ ReductionIndexerT, SlmT>;
1129+ cgh.parallel_for <KernelName>(
1130+ sycl::nd_range<1 >(globalRange, localRange),
1131+ CustomReductionOverGroupNoAtomicFunctor<
1132+ argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
1133+ ReductionIndexerT, SlmT>(
1134+ arg_tp, partially_reduced_tmp, ReductionOpT (),
1135+ identity_val, in_out_iter_indexer, reduction_indexer,
1136+ local_memory, reduction_nelems, iter_nelems,
1137+ preferrered_reductions_per_wi));
1138+ }
1139+ });
11271140
11281141 size_t remaining_reduction_nelems = reduction_groups;
11291142
@@ -1165,7 +1178,8 @@ sycl::event reduction_over_group_temps_strided_impl(
11651178 auto globalRange =
11661179 sycl::range<1 >{iter_nelems * reduction_groups_ * wg};
11671180 auto localRange = sycl::range<1 >{wg};
1168- if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
1181+ if constexpr (can_use_reduce_over_group<ReductionOpT,
1182+ resTy>::value) {
11691183 using KernelName = class reduction_over_group_temps_krn <
11701184 resTy, resTy, ReductionOpT, InputOutputIterIndexerT,
11711185 ReductionIndexerT>;
@@ -1240,7 +1254,8 @@ sycl::event reduction_over_group_temps_strided_impl(
12401254 sycl::range<1 >{iter_nelems * reduction_groups * wg};
12411255 auto localRange = sycl::range<1 >{wg};
12421256
1243- if constexpr (su_ns::IsSyclOp<resTy, ReductionOpT>::value) {
1257+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
1258+ {
12441259 using KernelName = class reduction_over_group_temps_krn <
12451260 argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
12461261 ReductionIndexerT>;
@@ -2564,7 +2579,8 @@ sycl::event search_reduction_over_group_temps_strided_impl(
25642579 sycl::range<1 >{iter_nelems * reduction_groups * wg};
25652580 auto localRange = sycl::range<1 >{wg};
25662581
2567- if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
2582+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
2583+ {
25682584 using KernelName = class search_reduction_over_group_temps_krn <
25692585 argTy, resTy, ReductionOpT, IndexOpT,
25702586 InputOutputIterIndexerT, ReductionIndexerT, true , true >;
@@ -2663,7 +2679,8 @@ sycl::event search_reduction_over_group_temps_strided_impl(
26632679 sycl::range<1 >{iter_nelems * reduction_groups * wg};
26642680 auto localRange = sycl::range<1 >{wg};
26652681
2666- if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
2682+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
2683+ {
26672684 using KernelName = class search_reduction_over_group_temps_krn <
26682685 argTy, resTy, ReductionOpT, IndexOpT,
26692686 InputOutputIterIndexerT, ReductionIndexerT, true , false >;
@@ -2743,7 +2760,8 @@ sycl::event search_reduction_over_group_temps_strided_impl(
27432760 auto globalRange =
27442761 sycl::range<1 >{iter_nelems * reduction_groups_ * wg};
27452762 auto localRange = sycl::range<1 >{wg};
2746- if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
2763+ if constexpr (can_use_reduce_over_group<ReductionOpT,
2764+ resTy>::value) {
27472765 using KernelName =
27482766 class search_reduction_over_group_temps_krn <
27492767 argTy, resTy, ReductionOpT, IndexOpT,
@@ -2826,7 +2844,8 @@ sycl::event search_reduction_over_group_temps_strided_impl(
28262844 sycl::range<1 >{iter_nelems * reduction_groups * wg};
28272845 auto localRange = sycl::range<1 >{wg};
28282846
2829- if constexpr (su_ns::IsSyclOp<argTy, ReductionOpT>::value) {
2847+ if constexpr (can_use_reduce_over_group<ReductionOpT, resTy>::value)
2848+ {
28302849 using KernelName = class search_reduction_over_group_temps_krn <
28312850 argTy, resTy, ReductionOpT, IndexOpT,
28322851 InputOutputIterIndexerT, ReductionIndexerT, false , true >;
0 commit comments