Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning][].

## [0.x.x] - 2024-xx-xx

- Add `return_background` as argument to `get_centroids` and `get_element_instances` #621

## [0.2.1] - 2024-07-04

### Minor
Expand Down
8 changes: 7 additions & 1 deletion src/spatialdata/_core/centroids.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def _validate_coordinate_system(e: SpatialElement, coordinate_system: str) -> No
def get_centroids(
e: SpatialElement,
coordinate_system: str = "global",
return_background: bool = False,
) -> DaskDataFrame:
"""
Get the centroids of the geometries contained in a SpatialElement, as a new Points element.
Expand All @@ -45,6 +46,8 @@ def get_centroids(
The SpatialElement. Only points, shapes (circles, polygons and multipolygons) and labels are supported.
coordinate_system
The coordinate system in which the centroids are computed.
return_background
If True, the centroid of the background label (0) is included in the output.

Notes
-----
Expand All @@ -69,7 +72,7 @@ def _get_centroids_for_axis(xdata: xr.DataArray, axis: str) -> pd.DataFrame:
-------
pd.DataFrame
A DataFrame containing one column, named after "axis", with the centroids of the labels along that axis.
The index of the DataFrame is the collection of label values, sorted ascendingly.
The index of the DataFrame is the collection of label values, sorted in ascending order.
"""
centroids: dict[int, float] = defaultdict(float)
for i in xdata[axis]:
Expand All @@ -95,6 +98,7 @@ def _get_centroids_for_axis(xdata: xr.DataArray, axis: str) -> pd.DataFrame:
def _(
e: DataArray | DataTree,
coordinate_system: str = "global",
return_background: bool = False,
) -> DaskDataFrame:
"""Get the centroids of a Labels element (2D or 3D)."""
model = get_model(e)
Expand All @@ -110,6 +114,8 @@ def _(
for axis in get_axes_names(e):
dfs.append(_get_centroids_for_axis(e, axis))
df = pd.concat(dfs, axis=1)
if not return_background and 0 in df.index:
df = df.drop(index=0) # drop the background label
t = get_transformation(e, coordinate_system)
centroids = PointsModel.parse(df, transformations={coordinate_system: t})
return transform(centroids, to_coordinate_system=coordinate_system)
Expand Down
12 changes: 10 additions & 2 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _filter_table_by_element_names(table: AnnData | None, element_names: str | l
@singledispatch
def get_element_instances(
element: SpatialElement,
return_background: bool = False,
) -> pd.Index:
"""
Get the instances (index values) of the SpatialElement.
Expand All @@ -94,6 +95,8 @@ def get_element_instances(
----------
element
The SpatialElement.
return_background
If True, the background label (0) is included in the output.

Returns
-------
Expand All @@ -106,6 +109,7 @@ def get_element_instances(
@get_element_instances.register(DataTree)
def _(
element: DataArray | DataTree,
return_background: bool = False,
) -> pd.Index:
model = get_model(element)
assert model in [Labels2DModel, Labels3DModel], "Expected a `Labels` element. Found an `Image` instead."
Expand All @@ -119,7 +123,10 @@ def _(
xdata = next(iter(v))
# can be slow
instances = da.unique(xdata.data).compute()
return pd.Index(np.sort(instances))
index = pd.Index(np.sort(instances))
if not return_background and 0 in index:
return index.drop(0) # drop the background label
return index


@get_element_instances.register(GeoDataFrame)
Expand Down Expand Up @@ -568,7 +575,8 @@ def join_spatialelement_table(
both the SpatialElement and table.

For Points and Shapes elements every valid join for argument how is supported. For Labels elements only
the ``'left'`` and ``'right_exclusive'`` joins are supported.
the ``'left'`` and ``'right_exclusive'`` joins are supported.
For Labels, the background label (0) is not included in the output and it will not be returned.

Parameters
----------
Expand Down
8 changes: 5 additions & 3 deletions tests/core/query/test_relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,14 +699,15 @@ def test_labels_table_joins(full_sdata):
table_name="table",
how="left",
)
assert all(table.obs["instance_id"] == range(100))

assert all(table.obs["instance_id"] == range(1, 100))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here the labels has values 0,100 but the 0 element is background, so is effectively filtered in the left/right joins when get_element_instances is called on labels. Hence the indices of the instances are different.


full_sdata["table"].obs["instance_id"] = list(reversed(range(100)))

element_dict, table = join_spatialelement_table(
sdata=full_sdata, spatial_element_names="labels2d", table_name="table", how="left", match_rows="left"
)
assert all(table.obs["instance_id"] == range(100))
assert all(table.obs["instance_id"] == range(1, 100))

with pytest.warns(UserWarning, match="Element type"):
join_spatialelement_table(
Expand All @@ -724,7 +725,8 @@ def test_labels_table_joins(full_sdata):
sdata=full_sdata, spatial_element_names="labels2d", table_name="table", how="right_exclusive"
)
assert element_dict["labels2d"] is None
assert table is None
assert len(table) == 1
assert all(table.obs["instance_id"] == 0) # the background value, which is filtered out effectively


def test_points_table_joins(full_sdata):
Expand Down
14 changes: 11 additions & 3 deletions tests/core/test_centroids.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def test_get_centroids_shapes(shapes, coordinate_system: str, shapes_name: str):
@pytest.mark.parametrize("coordinate_system", ["global", "aligned"])
@pytest.mark.parametrize("is_multiscale", [False, True])
@pytest.mark.parametrize("is_3d", [False, True])
def test_get_centroids_labels(labels, coordinate_system: str, is_multiscale: bool, is_3d: bool):
@pytest.mark.parametrize("return_background", [False, True])
def test_get_centroids_labels(
labels, coordinate_system: str, is_multiscale: bool, is_3d: bool, return_background: bool
):
scale_factors = [2] if is_multiscale else None
if is_3d:
model = Labels3DModel
Expand All @@ -124,6 +127,8 @@ def test_get_centroids_labels(labels, coordinate_system: str, is_multiscale: boo
},
index=[0, 1, 2],
)
if not return_background:
expected_centroids = expected_centroids.drop(index=0)
else:
array = np.array(
[
Expand All @@ -145,11 +150,14 @@ def test_get_centroids_labels(labels, coordinate_system: str, is_multiscale: boo

if coordinate_system == "aligned":
set_transformation(element, transformation=affine, to_coordinate_system=coordinate_system)
centroids = get_centroids(element, coordinate_system=coordinate_system)
centroids = get_centroids(element, coordinate_system=coordinate_system, return_background=return_background)

labels_indices = get_element_instances(element)
labels_indices = get_element_instances(element, return_background=return_background)
assert np.array_equal(centroids.index.values, labels_indices)

if not return_background:
assert 0 not in centroids.index

if coordinate_system == "global":
assert np.array_equal(centroids.compute().values, expected_centroids.values)
else:
Expand Down