Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,6 @@ spatialdata-sandbox

# version file
_version.py

# other
node_modules/
129 changes: 97 additions & 32 deletions src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def _get_bounding_box_corners_in_intrinsic_coordinates(
target_coordinate_system
The coordinate system the bounding box is defined in.

Returns ------- All the corners of the bounding box in the intrinsic coordinate system of the element. The shape
Returns
-------
All the corners of the bounding box in the intrinsic coordinate system of the element. The shape
is (2, 4) when axes has 2 spatial dimensions, and (2, 8) when axes has 3 spatial dimensions.

The axes of the intrinsic coordinate system.
Expand All @@ -73,6 +75,12 @@ def _get_bounding_box_corners_in_intrinsic_coordinates(
# get the transformation from the element's intrinsic coordinate system
# to the query coordinate space
transform_to_query_space = get_transformation(element, to_coordinate_system=target_coordinate_system)
m_without_c, input_axes_without_c, output_axes_without_c = _get_axes_of_tranformation(
element, target_coordinate_system
)
axes, min_coordinate, max_coordinate = _adjust_bounding_box_to_real_axes(
axes, min_coordinate, max_coordinate, output_axes_without_c
)

# get the coordinates of the bounding box corners
bounding_box_corners = get_bounding_box_corners(
Expand Down Expand Up @@ -155,7 +163,7 @@ def _bounding_box_mask_points(
min_coordinate: list[Number] | ArrayLike,
max_coordinate: list[Number] | ArrayLike,
) -> da.Array:
"""Compute a mask that is true for the points inside of an axis-aligned bounding box..
"""Compute a mask that is true for the points inside an axis-aligned bounding box.

Parameters
----------
Expand All @@ -164,23 +172,26 @@ def _bounding_box_mask_points(
axes
The axes that min_coordinate and max_coordinate refer to.
min_coordinate
The upper left hand corner of the bounding box (i.e., minimum coordinates
along all dimensions).
The upper left hand corner of the bounding box (i.e., minimum coordinates along all dimensions).
max_coordinate
The lower right hand corner of the bounding box (i.e., the maximum coordinates
along all dimensions
The lower right hand corner of the bounding box (i.e., the maximum coordinates along all dimensions).

Returns
-------
The mask for the points inside of the bounding box.
The mask for the points inside the bounding box.
"""
element_axes = get_axes_names(points)
min_coordinate = _parse_list_into_array(min_coordinate)
max_coordinate = _parse_list_into_array(max_coordinate)
in_bounding_box_masks = []
for axis_index, axis_name in enumerate(axes):
if axis_name not in element_axes:
continue
min_value = min_coordinate[axis_index]
in_bounding_box_masks.append(points[axis_name].gt(min_value).to_dask_array(lengths=True))
for axis_index, axis_name in enumerate(axes):
if axis_name not in element_axes:
continue
max_value = max_coordinate[axis_index]
in_bounding_box_masks.append(points[axis_name].lt(max_value).to_dask_array(lengths=True))
in_bounding_box_masks = da.stack(in_bounding_box_masks, axis=-1)
Expand Down Expand Up @@ -269,6 +280,77 @@ def _(
return SpatialData(**new_elements, table=table)


def _get_axes_of_tranformation(
element: SpatialElement, target_coordinate_system: str
) -> tuple[ArrayLike, tuple[str, ...], tuple[str, ...]]:
"""
Get the transformation matrix and the transformation's axes (ignoring `c`).

The transformation is the one from the element's intrinsic coordinate system to the query coordinate space.
Note that the axes which specify the query shape are not necessarily the same as the axes that are output of the
transformation

Parameters
----------
element
SpatialData element to be transformed.
target_coordinate_system
The target coordinate system for the transformation.

Returns
-------
m_without_c
The transformation from the element's intrinsic coordinate system to the query coordinate space, without the
"c" axis.
input_axes_without_c
The axes of the element's intrinsic coordinate system, without the "c" axis.
output_axes_without_c
The axes of the query coordinate system, without the "c" axis.

"""
from spatialdata.transformations import get_transformation

transform_to_query_space = get_transformation(element, to_coordinate_system=target_coordinate_system)
assert isinstance(transform_to_query_space, BaseTransformation)
m = _get_affine_for_element(element, transform_to_query_space)
input_axes_without_c = tuple([ax for ax in m.input_axes if ax != "c"])
output_axes_without_c = tuple([ax for ax in m.output_axes if ax != "c"])
m_without_c = m.to_affine_matrix(input_axes=input_axes_without_c, output_axes=output_axes_without_c)
return m_without_c, input_axes_without_c, output_axes_without_c


def _adjust_bounding_box_to_real_axes(
axes: tuple[str, ...],
min_coordinate: ArrayLike,
max_coordinate: ArrayLike,
output_axes_without_c: tuple[str, ...],
) -> tuple[tuple[str, ...], ArrayLike, ArrayLike]:
"""
Adjust the bounding box to the real axes of the transformation.

The bounding box is defined by the user and it's axes may not coincide with the axes of the transformation.
"""
if set(axes) != set(output_axes_without_c):
axes_only_in_bb = set(axes) - set(output_axes_without_c)
axes_only_in_output = set(output_axes_without_c) - set(axes)

# let's remove from the bounding box whose axes that are not in the output axes (e.g. querying 2D points with a
# 3D bounding box)
indices_to_remove_from_bb = [axes.index(ax) for ax in axes_only_in_bb]
axes = tuple([ax for ax in axes if ax not in axes_only_in_bb])
min_coordinate = np.delete(min_coordinate, indices_to_remove_from_bb)
max_coordinate = np.delete(max_coordinate, indices_to_remove_from_bb)

# if there are axes in the output axes that are not in the bounding box, we need to add them to the bounding box
# with a range that includes everything (e.g. querying 3D points with a 2D bounding box)
for ax in axes_only_in_output:
axes = axes + (ax,)
M = np.finfo(np.float32).max - 1
min_coordinate = np.append(min_coordinate, -M)
max_coordinate = np.append(max_coordinate, M)
return axes, min_coordinate, max_coordinate


@bounding_box_query.register(SpatialImage)
@bounding_box_query.register(MultiscaleSpatialImage)
def _(
Expand All @@ -282,7 +364,6 @@ def _(

Notes
-----
_____
See https://github.com/scverse/spatialdata/pull/151 for a detailed overview of the logic of this code,
and for the cases the comments refer to.
"""
Expand All @@ -299,15 +380,10 @@ def _(
max_coordinate=max_coordinate,
)

# get the transformation from the element's intrinsic coordinate system to the query coordinate space
transform_to_query_space = get_transformation(image, to_coordinate_system=target_coordinate_system)
assert isinstance(transform_to_query_space, BaseTransformation)
m = _get_affine_for_element(image, transform_to_query_space)
input_axes_without_c = tuple([ax for ax in m.input_axes if ax != "c"])
output_axes_without_c = tuple([ax for ax in m.output_axes if ax != "c"])
m_without_c = m.to_affine_matrix(input_axes=input_axes_without_c, output_axes=output_axes_without_c)
m_without_c, input_axes_without_c, output_axes_without_c = _get_axes_of_tranformation(
image, target_coordinate_system
)
m_without_c_linear = m_without_c[:-1, :-1]

transform_dimension = np.linalg.matrix_rank(m_without_c_linear)
transform_coordinate_length = len(output_axes_without_c)
data_dim = len(input_axes_without_c)
Expand Down Expand Up @@ -335,24 +411,13 @@ def _(
error_message = (
f"This case is not supported (data with dimension"
f"{data_dim} but transformation with rank {transform_dimension}."
f"Please open a GitHub issue if you want to discuss a case."
f"Please open a GitHub issue if you want to discuss a use case."
)
raise ValueError(error_message)

if set(axes) != set(output_axes_without_c):
if set(axes).issubset(output_axes_without_c):
logger.warning(
f"The element has axes {output_axes_without_c}, but the query has axes {axes}. Excluding the element "
f"from the query result. In the future we can add support for this case. If you are interested, "
f"please open a GitHub issue."
)
return None
error_messeage = (
f"Invalid case. The bounding box axes are {axes},"
f"the spatial axes in {target_coordinate_system} are"
f"{output_axes_without_c}"
)
raise ValueError(error_messeage)
axes, min_coordinate, max_coordinate = _adjust_bounding_box_to_real_axes(
axes, min_coordinate, max_coordinate, output_axes_without_c
)

spatial_transform = Affine(m_without_c, input_axes=input_axes_without_c, output_axes=output_axes_without_c)
spatial_transform_bb_axes = Affine(
Expand All @@ -369,7 +434,7 @@ def _(
)
else:
assert case == 2
# TODO: we need to intersect the plane in the extrinsic coordiante system with the 3D bounding box. The
# TODO: we need to intersect the plane in the extrinsic coordinate system with the 3D bounding box. The
# vertices of this polygons needs to be transformed to the intrinsic coordinate system
raise NotImplementedError(
"Case 2 (the transformation is embedding 2D data in the 3D space, is not "
Expand Down
Loading