Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning][].

## [0.2.3] - tbd

### Changed

- All parameters are now provided for a single element. If element in pl.render is None then this value will be broadcasted

### Fixed

- Fix color assignment for NaN values (#257)
Expand Down
69 changes: 0 additions & 69 deletions src/spatialdata_plot/_utils.py

This file was deleted.

178 changes: 88 additions & 90 deletions src/spatialdata_plot/pl/basic.py

Large diffs are not rendered by default.

278 changes: 135 additions & 143 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings
from collections import abc
from copy import copy
from typing import Union, cast
from typing import Union

import dask
import geopandas as gpd
Expand Down Expand Up @@ -45,8 +45,6 @@
_multiscale_to_spatial_image,
_normalize,
_rasterize_if_necessary,
_return_list_list_str_none,
_return_list_str_none,
_set_color_source_vec,
to_hex,
)
Expand Down Expand Up @@ -524,162 +522,156 @@ def _render_labels(
legend_params: LegendParams,
rasterize: bool,
) -> None:
elements = render_params.elements
element_table_mapping = cast(dict[str, str], render_params.element_table_mapping)
palettes = _return_list_list_str_none(render_params.palette)
colors = _return_list_str_none(render_params.color)
groups = _return_list_list_str_none(render_params.groups)
element = render_params.element
table_name = render_params.table_name
palette = render_params.palette
color = render_params.color
groups = render_params.groups
scale = render_params.scale

if render_params.outline is False:
render_params.outline_alpha = 0

sdata_filt = sdata.filter_by_coordinate_system(
coordinate_system=coordinate_system,
filter_tables=any(value is not None for value in element_table_mapping.values()),
filter_tables=bool(table_name),
)

if elements is None:
elements = list(sdata_filt.labels.keys())

for i, e in enumerate(elements):
label = sdata_filt.labels[e]
extent = get_extent(label, coordinate_system=coordinate_system)
scale = render_params.scale[i] if isinstance(render_params.scale, list) else render_params.scale
color = colors[i]

# get best scale out of multiscale label
if isinstance(label, MultiscaleSpatialImage):
label = _multiscale_to_spatial_image(
multiscale_image=label,
dpi=fig_params.fig.dpi,
width=fig_params.fig.get_size_inches()[0],
height=fig_params.fig.get_size_inches()[1],
scale=scale,
is_label=True,
)
# rasterize spatial image if necessary to speed up performance
if rasterize:
label = _rasterize_if_necessary(
image=label,
dpi=fig_params.fig.dpi,
width=fig_params.fig.get_size_inches()[0],
height=fig_params.fig.get_size_inches()[1],
coordinate_system=coordinate_system,
extent=extent,
)
label = sdata_filt.labels[element]
extent = get_extent(label, coordinate_system=coordinate_system)

table_name = mapping.get(e) if isinstance((mapping := element_table_mapping), dict) else None
if table_name is None:
instance_id = np.unique(label)
table = None
else:
regions, region_key, instance_key = get_table_keys(sdata[table_name])
table = sdata[table_name][sdata[table_name].obs[region_key].isin([e])]

# get instance id based on subsetted table
instance_id = table.obs[instance_key].values

trans = get_transformation(label, get_all=True)[coordinate_system]
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
trans = mtransforms.Affine2D(matrix=affine_trans)
trans_data = trans + ax.transData

color_source_vector, color_vector, categorical = _set_color_source_vec(
sdata=sdata_filt,
element=label,
element_name=e,
value_to_plot=color,
groups=groups[i], # if isinstance(groups, list) else None,
palette=palettes[i],
na_color=render_params.cmap_params.na_color,
cmap_params=render_params.cmap_params,
table_name=cast(str, table_name),
# get best scale out of multiscale label
if isinstance(label, MultiscaleSpatialImage):
label = _multiscale_to_spatial_image(
multiscale_image=label,
dpi=fig_params.fig.dpi,
width=fig_params.fig.get_size_inches()[0],
height=fig_params.fig.get_size_inches()[1],
scale=scale,
is_label=True,
)
# rasterize spatial image if necessary to speed up performance
if rasterize:
label = _rasterize_if_necessary(
image=label,
dpi=fig_params.fig.dpi,
width=fig_params.fig.get_size_inches()[0],
height=fig_params.fig.get_size_inches()[1],
coordinate_system=coordinate_system,
extent=extent,
)

if (render_params.fill_alpha != render_params.outline_alpha) and render_params.contour_px is not None:
# First get the labels infill and plot them
labels_infill = _map_color_seg(
seg=label.values,
cell_id=instance_id,
color_vector=color_vector,
color_source_vector=color_source_vector,
cmap_params=render_params.cmap_params,
seg_erosionpx=None,
seg_boundaries=render_params.outline,
na_color=render_params.cmap_params.na_color,
)
if table_name is None:
instance_id = np.unique(label)
table = None
else:
regions, region_key, instance_key = get_table_keys(sdata[table_name])
table = sdata[table_name][sdata[table_name].obs[region_key].isin([element])]

# Then overlay the contour
labels_contour = _map_color_seg(
seg=label.values,
cell_id=instance_id,
color_vector=color_vector,
color_source_vector=color_source_vector,
cmap_params=render_params.cmap_params,
seg_erosionpx=render_params.contour_px,
seg_boundaries=render_params.outline,
na_color=render_params.cmap_params.na_color,
)
# get instance id based on subsetted table
instance_id = table.obs[instance_key].values

_cax = ax.imshow(
labels_contour,
rasterized=True,
cmap=None if categorical else render_params.cmap_params.cmap,
norm=None if categorical else render_params.cmap_params.norm,
alpha=render_params.outline_alpha,
origin="lower",
)
_cax = ax.imshow(
labels_infill,
rasterized=True,
cmap=None if categorical else render_params.cmap_params.cmap,
norm=None if categorical else render_params.cmap_params.norm,
alpha=render_params.fill_alpha,
origin="lower",
)
_cax.set_transform(trans_data)
cax = ax.add_image(_cax)
else:
# Default: no alpha, contour = infill
label = _map_color_seg(
seg=label.values,
cell_id=instance_id,
color_vector=color_vector,
color_source_vector=color_source_vector,
cmap_params=render_params.cmap_params,
seg_erosionpx=render_params.contour_px,
seg_boundaries=render_params.outline,
na_color=render_params.cmap_params.na_color,
)
trans = get_transformation(label, get_all=True)[coordinate_system]
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
trans = mtransforms.Affine2D(matrix=affine_trans)
trans_data = trans + ax.transData

_cax = ax.imshow(
label,
rasterized=True,
cmap=None if categorical else render_params.cmap_params.cmap,
norm=None if categorical else render_params.cmap_params.norm,
alpha=render_params.fill_alpha,
origin="lower",
)
_cax.set_transform(trans_data)
cax = ax.add_image(_cax)
color_source_vector, color_vector, categorical = _set_color_source_vec(
sdata=sdata_filt,
element=label,
element_name=element,
value_to_plot=color,
groups=groups,
palette=palette,
na_color=render_params.cmap_params.na_color,
cmap_params=render_params.cmap_params,
table_name=table_name,
)

_ = _decorate_axs(
ax=ax,
cax=cax,
fig_params=fig_params,
adata=table,
value_to_plot=color,
if (render_params.fill_alpha != render_params.outline_alpha) and render_params.contour_px is not None:
# First get the labels infill and plot them
labels_infill = _map_color_seg(
seg=label.values,
cell_id=instance_id,
color_vector=color_vector,
color_source_vector=color_source_vector,
cmap_params=render_params.cmap_params,
seg_erosionpx=None,
seg_boundaries=render_params.outline,
na_color=render_params.cmap_params.na_color,
)

# Then overlay the contour
labels_contour = _map_color_seg(
seg=label.values,
cell_id=instance_id,
color_vector=color_vector,
color_source_vector=color_source_vector,
palette=palettes[i],
cmap_params=render_params.cmap_params,
seg_erosionpx=render_params.contour_px,
seg_boundaries=render_params.outline,
na_color=render_params.cmap_params.na_color,
)

_cax = ax.imshow(
labels_contour,
rasterized=True,
cmap=None if categorical else render_params.cmap_params.cmap,
norm=None if categorical else render_params.cmap_params.norm,
alpha=render_params.outline_alpha,
origin="lower",
)
_cax = ax.imshow(
labels_infill,
rasterized=True,
cmap=None if categorical else render_params.cmap_params.cmap,
norm=None if categorical else render_params.cmap_params.norm,
alpha=render_params.fill_alpha,
origin="lower",
)
_cax.set_transform(trans_data)
cax = ax.add_image(_cax)
else:
# Default: no alpha, contour = infill
label = _map_color_seg(
seg=label.values,
cell_id=instance_id,
color_vector=color_vector,
color_source_vector=color_source_vector,
cmap_params=render_params.cmap_params,
seg_erosionpx=render_params.contour_px,
seg_boundaries=render_params.outline,
na_color=render_params.cmap_params.na_color,
legend_fontsize=legend_params.legend_fontsize,
legend_fontweight=legend_params.legend_fontweight,
legend_loc=legend_params.legend_loc,
legend_fontoutline=legend_params.legend_fontoutline,
na_in_legend=legend_params.na_in_legend,
colorbar=legend_params.colorbar,
scalebar_dx=scalebar_params.scalebar_dx,
scalebar_units=scalebar_params.scalebar_units,
# scalebar_kwargs=scalebar_params.scalebar_kwargs,
)

_cax = ax.imshow(
label,
rasterized=True,
cmap=None if categorical else render_params.cmap_params.cmap,
norm=None if categorical else render_params.cmap_params.norm,
alpha=render_params.fill_alpha,
origin="lower",
)
_cax.set_transform(trans_data)
cax = ax.add_image(_cax)

_ = _decorate_axs(
ax=ax,
cax=cax,
fig_params=fig_params,
adata=table,
value_to_plot=color,
color_source_vector=color_source_vector,
palette=palette,
alpha=render_params.fill_alpha,
na_color=render_params.cmap_params.na_color,
legend_fontsize=legend_params.legend_fontsize,
legend_fontweight=legend_params.legend_fontweight,
legend_loc=legend_params.legend_loc,
legend_fontoutline=legend_params.legend_fontoutline,
na_in_legend=legend_params.na_in_legend,
colorbar=legend_params.colorbar,
scalebar_dx=scalebar_params.scalebar_dx,
scalebar_units=scalebar_params.scalebar_units,
# scalebar_kwargs=scalebar_params.scalebar_kwargs,
)
Loading