1414#include " oneapi/mkl.hpp"
1515#include " utils.hpp"
1616
17+ #include < pybind11/stl.h>
18+
1719using namespace sycl ;
1820namespace py = pybind11;
1921
@@ -65,15 +67,15 @@ __attribute__((always_inline)) DataType bilinear(const DataType *input,
6567 return result / 2 ;
6668}
6769
68- class deform ;
70+ template < class DataType > class deform ;
6971
7072template <class DataType >
7173inline auto deform_input (cl::sycl::queue &queue,
7274 const DataType *input,
7375 const Shape3D in_shape,
7476 DataType *output,
7577 const Shape5D out_shape,
76- const DataType *offset,
78+ const float *offset,
7779 int stride_y,
7880 int stride_x,
7981 int pad_y,
@@ -94,7 +96,7 @@ inline auto deform_input(cl::sycl::queue &queue,
9496
9597 auto wsize =
9698 sycl::range<3 >(in_channels * k_height * k_width, out_height, out_width);
97- return queue.parallel_for <deform>(wsize, [=](sycl::id<3 > idx) {
99+ return queue.parallel_for <deform<DataType> >(wsize, [=](sycl::id<3 > idx) {
98100 auto ckhkw = static_cast <int >(idx[0 ]);
99101 auto h = static_cast <int >(idx[1 ]);
100102 auto w = static_cast <int >(idx[2 ]);
@@ -127,19 +129,19 @@ inline auto deform_input(cl::sycl::queue &queue,
127129 });
128130}
129131
130- class fill_output ;
132+ template < class DataType > class fill_output ;
131133
132134template <class DataType >
133135auto output_fill_with_bias (cl::sycl::queue &queue,
134136 DataType *output,
135137 const Shape3D out_shape,
136- DataType *bias)
138+ const DataType *bias)
137139{
138140 auto out_c = out_shape[CHW::C];
139141 auto out_h = out_shape[CHW::H];
140142 auto out_w = out_shape[CHW::W];
141143
142- return queue.parallel_for <fill_output>(
144+ return queue.parallel_for <fill_output<DataType> >(
143145 sycl::range<3 >(out_c, out_h, out_w), [=](sycl::id<3 > idx) {
144146 auto c = static_cast <int >(idx[0 ]);
145147 auto h = static_cast <int >(idx[1 ]);
@@ -158,9 +160,9 @@ void deformable_convolution_b1_impl(cl::sycl::queue &queue,
158160 const Shape3D out_shape,
159161 DataType *tmp,
160162 const float *offset,
161- DataType *weights,
163+ const DataType *weights,
162164 const Shape4D weights_shape,
163- DataType *bias,
165+ const DataType *bias,
164166 int stride_y,
165167 int stride_x,
166168 int pad_y,
@@ -210,9 +212,9 @@ void deformable_convolution_impl(cl::sycl::queue &queue,
210212 const Shape4D out_shape,
211213 DataType *tmp,
212214 const float *offset,
213- DataType *weights,
215+ const DataType *weights,
214216 const Shape4D weights_shape,
215- DataType *bias,
217+ const DataType *bias,
216218 int stride_y,
217219 int stride_x,
218220 int pad_y,
@@ -247,25 +249,45 @@ void deformable_convolution_impl(cl::sycl::queue &queue,
247249 }
248250}
249251
252+ template <typename ... Args> bool ensure_compatibility (const Args &...args)
253+ {
254+ std::vector<dpctl::tensor::usm_ndarray> arrays = {args...};
255+
256+ auto arr = arrays.at (0 );
257+
258+ for (auto &arr : arrays) {
259+ if (!(arr.get_flags () & (USM_ARRAY_C_CONTIGUOUS))) {
260+ std::cerr << " All arrays need to be C contiguous.\n " ;
261+ return false ;
262+ }
263+ }
264+ return true ;
265+ }
266+
250267void deformable_convolution (dpctl::tensor::usm_ndarray input,
251268 dpctl::tensor::usm_ndarray output,
252269 dpctl::tensor::usm_ndarray offset,
253270 dpctl::tensor::usm_ndarray weights,
254271 dpctl::tensor::usm_ndarray bias,
255272 dpctl::tensor::usm_ndarray tmp,
256- int stride_y,
257- int stride_x,
258- int pad_y,
259- int pad_x,
260- int dilation_y,
261- int dilation_x,
273+ std::vector<int > stride_hw,
274+ std::vector<int > pad_hw,
275+ std::vector<int > dilation_hw,
262276 int groups,
263277 int deformable_groups)
264278{
265279 auto queue = input.get_queue ();
266280
267- if (input.get_typenum () != UAR_FLOAT) {
268- throw std::runtime_error (" Expected a single precision FP array." );
281+ if (!ensure_compatibility (input, output, offset, weights, bias, tmp))
282+ throw std::runtime_error (" Input arrays are not acceptable." );
283+
284+ if (input.get_typenum () != output.get_typenum () or
285+ input.get_typenum () != offset.get_typenum () or
286+ input.get_typenum () != weights.get_typenum () or
287+ input.get_typenum () != bias.get_typenum () or
288+ input.get_typenum () != tmp.get_typenum ())
289+ {
290+ throw std::runtime_error (" All arrays must have the same precision" );
269291 }
270292
271293 int batch = input.get_shape (0 );
@@ -281,17 +303,39 @@ void deformable_convolution(dpctl::tensor::usm_ndarray input,
281303 int kernel_height = weights.get_shape (2 );
282304 int kernel_width = weights.get_shape (3 );
283305
306+ auto stride_y = stride_hw[0 ];
307+ auto stride_x = stride_hw[1 ];
308+
309+ auto pad_y = pad_hw[0 ];
310+ auto pad_x = pad_hw[1 ];
311+
312+ auto dilation_y = pad_hw[0 ];
313+ auto dilation_x = pad_hw[1 ];
314+
284315 auto input_shape = Shape4D ({batch, in_channels, in_height, in_width});
285316 auto output_shape = Shape4D ({batch, out_channels, out_height, out_width});
286317 auto weights_shape =
287318 Shape4D ({out_channels, in_channels, kernel_height, kernel_width});
288319
289- deformable_convolution_impl (
290- queue, input.get_data <float >(), input_shape, output.get_data <float >(),
291- output_shape, tmp.get_data <float >(), offset.get_data <float >(),
292- weights.get_data <float >(), weights_shape, bias.get_data <float >(),
293- stride_y, stride_x, pad_y, pad_x, dilation_y, dilation_x, groups,
294- deformable_groups);
320+ #define dispatch_dc (typ ) \
321+ deformable_convolution_impl<typ>( \
322+ queue, input.get_data <typ>(), input_shape, output.get_data <typ>(), \
323+ output_shape, tmp.get_data <typ>(), offset.get_data <float >(), \
324+ weights.get_data <typ>(), weights_shape, bias.get_data <typ>(), \
325+ stride_y, stride_x, pad_y, pad_x, dilation_y, dilation_x, groups, \
326+ deformable_groups)
327+
328+ if (input.get_typenum () == UAR_FLOAT) {
329+ dispatch_dc (float );
330+ }
331+ else if (input.get_typenum () == UAR_DOUBLE) {
332+ dispatch_dc (double );
333+ }
334+ else {
335+ throw std::runtime_error (" Unsupported type" );
336+ }
337+
338+ #undef dispatch_dc
295339}
296340
297341PYBIND11_MODULE (_deformable_convolution_sycl, m)
@@ -301,8 +345,7 @@ PYBIND11_MODULE(_deformable_convolution_sycl, m)
301345 m.def (" deformable_convolution" , &deformable_convolution,
302346 " Defromable convolution" , py::arg (" input" ), py::arg (" output" ),
303347 py::arg (" offset" ), py::arg (" weights" ), py::arg (" bias" ),
304- py::arg (" tmp" ), py::arg (" stride_y" ), py::arg (" stride_x" ),
305- py::arg (" pad_y" ), py::arg (" pad_x" ), py::arg (" dilation_y" ),
306- py::arg (" dilation_x" ), py::arg (" groups" ),
348+ py::arg (" tmp" ), py::arg (" stride_hw" ), py::arg (" pad_hw" ),
349+ py::arg (" dilation_hw" ), py::arg (" groups" ),
307350 py::arg (" deformable_groups" ));
308351}
0 commit comments