@@ -52,55 +52,16 @@ def _default_reduction_dtype(inp_dt, q):
5252 return res_dt
5353
5454
55- def sum (x , axis = None , dtype = None , keepdims = False ):
56- """sum(x, axis=None, dtype=None, keepdims=False)
57-
58- Calculates the sum of the input array `x`.
59-
60- Args:
61- x (usm_ndarray):
62- input array.
63- axis (Optional[int, Tuple[int,...]]):
64- axis or axes along which sums must be computed. If a tuple
65- of unique integers, sums are computed over multiple axes.
66- If `None`, the sum if computed over the entire array.
67- Default: `None`.
68- dtype (Optional[dtype]):
69- data type of the returned array. If `None`, the default data
70- type is inferred from the "kind" of the input array data type.
71- * If `x` has a real-valued floating-point data type,
72- the returned array will have the default real-valued
73- floating-point data type for the device where input
74- array `x` is allocated.
75- * If x` has signed integral data type, the returned array
76- will have the default signed integral type for the device
77- where input array `x` is allocated.
78- * If `x` has unsigned integral data type, the returned array
79- will have the default unsigned integral type for the device
80- where input array `x` is allocated.
81- * If `x` has a complex-valued floating-point data typee,
82- the returned array will have the default complex-valued
83- floating-pointer data type for the device where input
84- array `x` is allocated.
85- * If `x` has a boolean data type, the returned array will
86- have the default signed integral type for the device
87- where input array `x` is allocated.
88- If the data type (either specified or resolved) differs from the
89- data type of `x`, the input array elements are cast to the
90- specified data type before computing the sum. Default: `None`.
91- keepdims (Optional[bool]):
92- if `True`, the reduced axes (dimensions) are included in the result
93- as singleton dimensions, so that the returned array remains
94- compatible with the input arrays according to Array Broadcasting
95- rules. Otherwise, if `False`, the reduced axes are not included in
96- the returned array. Default: `False`.
97- Returns:
98- usm_ndarray:
99- an array containing the sums. If the sum was computed over the
100- entire array, a zero-dimensional array is returned. The returned
101- array has the data type as described in the `dtype` parameter
102- description above.
103- """
55+ def _reduction_over_axis (
56+ x ,
57+ axis ,
58+ dtype ,
59+ keepdims ,
60+ _reduction_fn ,
61+ _dtype_supported ,
62+ _default_reduction_type_fn ,
63+ _identity = None ,
64+ ):
10465 if not isinstance (x , dpt .usm_ndarray ):
10566 raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
10667 nd = x .ndim
@@ -116,29 +77,36 @@ def sum(x, axis=None, dtype=None, keepdims=False):
11677 q = x .sycl_queue
11778 inp_dt = x .dtype
11879 if dtype is None :
119- res_dt = _default_reduction_dtype (inp_dt , q )
80+ res_dt = _default_reduction_type_fn (inp_dt , q )
12081 else :
12182 res_dt = dpt .dtype (dtype )
12283 res_dt = _to_device_supported_dtype (res_dt , q .sycl_device )
12384
12485 res_usm_type = x .usm_type
12586 if x .size == 0 :
126- if keepdims :
127- res_shape = res_shape + (1 ,) * red_nd
128- inv_perm = sorted (range (nd ), key = lambda d : perm [d ])
129- res_shape = tuple (res_shape [i ] for i in inv_perm )
130- return dpt .zeros (
131- res_shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
132- )
87+ if _identity is None :
88+ raise ValueError ("reduction does not support zero-size arrays" )
89+ else :
90+ if keepdims :
91+ res_shape = res_shape + (1 ,) * red_nd
92+ inv_perm = sorted (range (nd ), key = lambda d : perm [d ])
93+ res_shape = tuple (res_shape [i ] for i in inv_perm )
94+ return dpt .full (
95+ res_shape ,
96+ _identity ,
97+ dtype = res_dt ,
98+ usm_type = res_usm_type ,
99+ sycl_queue = q ,
100+ )
133101 if red_nd == 0 :
134102 return dpt .astype (x , res_dt , copy = False )
135103
136104 host_tasks_list = []
137- if ti . _sum_over_axis_dtype_supported (inp_dt , res_dt , res_usm_type , q ):
105+ if _dtype_supported (inp_dt , res_dt , res_usm_type , q ):
138106 res = dpt .empty (
139107 res_shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
140108 )
141- ht_e , _ = ti ._sum_over_axis (
109+ ht_e , _ = ti ._reduction_fn (
142110 src = arr2 , trailing_dims_to_reduce = red_nd , dst = res , sycl_queue = q
143111 )
144112 host_tasks_list .append (ht_e )
@@ -152,7 +120,7 @@ def sum(x, axis=None, dtype=None, keepdims=False):
152120 tmp = dpt .empty (
153121 res_shape , dtype = tmp_dt , usm_type = res_usm_type , sycl_queue = q
154122 )
155- ht_e_tmp , r_e = ti ._sum_over_axis (
123+ ht_e_tmp , r_e = ti ._reduction_fn (
156124 src = arr2 , trailing_dims_to_reduce = red_nd , dst = tmp , sycl_queue = q
157125 )
158126 host_tasks_list .append (ht_e_tmp )
@@ -173,6 +141,67 @@ def sum(x, axis=None, dtype=None, keepdims=False):
173141 return res
174142
175143
144+ def sum (x , axis = None , dtype = None , keepdims = False ):
145+ """sum(x, axis=None, dtype=None, keepdims=False)
146+
147+ Calculates the sum of the input array `x`.
148+
149+ Args:
150+ x (usm_ndarray):
151+ input array.
152+ axis (Optional[int, Tuple[int,...]]):
153+ axis or axes along which sums must be computed. If a tuple
154+ of unique integers, sums are computed over multiple axes.
155+ If `None`, the sum is computed over the entire array.
156+ Default: `None`.
157+ dtype (Optional[dtype]):
158+ data type of the returned array. If `None`, the default data
159+ type is inferred from the "kind" of the input array data type.
160+ * If `x` has a real-valued floating-point data type,
161+ the returned array will have the default real-valued
162+ floating-point data type for the device where input
163+ array `x` is allocated.
164+ * If x` has signed integral data type, the returned array
165+ will have the default signed integral type for the device
166+ where input array `x` is allocated.
167+ * If `x` has unsigned integral data type, the returned array
168+ will have the default unsigned integral type for the device
169+ where input array `x` is allocated.
170+ * If `x` has a complex-valued floating-point data typee,
171+ the returned array will have the default complex-valued
172+ floating-pointer data type for the device where input
173+ array `x` is allocated.
174+ * If `x` has a boolean data type, the returned array will
175+ have the default signed integral type for the device
176+ where input array `x` is allocated.
177+ If the data type (either specified or resolved) differs from the
178+ data type of `x`, the input array elements are cast to the
179+ specified data type before computing the sum. Default: `None`.
180+ keepdims (Optional[bool]):
181+ if `True`, the reduced axes (dimensions) are included in the result
182+ as singleton dimensions, so that the returned array remains
183+ compatible with the input arrays according to Array Broadcasting
184+ rules. Otherwise, if `False`, the reduced axes are not included in
185+ the returned array. Default: `False`.
186+ Returns:
187+ usm_ndarray:
188+ an array containing the sums. If the sum was computed over the
189+ entire array, a zero-dimensional array is returned. The returned
190+ array has the data type as described in the `dtype` parameter
191+ description above.
192+ """
193+ return _reduction_over_axis (
194+ x ,
195+ axis ,
196+ dtype ,
197+ keepdims ,
198+ ti ._sum_over_axis ,
199+ ti ._sum_over_axis_dtype_supported ,
200+ _default_reduction_dtype ,
201+ _identity = 0 ,
202+ )
203+
204+
176205def _comparison_over_axis (x , axis , keepdims , _reduction_fn ):
177206 if not isinstance (x , dpt .usm_ndarray ):
178207 raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
0 commit comments