@@ -144,7 +144,7 @@ def _reduction_over_axis(
144144def sum (x , axis = None , dtype = None , keepdims = False ):
145145 """sum(x, axis=None, dtype=None, keepdims=False)
146146
147- Calculates the sum of the input array `x`.
147+ Calculates the sum of elements in the input array `x`.
148148
149149 Args:
150150 x (usm_ndarray):
@@ -202,6 +202,67 @@ def sum(x, axis=None, dtype=None, keepdims=False):
202202 )
203203
204204
205+ def prod (x , axis = None , dtype = None , keepdims = False ):
206+ """prod(x, axis=None, dtype=None, keepdims=False)
207+
208+ Calculates the product of elements in the input array `x`.
209+
210+ Args:
211+ x (usm_ndarray):
212+ input array.
213+ axis (Optional[int, Tuple[int,...]]):
214+ axis or axes along which sums must be computed. If a tuple
215+ of unique integers, sums are computed over multiple axes.
216+ If `None`, the sum is computed over the entire array.
217+ Default: `None`.
218+ dtype (Optional[dtype]):
219+ data type of the returned array. If `None`, the default data
220+ type is inferred from the "kind" of the input array data type.
221+ * If `x` has a real-valued floating-point data type,
222+ the returned array will have the default real-valued
223+ floating-point data type for the device where input
224+ array `x` is allocated.
225+ * If x` has signed integral data type, the returned array
226+ will have the default signed integral type for the device
227+ where input array `x` is allocated.
228+ * If `x` has unsigned integral data type, the returned array
229+ will have the default unsigned integral type for the device
230+ where input array `x` is allocated.
231+ * If `x` has a complex-valued floating-point data typee,
232+ the returned array will have the default complex-valued
233+ floating-pointer data type for the device where input
234+ array `x` is allocated.
235+ * If `x` has a boolean data type, the returned array will
236+ have the default signed integral type for the device
237+ where input array `x` is allocated.
238+ If the data type (either specified or resolved) differs from the
239+ data type of `x`, the input array elements are cast to the
240+ specified data type before computing the sum. Default: `None`.
241+ keepdims (Optional[bool]):
242+ if `True`, the reduced axes (dimensions) are included in the result
243+ as singleton dimensions, so that the returned array remains
244+ compatible with the input arrays according to Array Broadcasting
245+ rules. Otherwise, if `False`, the reduced axes are not included in
246+ the returned array. Default: `False`.
247+ Returns:
248+ usm_ndarray:
249+ an array containing the products. If the product was computed over
250+ the entire array, a zero-dimensional array is returned. The returned
251+ array has the data type as described in the `dtype` parameter
252+ description above.
253+ """
254+ return _reduction_over_axis (
255+ x ,
256+ axis ,
257+ dtype ,
258+ keepdims ,
259+ ti ._prod_over_axis ,
260+ ti ._prod_over_axis_dtype_supported ,
261+ _default_reduction_dtype ,
262+ _identity = 0 ,
263+ )
264+
265+
205266def _comparison_over_axis (x , axis , keepdims , _reduction_fn ):
206267 if not isinstance (x , dpt .usm_ndarray ):
207268 raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
0 commit comments