Skip to content
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ and this project adheres to [Semantic Versioning][].

### Fixed

- Fix color assignment for NaN values (#257)
- Fix channel str support #221
- Fixed channel str support (#221)
- Fixed color assignment for NaN values (#257)
- Updated incorrect link to documentation (#261)
- Fixed plotting of categorical data (#262)

## [0.2.2] - 2024-05-02

Expand Down
17 changes: 10 additions & 7 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,16 +503,16 @@ def render_images(
def render_labels(
self,
elements: list[str] | str | None = None,
color: list[str | None] | str | None = None,
groups: list[list[str | None]] | list[str | None] | str | None = None,
color: list[str] | str | None = None,
groups: list[str] | str | None = None,
contour_px: int = 3,
outline: bool = False,
palette: list[list[str | None]] | list[str | None] | str | None = None,
palette: list[ColorLike] | ColorLike | None = None,
cmap: Colormap | str | None = None,
norm: Normalize | None = None,
na_color: ColorLike | None = (0.0, 0.0, 0.0, 0.0),
na_color: ColorLike | None = "lightgrey",
outline_alpha: float | int = 1.0,
fill_alpha: float | int = 0.3,
fill_alpha: float | int = 0.35, # 0.3
scale: list[str] | str | None = None,
table_name: list[str] | str | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -554,8 +554,11 @@ def render_labels(
Colormap for continuous annotations, see :class:`matplotlib.colors.Colormap`.
norm : Normalize | None, optional
Colormap normalization for continuous annotations, see :class:`matplotlib.colors.Normalize`.
na_color : ColorLike | None, optional
Color to be used for NAs values, if present.
na_color : str | list[float] | None, default "lightgrey"
Color to be used for NAs values, if present. Can either be a named color
("red"), a hex representation ("#000000ff") or a list of floats that
represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0). When None, the values won't
be shown.
outline_alpha : float | int, default 1.0
Alpha value for the outline of the labels.
fill_alpha : float | int, default 0.3
Expand Down
24 changes: 21 additions & 3 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,25 @@ def _render_shapes(
)
sdata_filt[table_name].obs[col_for_color] = sdata_filt[table_name].obs[col_for_color].astype("category")

assert isinstance(groups, list), "groups must be a list"
if isinstance(groups[index], list) and groups[index][0] is None:
group = None
elif all(isinstance(g, str) for g in groups[index]):
group = groups[index]
else:
raise ValueError("groups must be a list of strings or a list of lists of strings")

if group is not None or (isinstance(group, list) and all(isinstance(g, str) for g in group)):
raise ValueError("groups must be a list of strings or a list of lists of strings")

# get color vector (categorical or continuous)
color_source_vector, color_vector, _ = _set_color_source_vec(
sdata=sdata_filt,
element=sdata_filt.shapes[e],
element_index=index,
element_name=e,
value_to_plot=col_for_color,
groups=groups[index] if groups[index][0] is not None else None,
groups=group,
palette=(
palettes[index] if palettes is not None else None
), # and render_params.palette[index][0] is not None
Expand Down Expand Up @@ -192,6 +203,7 @@ def _render_shapes(
fig_params=fig_params,
adata=table,
value_to_plot=col_for_color,
color_vector=color_vector,
color_source_vector=color_source_vector,
palette=palette,
alpha=render_params.fill_alpha,
Expand Down Expand Up @@ -349,6 +361,7 @@ def _render_points(
fig_params=fig_params,
adata=adata,
value_to_plot=col_for_color,
color_vector=color_vector,
color_source_vector=color_source_vector,
palette=palette,
alpha=render_params.alpha,
Expand Down Expand Up @@ -596,6 +609,7 @@ def _render_labels(
palettes = _return_list_list_str_none(render_params.palette)
colors = _return_list_str_none(render_params.color)
groups = _return_list_list_str_none(render_params.groups)
print(element_table_mapping)

if render_params.outline is False:
render_params.outline_alpha = 0
Expand Down Expand Up @@ -640,11 +654,11 @@ def _render_labels(
instance_id = np.unique(label)
table = None
else:
regions, region_key, instance_key = get_table_keys(sdata[table_name])
_, region_key, instance_key = get_table_keys(sdata[table_name])
table = sdata[table_name][sdata[table_name].obs[region_key].isin([e])]

# get instance id based on subsetted table
instance_id = table.obs[instance_key].values
instance_id = np.unique(table.obs[instance_key].values)

trans = get_transformation(label, get_all=True)[coordinate_system]
affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
Expand Down Expand Up @@ -731,13 +745,17 @@ def _render_labels(
_cax.set_transform(trans_data)
cax = ax.add_image(_cax)

if groups[i][0] is not None and color_source_vector is not None:
color_source_vector = color_source_vector.set_categories(groups[i])

_ = _decorate_axs(
ax=ax,
cax=cax,
fig_params=fig_params,
adata=table,
value_to_plot=color,
color_source_vector=color_source_vector,
color_vector=color_vector,
palette=palettes[i],
alpha=render_params.fill_alpha,
na_color=render_params.cmap_params.na_color,
Expand Down
Loading