File tree Expand file tree Collapse file tree 2 files changed +24
-4
lines changed
Expand file tree Collapse file tree 2 files changed +24
-4
lines changed Original file line number Diff line number Diff line change @@ -425,8 +425,7 @@ def unique_inverse(x):
425425 )
426426 _manager .add_event_pair (ht_ev , sub_ev )
427427
428- inv_dt = dpt .int64 if x .size > dpt .iinfo (dpt .int32 ).max else dpt .int32
429- inv = dpt .empty_like (x , dtype = inv_dt , order = "C" )
428+ inv = dpt .empty_like (x , dtype = ind_dt , order = "C" )
430429 ht_ev , ssl_ev = _searchsorted_left (
431430 hay = unique_vals ,
432431 needles = x ,
@@ -608,8 +607,7 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
608607 )
609608 _manager .add_event_pair (ht_ev , sub_ev )
610609
611- inv_dt = dpt .int64 if x .size > dpt .iinfo (dpt .int32 ).max else dpt .int32
612- inv = dpt .empty_like (x , dtype = inv_dt , order = "C" )
610+ inv = dpt .empty_like (x , dtype = ind_dt , order = "C" )
613611 ht_ev , ssl_ev = _searchsorted_left (
614612 hay = unique_vals ,
615613 needles = x ,
Original file line number Diff line number Diff line change @@ -321,3 +321,25 @@ def test_set_functions_compute_follows_data():
321321 assert ind .sycl_queue == q
322322 assert inv_ind .sycl_queue == q
323323 assert uc .sycl_queue == q
324+
325+
326+ def test_gh_1738 ():
327+ get_queue_or_skip ()
328+
329+ ones = dpt .ones (10 , dtype = "i8" )
330+ iota = dpt .arange (10 , dtype = "i8" )
331+
332+ assert ones .device == iota .device
333+
334+ dpt_info = dpt .__array_namespace_info__ ()
335+ ind_dt = dpt_info .default_dtypes (device = ones .device )["indexing" ]
336+
337+ dt = dpt .unique_inverse (ones ).inverse_indices .dtype
338+ assert dt == ind_dt
339+ dt = dpt .unique_all (ones ).inverse_indices .dtype
340+ assert dt == ind_dt
341+
342+ dt = dpt .unique_inverse (iota ).inverse_indices .dtype
343+ assert dt == ind_dt
344+ dt = dpt .unique_all (iota ).inverse_indices .dtype
345+ assert dt == ind_dt
You can’t perform that action at this time.
0 commit comments