Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import scanpy as sc
import spatialdata as sd
from anndata import AnnData
Expand Down Expand Up @@ -39,6 +40,7 @@
_get_cs_contents,
_get_extent,
_maybe_set_colors,
_mpl_ax_contains_elements,
_multiscale_to_image,
_prepare_cmap_norm,
_prepare_params_plot,
Expand Down Expand Up @@ -454,6 +456,7 @@ def show(
fig: Figure | None = None,
title: None | str | Sequence[str] = None,
share_extent: bool = True,
pad_extent: int = 0,
ax: Axes | Sequence[Axes] | None = None,
return_ax: bool = False,
save: None | str | Path = None,
Expand Down Expand Up @@ -523,6 +526,14 @@ def show(
# Simplicstic solution: If the images are multiscale, just use the first
sdata = _multiscale_to_image(sdata)

# get original axis extent for later comparison
x_min_orig, x_max_orig = (np.inf, -np.inf)
y_min_orig, y_max_orig = (np.inf, -np.inf)

if isinstance(ax, Axes) and _mpl_ax_contains_elements(ax):
x_min_orig, x_max_orig = ax.get_xlim()
y_max_orig, y_min_orig = ax.get_ylim() # (0, 0) is top-left

# handle coordinate system
coordinate_systems = sdata.coordinate_systems if coordinate_systems is None else coordinate_systems
if isinstance(coordinate_systems, str):
Expand All @@ -532,12 +543,38 @@ def show(
if cs not in sdata.coordinate_systems:
raise ValueError(f"Unknown coordinate system '{cs}', valid choices are: {sdata.coordinate_systems}")

# Check if user specified only certain elements to be plotted
cs_contents = _get_cs_contents(sdata)
elements_to_be_rendered = []
for cmd, params in render_cmds.items():
if cmd == "render_images" and cs_contents.query(f"cs == '{cs}'")["has_images"][0]: # noqa: SIM114
if params.elements is not None:
elements_to_be_rendered += (
[params.elements] if isinstance(params.elements, str) else params.elements
)
elif cmd == "render_shapes" and cs_contents.query(f"cs == '{cs}'")["has_shapes"][0]: # noqa: SIM114
if params.elements is not None:
elements_to_be_rendered += (
[params.elements] if isinstance(params.elements, str) else params.elements
)
elif cmd == "render_points" and cs_contents.query(f"cs == '{cs}'")["has_points"][0]: # noqa: SIM114
if params.elements is not None:
elements_to_be_rendered += (
[params.elements] if isinstance(params.elements, str) else params.elements
)
elif cmd == "render_labels" and cs_contents.query(f"cs == '{cs}'")["has_labels"][0]: # noqa: SIM102
if params.elements is not None:
elements_to_be_rendered += (
[params.elements] if isinstance(params.elements, str) else params.elements
)

extent = _get_extent(
sdata=sdata,
has_images="render_images" in render_cmds,
has_labels="render_labels" in render_cmds,
has_points="render_points" in render_cmds,
has_shapes="render_shapes" in render_cmds,
elements=elements_to_be_rendered,
coordinate_systems=coordinate_systems,
)

Expand Down Expand Up @@ -585,7 +622,6 @@ def show(
)

# go through tree
cs_contents = _get_cs_contents(sdata)
for i, cs in enumerate(coordinate_systems):
sdata = self._copy()
# properly transform all elements to the current coordinate system
Expand Down Expand Up @@ -693,12 +729,10 @@ def show(
]
):
# If the axis already has limits, only expand them but not overwrite
x_min, x_max = ax.get_xlim()
y_min, y_max = ax.get_ylim()
x_min = min(x_min, extent[cs][0])
x_max = max(x_max, extent[cs][1])
y_min = min(y_min, extent[cs][2])
y_max = max(y_max, extent[cs][3])
x_min = min(x_min_orig, extent[cs][0]) - pad_extent
x_max = max(x_max_orig, extent[cs][1]) + pad_extent
y_min = min(y_min_orig, extent[cs][2]) - pad_extent
y_max = max(y_max_orig, extent[cs][3]) + pad_extent
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_max, y_min) # (0, 0) is top-left

Expand Down
41 changes: 33 additions & 8 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,12 @@ def _get_cs_contents(sdata: sd.SpatialData) -> pd.DataFrame:

def _get_extent(
sdata: sd.SpatialData,
coordinate_systems: None | str | Sequence[str] = None,
coordinate_systems: Sequence[str] | str | None = None,
has_images: bool = True,
has_labels: bool = True,
has_points: bool = True,
has_shapes: bool = True,
elements: Iterable[Any] | None = None,
share_extent: bool = False,
) -> dict[str, tuple[int, int, int, int]]:
"""Return the extent of all elements in their respective coordinate systems.
Expand All @@ -191,16 +192,18 @@ def _get_extent(
----------
sdata
The sd.SpatialData object to retrieve the extent from
images
has_images
Flag indicating whether to consider images when calculating the extent
labels
has_labels
Flag indicating whether to consider labels when calculating the extent
points
has_points
Flag indicating whether to consider points when calculating the extent
shapes
Flag indicating whether to consider shaoes when calculating the extent
img_transformations
List of transformations already applied to the images
has_shapes
Flag indicating whether to consider shapes when calculating the extent
elements
Optional list of element names to be considered. When None, all are used.
share_extent
Flag indicating whether to use the same extent for all coordinate systems

Returns
-------
Expand All @@ -212,6 +215,12 @@ def _get_extent(
cs_mapping = _get_coordinate_system_mapping(sdata)
cs_contents = _get_cs_contents(sdata)

if elements is None: # to shut up ruff
elements = []

if not isinstance(elements, list):
raise ValueError(f"Invalid type of `elements`: {type(elements)}, expected `list`.")

if coordinate_systems is not None:
if isinstance(coordinate_systems, str):
coordinate_systems = [coordinate_systems]
Expand All @@ -220,6 +229,8 @@ def _get_extent(

for cs_name, element_ids in cs_mapping.items():
extent[cs_name] = {}
if len(elements) > 0:
element_ids = [e for e in element_ids if e in elements]

def _get_extent_after_transformations(element: Any, cs_name: str) -> Sequence[int]:
tmp = element.copy()
Expand Down Expand Up @@ -1127,3 +1138,17 @@ def _robust_transform(element: Any, cs: str) -> Any:
raise ValueError("Unable to transform element.") from e

return element


def _mpl_ax_contains_elements(ax: Axes) -> bool:
"""Check if any objects have been plotted on the axes object.

While extracting the extent, we need to know if the axes object has just been
initialised and therefore has extent (0, 1), (0,1) or if it has been plotted on
and therefore has a different extent.

Based on: https://stackoverflow.com/a/71966295
"""
return (
len(ax.lines) > 0 or len(ax.collections) > 0 or len(ax.images) > 0 or len(ax.patches) > 0 or len(ax.tables) > 0
)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/_images/Labels_can_render_labels.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed tests/_images/Labels_labels.png
Binary file not shown.
Binary file added tests/_images/Points_can_render_points.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Shapes_can_render_circles.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Shapes_can_render_polygons.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/_images/Show_pad_extent_adds_padding.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
56 changes: 56 additions & 0 deletions tests/pl/test_get_extent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import matplotlib
import scanpy as sc
import spatialdata_plot # noqa: F401
from spatialdata import SpatialData

from tests.conftest import PlotTester, PlotTesterMeta

sc.pl.set_rcParams_defaults()
sc.set_figure_params(dpi=40, color_map="viridis")
matplotlib.use("agg") # same as GitHub action runner
_ = spatialdata_plot

# WARNING:
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
# 2. tests which produce a plot must be prefixed with `test_plot_`
# 3. if the tolerance needs to be changed, don't prefix the function with `test_plot_`, but with something else
# the comp. function can be accessed as `self.compare(<your_filename>, tolerance=<your_tolerance>)`
# ".png" is appended to <your_filename>, no need to set it


class TestExtent(PlotTester, metaclass=PlotTesterMeta):
def test_plot_extent_of_img_full_canvas(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_images(elements="blobs_image").pl.show()

def test_plot_extent_of_points_partial_canvas(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_points().pl.show()

def test_plot_extent_of_partial_canvas_on_full_canvas(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_images(elements="blobs_image").pl.render_points().pl.show()

def test_plot_extent_calculation_respects_element_selection_circles(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(elements="blobs_circles").pl.show()

def test_plot_extent_calculation_respects_element_selection_polygons(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(elements="blobs_polygons").pl.show()

def test_plot_extent_calculation_respects_element_selection_circles_and_polygons(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(elements=["blobs_circles", "blobs_polygons"]).pl.show()

def test_plot_extent_of_img_is_correct_after_spatial_query(self, sdata_blobs: SpatialData):
cropped_blobs = sdata_blobs.pp.get_elements(["blobs_image"]).query.bounding_box(
axes=["x", "y"], min_coordinate=[100, 100], max_coordinate=[400, 400], target_coordinate_system="global"
)
cropped_blobs.pl.render_images().pl.show()

def test_plot_extent_of_polygons_is_correct_after_spatial_query(self, sdata_blobs: SpatialData):
cropped_blobs = sdata_blobs.pp.get_elements(["blobs_polygons"]).query.bounding_box(
axes=["x", "y"], min_coordinate=[100, 100], max_coordinate=[400, 400], target_coordinate_system="global"
)
cropped_blobs.pl.render_shapes().pl.show()

def test_plot_extent_of_polygons_on_img_is_correct_after_spatial_query(self, sdata_blobs: SpatialData):
cropped_blobs = sdata_blobs.pp.get_elements(["blobs_image", "blobs_polygons"]).query.bounding_box(
axes=["x", "y"], min_coordinate=[100, 100], max_coordinate=[400, 400], target_coordinate_system="global"
)
cropped_blobs.pl.render_images().pl.render_shapes().pl.show()
2 changes: 1 addition & 1 deletion tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@


class TestLabels(PlotTester, metaclass=PlotTesterMeta):
def test_plot_labels(self, sdata_blobs: SpatialData):
def test_plot_can_render_labels(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_labels(elements="blobs_labels").pl.show()
2 changes: 1 addition & 1 deletion tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@


class TestPoints(PlotTester, metaclass=PlotTesterMeta):
def test_plot_points(self, sdata_blobs: SpatialData):
def test_plot_can_render_points(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_points(elements="blobs_points").pl.show()
4 changes: 2 additions & 2 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class TestShapes(PlotTester, metaclass=PlotTesterMeta):
def test_plot_can_render_circles(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_circles").pl.show()
sdata_blobs.pl.render_shapes(elements="blobs_circles").pl.show()

def test_plot_can_render_polygons(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons").pl.show()
sdata_blobs.pl.render_shapes(elements="blobs_polygons").pl.show()
23 changes: 23 additions & 0 deletions tests/pl/test_show.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import matplotlib
import scanpy as sc
import spatialdata_plot # noqa: F401
from spatialdata import SpatialData

from tests.conftest import PlotTester, PlotTesterMeta

sc.pl.set_rcParams_defaults()
sc.set_figure_params(dpi=40, color_map="viridis")
matplotlib.use("agg") # same as GitHub action runner
_ = spatialdata_plot

# WARNING:
# 1. all classes must both subclass PlotTester and use metaclass=PlotTesterMeta
# 2. tests which produce a plot must be prefixed with `test_plot_`
# 3. if the tolerance needs to be changed, don't prefix the function with `test_plot_`, but with something else
# the comp. function can be accessed as `self.compare(<your_filename>, tolerance=<your_tolerance>)`
# ".png" is appended to <your_filename>, no need to set it


class TestShow(PlotTester, metaclass=PlotTesterMeta):
def test_plot_pad_extent_adds_padding(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_images(elements="blobs_image").pl.show(pad_extent=100)