@@ -3401,6 +3401,125 @@ struct LogSumExpOverAxis0TempsContigFactory
34013401
34023402// Argmax and Argmin
34033403
3404+ /* Sequential search reduction */
3405+
3406+ template <typename argT,
3407+ typename outT,
3408+ typename ReductionOp,
3409+ typename IdxReductionOp,
3410+ typename InputOutputIterIndexerT,
3411+ typename InputRedIndexerT>
3412+ struct SequentialSearchReduction
3413+ {
3414+ private:
3415+ const argT *inp_ = nullptr ;
3416+ outT *out_ = nullptr ;
3417+ ReductionOp reduction_op_;
3418+ argT identity_;
3419+ IdxReductionOp idx_reduction_op_;
3420+ outT idx_identity_;
3421+ InputOutputIterIndexerT inp_out_iter_indexer_;
3422+ InputRedIndexerT inp_reduced_dims_indexer_;
3423+ size_t reduction_max_gid_ = 0 ;
3424+
3425+ public:
3426+ SequentialSearchReduction (const argT *inp,
3427+ outT *res,
3428+ ReductionOp reduction_op,
3429+ const argT &identity_val,
3430+ IdxReductionOp idx_reduction_op,
3431+ const outT &idx_identity_val,
3432+ InputOutputIterIndexerT arg_res_iter_indexer,
3433+ InputRedIndexerT arg_reduced_dims_indexer,
3434+ size_t reduction_size)
3435+ : inp_(inp), out_(res), reduction_op_(reduction_op),
3436+ identity_ (identity_val), idx_reduction_op_(idx_reduction_op),
3437+ idx_identity_(idx_identity_val),
3438+ inp_out_iter_indexer_(arg_res_iter_indexer),
3439+ inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
3440+ reduction_max_gid_(reduction_size)
3441+ {
3442+ }
3443+
3444+ void operator ()(sycl::id<1 > id) const
3445+ {
3446+
3447+ auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_ (id[0 ]);
3448+ const py::ssize_t &inp_iter_offset =
3449+ inp_out_iter_offsets_.get_first_offset ();
3450+ const py::ssize_t &out_iter_offset =
3451+ inp_out_iter_offsets_.get_second_offset ();
3452+
3453+ argT red_val (identity_);
3454+ outT idx_val (idx_identity_);
3455+ for (size_t m = 0 ; m < reduction_max_gid_; ++m) {
3456+ const py::ssize_t inp_reduction_offset =
3457+ inp_reduced_dims_indexer_ (m);
3458+ const py::ssize_t inp_offset =
3459+ inp_iter_offset + inp_reduction_offset;
3460+
3461+ argT val = inp_[inp_offset];
3462+ if (val == red_val) {
3463+ idx_val = idx_reduction_op_ (idx_val, static_cast <outT>(m));
3464+ }
3465+ else {
3466+ if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
3467+ using dpctl::tensor::type_utils::is_complex;
3468+ if constexpr (is_complex<argT>::value) {
3469+ using dpctl::tensor::math_utils::less_complex;
3470+ // less_complex always returns false for NaNs, so check
3471+ if (less_complex<argT>(val, red_val) ||
3472+ std::isnan (std::real (val)) ||
3473+ std::isnan (std::imag (val)))
3474+ {
3475+ red_val = val;
3476+ idx_val = static_cast <outT>(m);
3477+ }
3478+ }
3479+ else if constexpr (std::is_floating_point_v<argT>) {
3480+ if (val < red_val || std::isnan (val)) {
3481+ red_val = val;
3482+ idx_val = static_cast <outT>(m);
3483+ }
3484+ }
3485+ else {
3486+ if (val < red_val) {
3487+ red_val = val;
3488+ idx_val = static_cast <outT>(m);
3489+ }
3490+ }
3491+ }
3492+ else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
3493+ using dpctl::tensor::type_utils::is_complex;
3494+ if constexpr (is_complex<argT>::value) {
3495+ using dpctl::tensor::math_utils::greater_complex;
3496+ if (greater_complex<argT>(val, red_val) ||
3497+ std::isnan (std::real (val)) ||
3498+ std::isnan (std::imag (val)))
3499+ {
3500+ red_val = val;
3501+ idx_val = static_cast <outT>(m);
3502+ }
3503+ }
3504+ else if constexpr (std::is_floating_point_v<argT>) {
3505+ if (val > red_val || std::isnan (val)) {
3506+ red_val = val;
3507+ idx_val = static_cast <outT>(m);
3508+ }
3509+ }
3510+ else {
3511+ if (val > red_val) {
3512+ red_val = val;
3513+ idx_val = static_cast <outT>(m);
3514+ }
3515+ }
3516+ }
3517+ }
3518+ }
3519+ out_[out_iter_offset] = idx_val;
3520+ }
3521+ };
3522+
34043523/* = Search reduction using reduce_over_group*/
34053524
34063525template <typename argT,
@@ -3799,6 +3918,14 @@ typedef sycl::event (*search_strided_impl_fn_ptr)(
37993918 py::ssize_t ,
38003919 const std::vector<sycl::event> &);
38013920
3921+ template <typename T1,
3922+ typename T2,
3923+ typename T3,
3924+ typename T4,
3925+ typename T5,
3926+ typename T6>
3927+ class search_seq_strided_krn ;
3928+
38023929template <typename T1,
38033930 typename T2,
38043931 typename T3,
@@ -3820,6 +3947,14 @@ template <typename T1,
38203947 bool b2>
38213948class custom_search_over_group_temps_strided_krn ;
38223949
3950+ template <typename T1,
3951+ typename T2,
3952+ typename T3,
3953+ typename T4,
3954+ typename T5,
3955+ typename T6>
3956+ class search_seq_contig_krn ;
3957+
38233958template <typename T1,
38243959 typename T2,
38253960 typename T3,
@@ -4019,6 +4154,36 @@ sycl::event search_over_group_temps_strided_impl(
40194154 const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
40204155 size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
40214156
4157+ if (reduction_nelems < wg) {
4158+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
4159+ cgh.depends_on (depends);
4160+
4161+ using InputOutputIterIndexerT =
4162+ dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
4163+ using ReductionIndexerT =
4164+ dpctl::tensor::offset_utils::StridedIndexer;
4165+
4166+ InputOutputIterIndexerT in_out_iter_indexer{
4167+ iter_nd, iter_arg_offset, iter_res_offset,
4168+ iter_shape_and_strides};
4169+ ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset,
4170+ reduction_shape_stride};
4171+
4172+ cgh.parallel_for <class search_seq_strided_krn <
4173+ argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
4174+ ReductionIndexerT>>(
4175+ sycl::range<1 >(iter_nelems),
4176+ SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
4177+ InputOutputIterIndexerT,
4178+ ReductionIndexerT>(
4179+ arg_tp, res_tp, ReductionOpT (), identity_val, IndexOpT (),
4180+ idx_identity_val, in_out_iter_indexer, reduction_indexer,
4181+ reduction_nelems));
4182+ });
4183+
4184+ return comp_ev;
4185+ }
4186+
40224187 constexpr size_t preferred_reductions_per_wi = 4 ;
40234188 // max_max_wg prevents running out of resources on CPU
40244189 size_t max_wg =
@@ -4419,6 +4584,39 @@ sycl::event search_axis1_over_group_temps_contig_impl(
44194584 const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
44204585 size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
44214586
4587+ if (reduction_nelems < wg) {
4588+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
4589+ cgh.depends_on (depends);
4590+
4591+ using InputIterIndexerT =
4592+ dpctl::tensor::offset_utils::Strided1DIndexer;
4593+ using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
4594+ using InputOutputIterIndexerT =
4595+ dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
4596+ InputIterIndexerT, NoOpIndexerT>;
4597+ using ReductionIndexerT = NoOpIndexerT;
4598+
4599+ InputOutputIterIndexerT in_out_iter_indexer{
4600+ InputIterIndexerT{0 , static_cast <py::ssize_t >(iter_nelems),
4601+ static_cast <py::ssize_t >(reduction_nelems)},
4602+ NoOpIndexerT{}};
4603+ ReductionIndexerT reduction_indexer{};
4604+
4605+ cgh.parallel_for <class search_seq_contig_krn <
4606+ argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
4607+ ReductionIndexerT>>(
4608+ sycl::range<1 >(iter_nelems),
4609+ SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
4610+ InputOutputIterIndexerT,
4611+ ReductionIndexerT>(
4612+ arg_tp, res_tp, ReductionOpT (), identity_val, IndexOpT (),
4613+ idx_identity_val, in_out_iter_indexer, reduction_indexer,
4614+ reduction_nelems));
4615+ });
4616+
4617+ return comp_ev;
4618+ }
4619+
44224620 constexpr size_t preferred_reductions_per_wi = 8 ;
44234621 // max_max_wg prevents running out of resources on CPU
44244622 size_t max_wg =
@@ -4801,6 +4999,43 @@ sycl::event search_axis0_over_group_temps_contig_impl(
48014999 const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
48025000 size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
48035001
5002+ if (reduction_nelems < wg) {
5003+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
5004+ cgh.depends_on (depends);
5005+
5006+ using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
5007+ using InputOutputIterIndexerT =
5008+ dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
5009+ NoOpIndexerT, NoOpIndexerT>;
5010+ using ReductionIndexerT =
5011+ dpctl::tensor::offset_utils::Strided1DIndexer;
5012+
5013+ InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{},
5014+ NoOpIndexerT{}};
5015+ ReductionIndexerT reduction_indexer{
5016+ 0 , static_cast <py::ssize_t >(reduction_nelems),
5017+ static_cast <py::ssize_t >(iter_nelems)};
5018+
5019+ using KernelName =
5020+ class search_seq_contig_krn <argTy, resTy, ReductionOpT,
5021+ IndexOpT, InputOutputIterIndexerT,
5022+ ReductionIndexerT>;
5023+
5024+ sycl::range<1 > iter_range{iter_nelems};
5025+
5026+ cgh.parallel_for <KernelName>(
5027+ iter_range,
5028+ SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
5029+ InputOutputIterIndexerT,
5030+ ReductionIndexerT>(
5031+ arg_tp, res_tp, ReductionOpT (), identity_val, IndexOpT (),
5032+ idx_identity_val, in_out_iter_indexer, reduction_indexer,
5033+ reduction_nelems));
5034+ });
5035+
5036+ return comp_ev;
5037+ }
5038+
48045039 constexpr size_t preferred_reductions_per_wi = 8 ;
48055040 // max_max_wg prevents running out of resources on CPU
48065041 size_t max_wg =
0 commit comments