Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def _right_join_spatialelement_table(
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
) -> tuple[dict[str, Any], AnnData]:
if match_rows == "left":
warnings.warn("Matching rows ``'left'`` is not supported for ``'right'`` join.", UserWarning, stacklevel=2)
warnings.warn("Matching rows 'left' is not supported for 'right' join.", UserWarning, stacklevel=2)
regions, region_column_name, instance_key = get_table_keys(table)
groups_df = table.obs.groupby(by=region_column_name)
for element_type, name_element in element_dict.items():
Expand All @@ -300,7 +300,7 @@ def _right_join_spatialelement_table(
element_indices = element.index
else:
warnings.warn(
f"Element type `labels` not supported for left exclusive join. Skipping `{name}`",
f"Element type `labels` not supported for 'right' join. Skipping `{name}`",
UserWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -331,7 +331,7 @@ def _inner_join_spatialelement_table(
element_indices = element.index
else:
warnings.warn(
f"Element type `labels` not supported for left exclusive join. Skipping `{name}`",
f"Element type `labels` not supported for 'inner' join. Skipping `{name}`",
UserWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -389,7 +389,7 @@ def _left_join_spatialelement_table(
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
) -> tuple[dict[str, Any], AnnData]:
if match_rows == "right":
warnings.warn("Matching rows ``'right'`` is not supported for ``'left'`` join.", UserWarning, stacklevel=2)
warnings.warn("Matching rows 'right' is not supported for 'left' join.", UserWarning, stacklevel=2)
regions, region_column_name, instance_key = get_table_keys(table)
groups_df = table.obs.groupby(by=region_column_name)
joined_indices = None
Expand Down
58 changes: 36 additions & 22 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,41 +159,55 @@ def __init__(

self._query = QueryManager(self)

def validate_table_in_spatialdata(self, data: AnnData) -> None:
def validate_table_in_spatialdata(self, table: AnnData) -> None:
"""
Validate the presence of the annotation target of a SpatialData table in the SpatialData object.

This method validates a table in the SpatialData object to ensure that if annotation metadata is present, the
annotation target (SpatialElement) is present in the SpatialData object. Otherwise, a warning is raised.
annotation target (SpatialElement) is present in the SpatialData object, the dtypes of the instance key column
in the table and the annotation target do not match. Otherwise, a warning is raised.

Parameters
----------
data
table
The table potentially annotating a SpatialElement

Raises
------
UserWarning
If the table is annotating elements not present in the SpatialData object.
UserWarning
The dtypes of the instance key column in the table and the annotation target do not match.
"""
TableModel().validate(data)
element_names = [
element_name for element_type, element_name, _ in self._gen_elements() if element_type != "tables"
]
if TableModel.ATTRS_KEY in data.uns:
attrs = data.uns[TableModel.ATTRS_KEY]
regions = (
attrs[TableModel.REGION_KEY]
if isinstance(attrs[TableModel.REGION_KEY], list)
else [attrs[TableModel.REGION_KEY]]
)
# TODO: check throwing error
if not all(element_name in element_names for element_name in regions):
warnings.warn(
"The table is annotating an/some element(s) not present in the SpatialData object",
UserWarning,
stacklevel=2,
)
TableModel().validate(table)
if TableModel.ATTRS_KEY in table.uns:
region, _, instance_key = get_table_keys(table)
region = region if isinstance(region, list) else [region]
for r in region:
element = self.get(r)
if element is None:
warnings.warn(
f"The table is annotating {r!r}, which is not present in the SpatialData object.",
UserWarning,
stacklevel=2,
)
else:
if isinstance(element, (SpatialImage, MultiscaleSpatialImage)):
dtype = element.dtype
else:
dtype = element.index.dtype
if dtype != table.obs[instance_key].dtype:
warnings.warn(
(
f"Table instance_key column ({instance_key}) has a dtype "
f"({table.obs[instance_key].dtype}) that does not match the dtype of the indices of "
f"the annotated element ({dtype}). Please note in the case of int16 vs int32 or "
"similar cases may be tolerated in downstream methods, but it is recommended to make "
"the dtypes match."
),
UserWarning,
stacklevel=2,
)

@staticmethod
def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> SpatialData:
Expand Down Expand Up @@ -417,7 +431,7 @@ def set_table_annotates_spatialelement(
table = self.tables[table_name]
element_names = {element[1] for element in self._gen_elements()}
if region not in element_names:
raise ValueError(f"Annotation target '{region}' not present as SpatialElement in " f"SpatialData object.")
raise ValueError(f"Annotation target '{region}' not present as SpatialElement in SpatialData object.")

if table.uns.get(TableModel.ATTRS_KEY):
self._change_table_annotation_target(table, region, region_key, instance_key)
Expand Down
48 changes: 41 additions & 7 deletions tests/core/operations/test_spatialdata_operations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import math
import warnings

import numpy as np
import pytest
Expand All @@ -11,13 +12,7 @@
from spatialdata._core.spatialdata import SpatialData
from spatialdata._utils import _assert_spatialdata_objects_seem_identical, _assert_tables_seem_identical
from spatialdata.datasets import blobs
from spatialdata.models import (
Image2DModel,
Labels2DModel,
PointsModel,
ShapesModel,
TableModel,
)
from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel, get_table_keys
from spatialdata.transformations.operations import get_transformation, set_transformation
from spatialdata.transformations.transformations import (
Affine,
Expand Down Expand Up @@ -417,3 +412,42 @@ def test_transform_to_data_extent(full_sdata: SpatialData, maintain_positioning:
assert are_extents_equal(
data_extent_before, data_extent_after, atol=3
), f"data_extent_before: {data_extent_before}, data_extent_after: {data_extent_after} for element {element}"


def test_validate_table_in_spatialdata(full_sdata):
table = full_sdata["table"]
region, region_key, _ = get_table_keys(table)
assert region == "labels2d"

# no warnings
with warnings.catch_warnings():
warnings.simplefilter("error")
full_sdata.validate_table_in_spatialdata(table)

# dtype mismatch
full_sdata.labels["labels2d"] = Labels2DModel.parse(full_sdata.labels["labels2d"].astype("int16"))
with pytest.warns(UserWarning, match="that does not match the dtype of the indices of the annotated element"):
full_sdata.validate_table_in_spatialdata(table)

# region not found
del full_sdata.labels["labels2d"]
with pytest.warns(UserWarning, match="in the SpatialData object"):
full_sdata.validate_table_in_spatialdata(table)

table.obs[region_key] = "points_0"
full_sdata.set_table_annotates_spatialelement("table", region="points_0")

# no warnings
with warnings.catch_warnings():
warnings.simplefilter("error")
full_sdata.validate_table_in_spatialdata(table)

# dtype mismatch
full_sdata.points["points_0"].index = full_sdata.points["points_0"].index.astype("int16")
with pytest.warns(UserWarning, match="that does not match the dtype of the indices of the annotated element"):
full_sdata.validate_table_in_spatialdata(table)

# region not found
del full_sdata.points["points_0"]
with pytest.warns(UserWarning, match="in the SpatialData object"):
full_sdata.validate_table_in_spatialdata(table)
12 changes: 5 additions & 7 deletions tests/io/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_change_annotation_target(self, full_sdata, region_key, instance_key, er
def test_set_table_nonexisting_target(self, full_sdata):
with pytest.raises(
ValueError,
match="Annotation target 'non_existing' not present as SpatialElement in " "SpatialData object.",
match="Annotation target 'non_existing' not present as SpatialElement in SpatialData object.",
):
full_sdata.set_table_annotates_spatialelement("table", "non_existing")

Expand Down Expand Up @@ -150,9 +150,8 @@ def test_single_table(self, tmp_path: str, region: str):
}

if region == "non_existing":
with pytest.warns(
UserWarning, match=r"The table is annotating an/some element\(s\) not present in the SpatialData object"
):
# annotation target not present in the SpatialData object
with pytest.warns(UserWarning, match=r", which is not present in the SpatialData object"):
SpatialData(
shapes=shapes_dict,
tables={"shape_annotate": table},
Expand Down Expand Up @@ -189,9 +188,8 @@ def test_paired_elements_tables(self, tmp_path: str):
table = _get_table(region="poly")
table2 = _get_table(region="multipoly")
table3 = _get_table(region="non_existing")
with pytest.warns(
UserWarning, match=r"The table is annotating an/some element\(s\) not present in the SpatialData object"
):
# annotation target not present in the SpatialData object
with pytest.warns(UserWarning, match=r", which is not present in the SpatialData object"):
SpatialData(
shapes={"poly": test_shapes["poly"], "multipoly": test_shapes["multipoly"]},
table={"poly_annotate": table, "multipoly_annotate": table3},
Expand Down