@@ -105,6 +105,42 @@ template <typename argT1, typename argT2, typename resT> struct AddFunctor
105105 tmp);
106106 }
107107 }
108+
109+ template <int vec_sz>
110+ sycl::vec<resT, vec_sz> operator ()(const sycl::vec<argT1, vec_sz> &in1,
111+ const argT2 &in2) const
112+ {
113+ auto tmp = in1 + in2;
114+ if constexpr (std::is_same_v<resT,
115+ typename decltype (tmp)::element_type>)
116+ {
117+ return tmp;
118+ }
119+ else {
120+ using dpctl::tensor::type_utils::vec_cast;
121+
122+ return vec_cast<resT, typename decltype (tmp)::element_type, vec_sz>(
123+ tmp);
124+ }
125+ }
126+
127+ template <int vec_sz>
128+ sycl::vec<resT, vec_sz>
129+ operator ()(const argT1 &in1, const sycl::vec<argT2, vec_sz> &in2) const
130+ {
131+ auto tmp = in1 + in2;
132+ if constexpr (std::is_same_v<resT,
133+ typename decltype (tmp)::element_type>)
134+ {
135+ return tmp;
136+ }
137+ else {
138+ using dpctl::tensor::type_utils::vec_cast;
139+
140+ return vec_cast<resT, typename decltype (tmp)::element_type, vec_sz>(
141+ tmp);
142+ }
143+ }
108144};
109145
110146template <typename argT1,
@@ -393,6 +429,126 @@ struct AddContigRowContigMatrixBroadcastFactory
393429 }
394430};
395431
432+ template <typename argT1,
433+ typename argT2,
434+ typename resT,
435+ unsigned int vec_sz = 4 ,
436+ unsigned int n_vecs = 2 ,
437+ bool enable_sg_loadstore = true >
438+ using AddScalarContigArrayFunctor =
439+ elementwise_common::BinaryScalarContigArrayFunctor<
440+ argT1,
441+ argT2,
442+ resT,
443+ AddFunctor<argT1, argT2, resT>,
444+ vec_sz,
445+ n_vecs,
446+ enable_sg_loadstore>;
447+
448+ template <typename argT1,
449+ typename argT2,
450+ typename resT,
451+ unsigned int vec_sz = 4 ,
452+ unsigned int n_vecs = 2 ,
453+ bool enable_sg_loadstore = true >
454+ using AddContigArrayScalarFunctor =
455+ elementwise_common::BinaryContigArrayScalarFunctor<
456+ argT1,
457+ argT2,
458+ resT,
459+ AddFunctor<argT1, argT2, resT>,
460+ vec_sz,
461+ n_vecs,
462+ enable_sg_loadstore>;
463+
464+ template <typename argT1,
465+ typename argT2,
466+ typename resT,
467+ unsigned int vec_sz,
468+ unsigned int n_vecs>
469+ class add_scalar_contig_array_kernel ;
470+
471+ template <typename argTy1, typename argTy2>
472+ sycl::event
473+ add_scalar_contig_array_impl (sycl::queue &exec_q,
474+ size_t nelems,
475+ const char *arg1_p,
476+ ssize_t arg1_offset,
477+ const char *arg2_p,
478+ ssize_t arg2_offset,
479+ char *res_p,
480+ ssize_t res_offset,
481+ const std::vector<sycl::event> &depends = {})
482+ {
483+ return elementwise_common::binary_scalar_contig_array_impl<
484+ argTy1, argTy2, AddOutputType, AddScalarContigArrayFunctor,
485+ add_scalar_contig_array_kernel>(exec_q, nelems, arg1_p, arg1_offset,
486+ arg2_p, arg2_offset, res_p, res_offset,
487+ depends);
488+ }
489+
490+ template <typename argT1,
491+ typename argT2,
492+ typename resT,
493+ unsigned int vec_sz,
494+ unsigned int n_vecs>
495+ class add_contig_array_scalar_kernel ;
496+
497+ template <typename argTy1, typename argTy2>
498+ sycl::event
499+ add_contig_array_scalar_impl (sycl::queue &exec_q,
500+ size_t nelems,
501+ const char *arg1_p,
502+ ssize_t arg1_offset,
503+ const char *arg2_p,
504+ ssize_t arg2_offset,
505+ char *res_p,
506+ ssize_t res_offset,
507+ const std::vector<sycl::event> &depends = {})
508+ {
509+ return elementwise_common::binary_contig_array_scalar_impl<
510+ argTy1, argTy2, AddOutputType, AddContigArrayScalarFunctor,
511+ add_contig_array_scalar_kernel>(exec_q, nelems, arg1_p, arg1_offset,
512+ arg2_p, arg2_offset, res_p, res_offset,
513+ depends);
514+ }
515+
516+ template <typename fnT, typename T1, typename T2>
517+ struct AddScalarContigArrayFactory
518+ {
519+ fnT get ()
520+ {
521+ if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
522+ void >)
523+ {
524+ fnT fn = nullptr ;
525+ return fn;
526+ }
527+ else {
528+ fnT fn = add_scalar_contig_array_impl<T1, T2>;
529+ return fn;
530+ }
531+ }
532+ };
533+
534+ template <typename fnT, typename T1, typename T2>
535+ struct AddContigArrayScalarFactory
536+ {
537+ fnT get ()
538+ {
539+ if constexpr (std::is_same_v<typename AddOutputType<T1, T2>::value_type,
540+ void >)
541+ {
542+ fnT fn = nullptr ;
543+ return fn;
544+ }
545+ else {
546+ fnT fn = add_contig_array_scalar_impl<T1, T2>;
547+ return fn;
548+ }
549+ }
550+ };
551+
396552template <typename argT, typename resT> struct AddInplaceFunctor
397553{
398554
@@ -409,6 +565,12 @@ template <typename argT, typename resT> struct AddInplaceFunctor
409565 {
410566 res += in;
411567 }
568+
569+ template <int vec_sz>
570+ void operator ()(sycl::vec<resT, vec_sz> &res, const argT &in)
571+ {
572+ res += in;
573+ }
412574};
413575
414576template <typename argT,
@@ -606,6 +768,58 @@ struct AddInplaceRowMatrixBroadcastFactory
606768 }
607769};
608770
771+ template <typename argT,
772+ typename resT,
773+ unsigned int vec_sz = 4 ,
774+ unsigned int n_vecs = 2 ,
775+ bool enable_sg_loadstore = true >
776+ using AddInplaceScalarContigFunctor =
777+ elementwise_common::BinaryInplaceScalarContigFunctor<
778+ argT,
779+ resT,
780+ AddInplaceFunctor<argT, resT>,
781+ vec_sz,
782+ n_vecs,
783+ enable_sg_loadstore>;
784+
785+ template <typename argT,
786+ typename resT,
787+ unsigned int vec_sz,
788+ unsigned int n_vecs>
789+ class add_inplace_scalar_contig_kernel ;
790+
791+ template <typename argTy, typename resTy>
792+ sycl::event
793+ add_inplace_scalar_contig_impl (sycl::queue &exec_q,
794+ size_t nelems,
795+ const char *arg_p,
796+ ssize_t arg_offset,
797+ char *res_p,
798+ ssize_t res_offset,
799+ const std::vector<sycl::event> &depends = {})
800+ {
801+ return elementwise_common::binary_inplace_scalar_contig_impl<
802+ argTy, resTy, AddInplaceScalarContigFunctor,
803+ add_inplace_scalar_contig_kernel>(exec_q, nelems, arg_p, arg_offset,
804+ res_p, res_offset, depends);
805+ }
806+
807+ template <typename fnT, typename T1, typename T2>
808+ struct AddInplaceScalarContigFactory
809+ {
810+ fnT get ()
811+ {
812+ if constexpr (!AddInplaceTypePairSupport<T1, T2>::is_defined) {
813+ fnT fn = nullptr ;
814+ return fn;
815+ }
816+ else {
817+ fnT fn = add_inplace_scalar_contig_impl<T1, T2>;
818+ return fn;
819+ }
820+ }
821+ };
822+
609823} // namespace add
610824} // namespace kernels
611825} // namespace tensor
0 commit comments