Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def show(
wants_labels = False
wants_points = False
wants_shapes = False
wanted_elements = []
wanted_elements: list[str] = []

for cmd, params in render_cmds:
# We create a copy here as the wanted elements can change from one cs to another.
Expand Down
182 changes: 95 additions & 87 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections import abc
from copy import copy
from typing import Union
from typing import Union, cast

import dask
import geopandas as gpd
Expand Down Expand Up @@ -107,7 +107,12 @@ def _render_shapes(
color_vector = [render_params.cmap_params.na_color]

# filter by `groups`
if render_params.groups[index][0] is not None and color_source_vector is not None:

if (
isinstance(render_params.groups, list)
and render_params.groups[index][0] is not None
and color_source_vector is not None
):
mask = color_source_vector.isin(render_params.groups[index])
shapes = shapes[mask]
shapes = shapes.reset_index()
Expand Down Expand Up @@ -256,7 +261,7 @@ def _render_points(
palette=render_params.palette[index] if render_params.palette[index][0] is not None else None,
na_color=default_color,
cmap_params=render_params.cmap_params,
table_name=table_name,
table_name=cast(str, table_name),
)

# color_source_vector is None when the values aren't categorical
Expand Down Expand Up @@ -396,10 +401,11 @@ def _render_images(
if render_params.cmap_params.norm is not None: # type: ignore[attr-defined]
layer = render_params.cmap_params.norm(layer) # type: ignore[attr-defined]

if render_params.palette[i][0] is None:
cmap = render_params.cmap_params.cmap # type: ignore[attr-defined]
else:
cmap = _get_linear_colormap(render_params.palette[i], "k")[0] # type: ignore[arg-type]
if isinstance(render_params.palette, list):
if render_params.palette[i][0] is None:
cmap = render_params.cmap_params.cmap # type: ignore[attr-defined]
else:
cmap = _get_linear_colormap(render_params.palette[i], "k")[0] # type: ignore[arg-type]

# Overwrite alpha in cmap: https://stackoverflow.com/a/10127675
cmap._init()
Expand Down Expand Up @@ -433,97 +439,99 @@ def _render_images(
layers[c] = render_params.cmap_params[ch_index].norm(layers[c])

# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
if (
n_channels == 3
and render_params.palette[i][0] is None
and not isinstance(render_params.cmap_params, list)
):
if render_params.cmap_params.is_default: # -> use RGB
stacked = np.stack([layers[c] for c in channels], axis=-1)
else: # -> use given cmap for each channel
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
# Apply cmaps to each channel, add up and normalize to [0, 1]
stacked = (
np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0)
/ n_channels
)
# Remove alpha channel so we can overwrite it from render_params.alpha
stacked = stacked[:, :, :3]
logger.warning(
"One cmap was given for multiple channels and is now used for each channel. "
"You're blending multiple cmaps. "
"If the plot doesn't look like you expect, it might be because your "
"cmaps go from a given color to 'white', and not to 'transparent'. "
"Therefore, the 'white' of higher layers will overlay the lower layers. "
"Consider using 'palette' instead."
if isinstance(render_params.palette, list):
if (
n_channels == 3
and render_params.palette[i][0] is None
and not isinstance(render_params.cmap_params, list)
):
if render_params.cmap_params.is_default: # -> use RGB
stacked = np.stack([layers[c] for c in channels], axis=-1)
else: # -> use given cmap for each channel
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
# Apply cmaps to each channel, add up and normalize to [0, 1]
stacked = (
np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0)
/ n_channels
)
# Remove alpha channel so we can overwrite it from render_params.alpha
stacked = stacked[:, :, :3]
logger.warning(
"One cmap was given for multiple channels and is now used for each channel. "
"You're blending multiple cmaps. "
"If the plot doesn't look like you expect, it might be because your "
"cmaps go from a given color to 'white', and not to 'transparent'. "
"Therefore, the 'white' of higher layers will overlay the lower layers. "
"Consider using 'palette' instead."
)

im = ax.imshow(
stacked,
alpha=render_params.alpha,
)
im.set_transform(trans_data)

im = ax.imshow(
stacked,
alpha=render_params.alpha,
)
im.set_transform(trans_data)
# 2B) Image has n channels, no palette/cmap info -> sample n categorical colors
elif render_params.palette[i][0] is None and not got_multiple_cmaps:
# overwrite if n_channels == 2 for intuitive result
if n_channels == 2:
seed_colors = ["#ff0000ff", "#00ff00ff"]
else:
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))

# 2B) Image has n channels, no palette/cmap info -> sample n categorical colors
elif render_params.palette[i][0] is None and not got_multiple_cmaps:
# overwrite if n_channels == 2 for intuitive result
if n_channels == 2:
seed_colors = ["#ff0000ff", "#00ff00ff"]
else:
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]

channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
# Apply cmaps to each channel and add up
colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0)

# Apply cmaps to each channel and add up
colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0)

# Remove alpha channel so we can overwrite it from render_params.alpha
colored = colored[:, :, :3]
# Remove alpha channel so we can overwrite it from render_params.alpha
colored = colored[:, :, :3]

im = ax.imshow(
colored,
alpha=render_params.alpha,
)
im.set_transform(trans_data)
im = ax.imshow(
colored,
alpha=render_params.alpha,
)
im.set_transform(trans_data)

# 2C) Image has n channels and palette info
elif render_params.palette[i][0] is not None and not got_multiple_cmaps:
if len(render_params.palette[i]) != n_channels:
raise ValueError("If 'palette' is provided, its length must match the number of channels.")
# 2C) Image has n channels and palette info
elif render_params.palette[i][0] is not None and not got_multiple_cmaps:
if len(render_params.palette[i]) != n_channels:
raise ValueError("If 'palette' is provided, its length must match the number of channels.")

channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in render_params.palette[i]]
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in render_params.palette[i]]

# Apply cmaps to each channel and add up
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)
# Apply cmaps to each channel and add up
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)

# Remove alpha channel so we can overwrite it from render_params.alpha
colored = colored[:, :, :3]
# Remove alpha channel so we can overwrite it from render_params.alpha
colored = colored[:, :, :3]

im = ax.imshow(
colored,
alpha=render_params.alpha,
)
im.set_transform(trans_data)
im = ax.imshow(
colored,
alpha=render_params.alpha,
)
im.set_transform(trans_data)

elif render_params.palette[i][0] is None and got_multiple_cmaps:
channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr]
elif render_params.palette[i][0] is None and got_multiple_cmaps:
channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr]

# Apply cmaps to each channel, add up and normalize to [0, 1]
colored = (
np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) / n_channels
)
# Apply cmaps to each channel, add up and normalize to [0, 1]
colored = (
np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0)
/ n_channels
)

# Remove alpha channel so we can overwrite it from render_params.alpha
colored = colored[:, :, :3]
# Remove alpha channel so we can overwrite it from render_params.alpha
colored = colored[:, :, :3]

im = ax.imshow(
colored,
alpha=render_params.alpha,
)
im.set_transform(trans_data)
im = ax.imshow(
colored,
alpha=render_params.alpha,
)
im.set_transform(trans_data)

elif render_params.palette[i][0] is not None and got_multiple_cmaps:
raise ValueError("If 'palette' is provided, 'cmap' must be None.")
elif render_params.palette[i][0] is not None and got_multiple_cmaps:
raise ValueError("If 'palette' is provided, 'cmap' must be None.")


def _render_labels(
Expand All @@ -537,7 +545,7 @@ def _render_labels(
rasterize: bool,
) -> None:
elements = render_params.elements
element_table_mapping = render_params.element_table_mapping
element_table_mapping = cast(dict[str, str], render_params.element_table_mapping)

sdata_filt = sdata.filter_by_coordinate_system(
coordinate_system=coordinate_system,
Expand Down Expand Up @@ -573,7 +581,7 @@ def _render_labels(
extent=extent,
)

table_name = element_table_mapping.get(e)
table_name = mapping.get(e) if isinstance((mapping := element_table_mapping), dict) else None
if table_name is None:
instance_id = np.unique(label)
table = None
Expand All @@ -595,12 +603,12 @@ def _render_labels(
element=label,
element_index=i,
element_name=e,
value_to_plot=render_params.color[i],
value_to_plot=cast(list[str], render_params.color)[i],
groups=render_params.groups,
palette=render_params.palette,
na_color=render_params.cmap_params.na_color,
cmap_params=render_params.cmap_params,
table_name=table_name,
table_name=cast(str, table_name),
)

if (render_params.fill_alpha != render_params.outline_alpha) and render_params.contour_px is not None:
Expand Down Expand Up @@ -676,7 +684,7 @@ def _render_labels(
cax=cax,
fig_params=fig_params,
adata=table,
value_to_plot=render_params.color[i],
value_to_plot=cast(list[str], render_params.color)[i],
color_source_vector=color_source_vector,
palette=render_params.palette,
alpha=render_params.fill_alpha,
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class LabelsRenderParams:

cmap_params: CmapParams
elements: str | Sequence[str] | None = None
color: str | None = None
color: list[str | None] | str | None = None
groups: str | Sequence[str] | None = None
contour_px: int | None = None
outline: bool = False
Expand Down
20 changes: 10 additions & 10 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import partial
from pathlib import Path
from types import MappingProxyType
from typing import Any, Literal, Union
from typing import Any, Literal, Union, cast

import matplotlib
import matplotlib.patches as mpatches
Expand Down Expand Up @@ -618,7 +618,7 @@ def _set_color_source_vec(

# numerical case, return early
if color_source_vector is not None and not isinstance(color_source_vector.dtype, pd.CategoricalDtype):
if palette[0] is not None:
if isinstance(palette, list) and palette[0] is not None:
logger.warning(
"Ignoring categorical palette which is given for a continuous variable. "
"Consider using `cmap` to pass a ColorMap."
Expand All @@ -633,8 +633,9 @@ def _set_color_source_vec(
categories = groups

if groups is not None:
palette_input = palette[0] if palette[0] is None else palette
elif palette is not None:
if isinstance(palette, list):
palette_input = palette[0] if palette[0] is None else palette
elif palette is not None and isinstance(palette, list):
palette_input = palette[0]
else:
palette_input = palette
Expand Down Expand Up @@ -1345,7 +1346,7 @@ def _update_element_table_mapping_label_colors(
element_table_mapping = params.element_table_mapping

# If one color column check presence for each table annotating the specific element
if len(params.color) == 1:
if isinstance(params.color, list) and len(params.color) == 1:
params.color = params.color * len(render_elements)
for element_name in render_elements:
for table_name in element_table_mapping[element_name].copy():
Expand Down Expand Up @@ -1382,7 +1383,7 @@ def _update_element_table_mapping_label_colors(
def _validate_colors_element_table_mapping_points_shapes(
sdata: SpatialData, params: PointsRenderParams | ShapesRenderParams, render_elements: list[str]
) -> PointsRenderParams | ShapesRenderParams:
element_table_mapping = params.element_table_mapping
element_table_mapping = cast(dict, params.element_table_mapping)
if len(params.color) == 1:
color = params.color[0]
col_color = params.col_for_color[0]
Expand Down Expand Up @@ -1558,7 +1559,7 @@ def _validate_render_params(
alpha: float | int | None = None,
channel: list[str] | list[int] | str | int | None = None,
cmap: list[Colormap] | Colormap | str | None = None,
color: str | None = None,
color: list[str] | str | None = None,
contour_px: int | None = None,
elements: list[str] | str | None = None,
fill_alpha: float | int | None = None,
Expand Down Expand Up @@ -1792,7 +1793,7 @@ def _match_length_elements_groups_palette(
render_elements: list[str],
image: bool = False,
):
if image:
if image and isinstance(params, ImageRenderParams):
if params.palette is None:
params.palette = [[None] for _ in range(len(render_elements))]
else:
Expand Down Expand Up @@ -1849,5 +1850,4 @@ def _update_params(sdata, params, wanted_elements_on_cs, element_type: Literal["
# if params.palette is None:
# params.palette = [[None] for _ in wanted_elements_on_cs]
image_flag = element_type == "images"
params = _match_length_elements_groups_palette(params, wanted_elements_on_cs, image=image_flag)
return params
return _match_length_elements_groups_palette(params, wanted_elements_on_cs, image=image_flag)