Skip to content

Commit c5bddb3

Browse files
Some cleanup
1 parent e7face2 commit c5bddb3

File tree

4 files changed

+169
-135
lines changed

4 files changed

+169
-135
lines changed

dpbench/benchmarks/deformable_convolution/deformable_convolution_initialize.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,45 +5,49 @@
55

66
def initialize(
77
batch,
8-
in_channels,
9-
in_height,
10-
in_width,
11-
out_channels,
12-
out_height,
13-
out_width,
14-
kernel_height,
15-
kernel_width,
16-
stride_y,
17-
stride_x,
18-
dilation_y,
19-
dilation_x,
20-
pad_y,
21-
pad_x,
8+
in_chw,
9+
out_chw,
10+
kernel_hw,
11+
stride_hw,
12+
dilation_hw,
13+
pad_hw,
2214
groups,
2315
deformable_groups,
24-
dtype,
2516
seed,
17+
types_dict,
2618
):
2719
import numpy as np
2820
import numpy.random as default_rng
2921

22+
dtype: np.dtype = types_dict["float"]
23+
3024
default_rng.seed(seed)
3125

26+
input_size = [batch] + in_chw # nchw
27+
output_size = [batch] + out_chw # nchw
28+
offset_size = kernel_hw + [2, out_chw[1], out_chw[2]] # kh, kw, 2, oh, ow
29+
weights_size = [out_chw[0], in_chw[0]] + kernel_hw # oc, ic, kh, kw
30+
bias_size = out_chw[0] # oc
31+
tmp_size = [
32+
in_chw[0],
33+
kernel_hw[0],
34+
kernel_hw[1],
35+
out_chw[1],
36+
out_chw[2],
37+
] # ic, kh, kw, oh, ow
38+
39+
input = default_rng.random(input_size).astype(dtype)
40+
output = np.empty(output_size, dtype=dtype)
41+
offset = 2 * default_rng.random(offset_size).astype("float32") - 1
42+
weights = default_rng.random(weights_size).astype(dtype)
43+
bias = default_rng.random(bias_size).astype(dtype)
44+
tmp = np.empty(tmp_size, dtype=dtype)
45+
3246
return (
33-
default_rng.random((batch, in_channels, in_height, in_width)).astype(
34-
dtype
35-
),
36-
np.zeros((batch, out_channels, out_height, out_width)).astype(dtype),
37-
2
38-
* default_rng.random(
39-
(kernel_height, kernel_width, 2, out_height, out_width)
40-
).astype(dtype)
41-
- 1,
42-
np.ones(
43-
(out_channels, in_channels, kernel_height, kernel_width)
44-
).astype(dtype),
45-
default_rng.random(out_channels).astype(dtype),
46-
np.zeros(
47-
(in_channels, kernel_height, kernel_width, out_height, out_width)
48-
).astype(dtype),
47+
input,
48+
output,
49+
offset,
50+
weights,
51+
bias,
52+
tmp,
4953
)

dpbench/benchmarks/deformable_convolution/deformable_convolution_numba_mlir_p.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,19 +108,16 @@ def deformable_convolution_b1(
108108

109109

110110
@njit(parallel=True, gpu_fp64_truncate="auto")
111-
def deformable_convolution(
111+
def jdeformable_convolution(
112112
input,
113113
output,
114114
offset,
115115
weights,
116116
bias,
117117
tmp,
118-
stride_y,
119-
stride_x,
120-
pad_y,
121-
pad_x,
122-
dilation_y,
123-
dilation_x,
118+
stride,
119+
pad,
120+
dilation,
124121
groups,
125122
deformable_groups,
126123
):
@@ -133,9 +130,37 @@ def deformable_convolution(
133130
weights,
134131
bias,
135132
tmp,
136-
(stride_y, stride_x),
137-
(pad_y, pad_x),
138-
(dilation_y, dilation_x),
133+
stride,
134+
pad,
135+
dilation,
139136
groups,
140137
deformable_groups,
141138
)
139+
140+
141+
def deformable_convolution(
142+
input,
143+
output,
144+
offset,
145+
weights,
146+
bias,
147+
tmp,
148+
stride_hw,
149+
pad_hw,
150+
dilation_hw,
151+
groups,
152+
deformable_groups,
153+
):
154+
jdeformable_convolution(
155+
input,
156+
output,
157+
offset,
158+
weights,
159+
bias,
160+
tmp,
161+
tuple(stride_hw),
162+
tuple(pad_hw),
163+
tuple(dilation_hw),
164+
groups,
165+
deformable_groups,
166+
)

dpbench/benchmarks/deformable_convolution/deformable_convolution_sycl_native_ext/deformable_convolution_sycl/impl.cpp

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include "oneapi/mkl.hpp"
1515
#include "utils.hpp"
1616

17+
#include <pybind11/stl.h>
18+
1719
using namespace sycl;
1820
namespace py = pybind11;
1921

@@ -65,15 +67,15 @@ __attribute__((always_inline)) DataType bilinear(const DataType *input,
6567
return result / 2;
6668
}
6769

68-
class deform;
70+
template <class DataType> class deform;
6971

7072
template <class DataType>
7173
inline auto deform_input(cl::sycl::queue &queue,
7274
const DataType *input,
7375
const Shape3D in_shape,
7476
DataType *output,
7577
const Shape5D out_shape,
76-
const DataType *offset,
78+
const float *offset,
7779
int stride_y,
7880
int stride_x,
7981
int pad_y,
@@ -94,7 +96,7 @@ inline auto deform_input(cl::sycl::queue &queue,
9496

9597
auto wsize =
9698
sycl::range<3>(in_channels * k_height * k_width, out_height, out_width);
97-
return queue.parallel_for<deform>(wsize, [=](sycl::id<3> idx) {
99+
return queue.parallel_for<deform<DataType>>(wsize, [=](sycl::id<3> idx) {
98100
auto ckhkw = static_cast<int>(idx[0]);
99101
auto h = static_cast<int>(idx[1]);
100102
auto w = static_cast<int>(idx[2]);
@@ -127,19 +129,19 @@ inline auto deform_input(cl::sycl::queue &queue,
127129
});
128130
}
129131

130-
class fill_output;
132+
template <class DataType> class fill_output;
131133

132134
template <class DataType>
133135
auto output_fill_with_bias(cl::sycl::queue &queue,
134136
DataType *output,
135137
const Shape3D out_shape,
136-
DataType *bias)
138+
const DataType *bias)
137139
{
138140
auto out_c = out_shape[CHW::C];
139141
auto out_h = out_shape[CHW::H];
140142
auto out_w = out_shape[CHW::W];
141143

142-
return queue.parallel_for<fill_output>(
144+
return queue.parallel_for<fill_output<DataType>>(
143145
sycl::range<3>(out_c, out_h, out_w), [=](sycl::id<3> idx) {
144146
auto c = static_cast<int>(idx[0]);
145147
auto h = static_cast<int>(idx[1]);
@@ -158,9 +160,9 @@ void deformable_convolution_b1_impl(cl::sycl::queue &queue,
158160
const Shape3D out_shape,
159161
DataType *tmp,
160162
const float *offset,
161-
DataType *weights,
163+
const DataType *weights,
162164
const Shape4D weights_shape,
163-
DataType *bias,
165+
const DataType *bias,
164166
int stride_y,
165167
int stride_x,
166168
int pad_y,
@@ -210,9 +212,9 @@ void deformable_convolution_impl(cl::sycl::queue &queue,
210212
const Shape4D out_shape,
211213
DataType *tmp,
212214
const float *offset,
213-
DataType *weights,
215+
const DataType *weights,
214216
const Shape4D weights_shape,
215-
DataType *bias,
217+
const DataType *bias,
216218
int stride_y,
217219
int stride_x,
218220
int pad_y,
@@ -247,25 +249,45 @@ void deformable_convolution_impl(cl::sycl::queue &queue,
247249
}
248250
}
249251

252+
template <typename... Args> bool ensure_compatibility(const Args &...args)
253+
{
254+
std::vector<dpctl::tensor::usm_ndarray> arrays = {args...};
255+
256+
auto arr = arrays.at(0);
257+
258+
for (auto &arr : arrays) {
259+
if (!(arr.get_flags() & (USM_ARRAY_C_CONTIGUOUS))) {
260+
std::cerr << "All arrays need to be C contiguous.\n";
261+
return false;
262+
}
263+
}
264+
return true;
265+
}
266+
250267
void deformable_convolution(dpctl::tensor::usm_ndarray input,
251268
dpctl::tensor::usm_ndarray output,
252269
dpctl::tensor::usm_ndarray offset,
253270
dpctl::tensor::usm_ndarray weights,
254271
dpctl::tensor::usm_ndarray bias,
255272
dpctl::tensor::usm_ndarray tmp,
256-
int stride_y,
257-
int stride_x,
258-
int pad_y,
259-
int pad_x,
260-
int dilation_y,
261-
int dilation_x,
273+
std::vector<int> stride_hw,
274+
std::vector<int> pad_hw,
275+
std::vector<int> dilation_hw,
262276
int groups,
263277
int deformable_groups)
264278
{
265279
auto queue = input.get_queue();
266280

267-
if (input.get_typenum() != UAR_FLOAT) {
268-
throw std::runtime_error("Expected a single precision FP array.");
281+
if (!ensure_compatibility(input, output, offset, weights, bias, tmp))
282+
throw std::runtime_error("Input arrays are not acceptable.");
283+
284+
if (input.get_typenum() != output.get_typenum() or
285+
input.get_typenum() != offset.get_typenum() or
286+
input.get_typenum() != weights.get_typenum() or
287+
input.get_typenum() != bias.get_typenum() or
288+
input.get_typenum() != tmp.get_typenum())
289+
{
290+
throw std::runtime_error("All arrays must have the same precision");
269291
}
270292

271293
int batch = input.get_shape(0);
@@ -281,17 +303,39 @@ void deformable_convolution(dpctl::tensor::usm_ndarray input,
281303
int kernel_height = weights.get_shape(2);
282304
int kernel_width = weights.get_shape(3);
283305

306+
auto stride_y = stride_hw[0];
307+
auto stride_x = stride_hw[1];
308+
309+
auto pad_y = pad_hw[0];
310+
auto pad_x = pad_hw[1];
311+
312+
auto dilation_y = pad_hw[0];
313+
auto dilation_x = pad_hw[1];
314+
284315
auto input_shape = Shape4D({batch, in_channels, in_height, in_width});
285316
auto output_shape = Shape4D({batch, out_channels, out_height, out_width});
286317
auto weights_shape =
287318
Shape4D({out_channels, in_channels, kernel_height, kernel_width});
288319

289-
deformable_convolution_impl(
290-
queue, input.get_data<float>(), input_shape, output.get_data<float>(),
291-
output_shape, tmp.get_data<float>(), offset.get_data<float>(),
292-
weights.get_data<float>(), weights_shape, bias.get_data<float>(),
293-
stride_y, stride_x, pad_y, pad_x, dilation_y, dilation_x, groups,
294-
deformable_groups);
320+
#define dispatch_dc(typ) \
321+
deformable_convolution_impl<typ>( \
322+
queue, input.get_data<typ>(), input_shape, output.get_data<typ>(), \
323+
output_shape, tmp.get_data<typ>(), offset.get_data<float>(), \
324+
weights.get_data<typ>(), weights_shape, bias.get_data<typ>(), \
325+
stride_y, stride_x, pad_y, pad_x, dilation_y, dilation_x, groups, \
326+
deformable_groups)
327+
328+
if (input.get_typenum() == UAR_FLOAT) {
329+
dispatch_dc(float);
330+
}
331+
else if (input.get_typenum() == UAR_DOUBLE) {
332+
dispatch_dc(double);
333+
}
334+
else {
335+
throw std::runtime_error("Unsupported type");
336+
}
337+
338+
#undef dispatch_dc
295339
}
296340

297341
PYBIND11_MODULE(_deformable_convolution_sycl, m)
@@ -301,8 +345,7 @@ PYBIND11_MODULE(_deformable_convolution_sycl, m)
301345
m.def("deformable_convolution", &deformable_convolution,
302346
"Defromable convolution", py::arg("input"), py::arg("output"),
303347
py::arg("offset"), py::arg("weights"), py::arg("bias"),
304-
py::arg("tmp"), py::arg("stride_y"), py::arg("stride_x"),
305-
py::arg("pad_y"), py::arg("pad_x"), py::arg("dilation_y"),
306-
py::arg("dilation_x"), py::arg("groups"),
348+
py::arg("tmp"), py::arg("stride_hw"), py::arg("pad_hw"),
349+
py::arg("dilation_hw"), py::arg("groups"),
307350
py::arg("deformable_groups"));
308351
}

0 commit comments

Comments
 (0)