Source code for dandelion.polars.plotting._plotting

from __future__ import annotations

import circlify
import warnings

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    import nxviz as nxv
import pandas as pd
import seaborn as sns

from anndata import AnnData
from collections import defaultdict
from contextlib import contextmanager
from itertools import product, cycle
from matplotlib.axes import Axes
from matplotlib.collections import PatchCollection
from matplotlib.figure import Figure
from nxviz import annotate

from scanpy.plotting import palettes
from scanpy.plotting._tools.scatterplots import embedding
from typing import Callable, Literal, TYPE_CHECKING

from dandelion.polars.core._core import DandelionPolars

if TYPE_CHECKING:
    from mudata import MuData


[docs] def clone_network( adata: AnnData | MuData, basis: str = "vdj", edges: bool = True, **kwargs ) -> None: """ Using scanpy's plotting module to plot the network. Only thing that is changed is the default options: `basis = 'vdj'` and `edges = True`. Parameters ---------- adata : AnnData | MuData AnnData or scirpy-formatted MuData object. basis : str, optional key for embedding. edges : bool, optional whether or not to plot edges. **kwargs passed `sc.pl.embedding`. """ is_mudata = hasattr(adata, "mod") base_adata = adata.mod["airr"] if is_mudata else adata with _temporary_obs_columns( base_adata, adata if is_mudata else None, **kwargs ) as kw: embedding(base_adata, basis=basis, edges=edges, **kw)
[docs] def barplot( data: AnnData | DandelionPolars, color: str, palette: str = "Set1", figsize: tuple[float, float] = (8, 3), normalize: bool = True, sort_descending: bool = True, title: str | None = None, xtick_fontsize: int | None = None, xtick_rotation: int | float | None = None, min_clone_size: int = 1, clone_key: str | None = None, **kwargs, ) -> tuple[Figure, Axes]: """ A barplot function to plot usage of V/J genes in the data. Parameters ---------- data : AnnData | DandelionPolars DandelionPolars or AnnData object. color : str column name in metadata for plotting in bar plot. palette : str, optional Colors to use for the different levels of the color variable. Should be something that can be interpreted by [color_palette](https://seaborn.pydata.org/generated/seaborn.color_palette.html#seaborn.color_palette), or a dictionary mapping hue levels to matplotlib colors. See [seaborn.barplot](https://seaborn.pydata.org/generated/seaborn.barplot.html). figsize : tuple[float, float], optional figure size. normalize : bool, optional if True, will return as proportion out of 1. Otherwise False will return counts. sort_descending : bool, optional whether or not to sort the order of the plot. title : str | None, optional title of plot. xtick_fontsize : int | None, optional size of x tick labels xtick_rotation : int | float | None, optional rotation of x tick labels. min_clone_size : int, optional minimum clone size to keep. clone_key : str | None, optional column name for clones. None defaults to 'clone_id'. **kwargs passed to `sns.barplot`. Returns ------- tuple[Figure, Axes] bar plot. """ if isinstance(data, DandelionPolars): data = data.copy() if data._backend == "polars": data.to_pandas() data = data._metadata.copy() elif isinstance(data, AnnData): data = data.obs.copy() min_size = min_clone_size if clone_key is None: clone_ = "clone_id" else: clone_ = clone_key size = data[clone_].value_counts() keep = list(size[size >= min_size].index) data_ = data[data[clone_].isin(keep)] sns.set_style("whitegrid", {"axes.grid": False}) res = pd.DataFrame(data_[color].value_counts(normalize=normalize)) if not sort_descending: res = res.sort_index() res.reset_index(drop=False, inplace=True) # Initialize the matplotlib figure fig, ax = plt.subplots(figsize=figsize) # plot try: sns.barplot(x="index", y=color, data=res, palette=palette, **kwargs) except ValueError: yname = "proportion" if normalize else "count" sns.barplot(x=color, y=yname, data=res, palette=palette, **kwargs) # change some parts if title is None: ax.set_title(color.replace("_", " ") + " usage") else: ax.set_title(title) if normalize: ax.set_ylabel("proportion") else: ax.set_ylabel("count") ax.set_xlabel("") # modify the x ticks accordingly xtick_params = {} if xtick_rotation is None: xtick_params["rotation"] = 90 else: xtick_params["rotation"] = xtick_rotation if xtick_fontsize is not None: xtick_params["fontsize"] = xtick_fontsize plt.xticks(**xtick_params) return fig, ax
[docs] def stackedbarplot( data: AnnData | DandelionPolars, color: str, group_by: str | None, figsize: tuple[float, float] = (8, 3), normalize: bool = False, title: str | None = None, sort_descending: bool = True, xtick_fontsize: int | None = None, xtick_rotation: int | float | None = None, hide_legend: bool = False, legend_options: tuple[str, tuple[float, float], int] = ( "upper left", (1, 1), 1, ), labels: list[str] | None = None, min_clone_size: int = 1, clone_key: str | None = None, **kwargs, ) -> tuple[Figure, Axes]: """ A stacked bar plot function to plot usage of V/J genes in the data split by groups. Parameters ---------- data : AnnData | Dandelion Dandelion or AnnData object. color : str column name in metadata for plotting in bar plot. group_by : str | None column name in metadata to split by during plotting. figsize : tuple[float, float], optional figure size. normalize : bool, optional if True, will return as proportion out of 1, otherwise False will return counts. title : str | None, optional title of plot. sort_descending : bool, optional whether or not to sort the order of the plot. xtick_fontsize : int | None, optional size of x tick labels xtick_rotation : int | float | None, optional rotation of x tick labels. hide_legend : bool, optional whether or not to hide the legend. legend_options : tuple[str, tuple[float, float], int], optional a tuple holding 3 options for specify legend options: 1) loc (string), 2) bbox_to_anchor (tuple), 3) ncol (int). labels : list[str] | None, optional Names of objects will be used for the legend if list of multiple data frames supplied. min_clone_size : int, optional minimum clone size to keep. clone_key : str | None, optional column name for clones. None defaults to 'clone_id'. **kwargs other kwargs passed to `matplotlib.plt`. Returns ------- tuple[Figure, Axes] stacked barplot. """ if isinstance(data, DandelionPolars): data = data.copy() if data._backend == "polars": data.to_pandas() data = data._metadata.copy() elif isinstance(data, AnnData): data = data.obs.copy() # quick fix to prevent dropping of nan data[group_by] = [str(l) for l in data[group_by]] min_size = min_clone_size if clone_key is None: clone_ = "clone_id" else: clone_ = clone_key size = data[clone_].value_counts() keep = list(size[size >= min_size].index) data_ = data[data[clone_].isin(keep)] dat_ = pd.DataFrame( data_.groupby(color)[group_by] .value_counts(normalize=normalize) .unstack(fill_value=0) .stack(), columns=["value"], ) dat_.reset_index(drop=False, inplace=True) dat_order = pd.DataFrame(data[color].value_counts(normalize=normalize)) dat_ = dat_.pivot(index=color, columns=group_by, values="value") if sort_descending is True: dat_ = dat_.reindex(dat_order.index) elif sort_descending is False: dat_ = dat_.reindex(dat_order.index[::-1]) elif sort_descending is None: dat_ = dat_.sort_index() def _plot_bar_stacked( dfall: pd.DataFrame, labels: list[str] | None = None, figsize: tuple[float, float] = (8, 3), title: str = "multiple stacked bar plot", xtick_fontsize: int | None = None, xtick_rotation: int | float | None = None, legend_options: tuple[str, tuple[float, float], int] = None, hide_legend: bool = False, H: Literal["/"] = "/", **kwargs, ) -> tuple[Figure, Axes]: """ Given a list of data frames, with identical columns and index, create a clustered stacked bar plot. Parameters ---------- dfall : pd.DataFrame data frame for plotting. labels : list[str] | None, optional a list of the data frame objects. Names of objects will be used for the legend. figsize : tuple[float, float], optional size of figure. title : str, optional string for the title of the plot xtick_fontsize : int | None, optional xtick fontsize. xtick_rotation : int | float | None, optional rotation of xtick labels legend_options : tuple[str, tuple[float, float], int], optional legend options. hide_legend : bool, optional whether to show legend. H : Literal["/"], optional is the hatch used for identification of the different data frames **kwargs other kwargs passed to matplotlib.plt Returns ------- tuple[Figure, Axes] stacked barplot. """ if type(dfall) is not list: dfall = [dfall] n_df = len(dfall) n_col = len(dfall[0].columns) n_ind = len(dfall[0].index) # Initialize the matplotlib figure fig, ax = plt.subplots(figsize=figsize) for df in dfall: # for each data frame ax = df.plot( kind="bar", linewidth=0, stacked=True, ax=ax, legend=False, grid=False, **kwargs, ) # make bar plots ( h, l, ) = ax.get_legend_handles_labels() # get the handles we want to modify for i in range(0, n_df * n_col, n_col): # len(h) = n_col * n_df for j, pa in enumerate(h[i : i + n_col]): for rect in pa.patches: # for each index rect.set_x( rect.get_x() + 1 / float(n_df + 1) * i / float(n_col) ) rect.set_hatch(H * int(i / n_col)) # edited part rect.set_width(1 / float(n_df + 1)) ax.set_xticks((np.arange(0, 2 * n_ind, 2) + 1 / float(n_df + 1)) / 2.0) ax.set_xticklabels(df.index, rotation=0) ax.set_title(title) if normalize: ax.set_ylabel("proportion") else: ax.set_ylabel("count") # Add invisible data to add another legend n = [] for i in range(n_df): n.append(ax.bar(0, 0, color="grey", hatch=H * i)) if legend_options is None: Legend = ("center right", (1.15, 0.5), 1) else: Legend = legend_options if hide_legend is False: l1 = ax.legend( h[:n_col], l[:n_col], loc=Legend[0], bbox_to_anchor=Legend[1], ncol=Legend[2], frameon=False, ) if labels is not None: l2 = plt.legend( n, labels, loc=Legend[0], bbox_to_anchor=Legend[1], ncol=Legend[2], frameon=False, ) ax.add_artist(l2) ax.add_artist(l1) # modify the x ticks accordingly xtick_params = {} if xtick_rotation is None: xtick_params["rotation"] = 90 else: xtick_params["rotation"] = xtick_rotation if xtick_fontsize is not None: xtick_params["fontsize"] = xtick_fontsize plt.xticks(**xtick_params) return fig, ax if title is None: title = ( "multiple stacked bar plot : " + color.replace("_", " ") + " usage" ) else: title = title return _plot_bar_stacked( dat_, labels=labels, figsize=figsize, title=title, xtick_fontsize=xtick_fontsize, xtick_rotation=xtick_rotation, legend_options=legend_options, hide_legend=hide_legend, **kwargs, )
[docs] def spectratype( vdj: DandelionPolars, color: str, group_by: str, locus: str, figsize: tuple[float, float] = (5, 3), width: int | float | None = None, title: str | None = None, xtick_fontsize: int | None = None, xtick_rotation: int | float | None = None, hide_legend: bool = False, legend_options: tuple[str, tuple[float, float], int] = ( "upper left", (1, 1), 1, ), labels: list[str] | None = None, **kwargs, ) -> tuple[Figure, Axes]: """ A spectratype function to plot usage of CDR3 length. Parameters ---------- vdj : Dandelion Dandelion object. color : str column name in metadata for plotting in bar plot. group_by : str column name in metadata to split by during plotting. locus : str either IGH or IGL. figsize : tuple[float, float], optional figure size. width : int | float | None, optional width of bars. title : str | None, optional title of plot. xtick_fontsize : int | None, optional size of x tick labels xtick_rotation : int | float | None, optional rotation of x tick labels. hide_legend : bool, optional whether or not to hide the legend. legend_options : tuple[str, tuple[float, float], int], optional a tuple holding 3 options for specify legend options: 1) loc (string), 2) bbox_to_anchor (tuple), 3) ncol (int). labels : list[str] | None, optional Names of objects will be used for the legend if list of multiple data frames supplied. **kwargs other kwargs passed to matplotlib.pyplot.plot Returns ------- tuple[Figure, Axes] spectratype plot. """ if vdj._backend == "polars": vdj.to_pandas() data = vdj._data.copy() if "ambiguous" in data: data = data[data["ambiguous"] == "F"].copy() if type(locus) is not list: locus = [locus] data = data[data["locus"].isin(locus)].copy() data[group_by] = [str(l) for l in data[group_by]] dat_ = pd.DataFrame( data.groupby(color)[group_by] .value_counts(normalize=False) .unstack(fill_value=0) .stack(), columns=["value"], ) dat_.reset_index(drop=False, inplace=True) dat_[color] = pd.to_numeric(dat_[color], errors="coerce") dat_.sort_values(by=color) dat_2 = dat_.pivot(index=color, columns=group_by, values="value") new_index = range(0, int(dat_[color].max()) + 1) dat_2 = dat_2.reindex(new_index, fill_value=0) def _plot_spectra_stacked( dfall: pd.DataFrame, labels: list[str] | None = None, figsize: tuple[float, float] = (5, 3), title: str = "multiple stacked bar plot", width: int | float | None = None, xtick_fontsize: int | None = None, xtick_rotation: int | float | None = None, legend_options: tuple[str, tuple[float, float], int] = None, hide_legend: bool = False, H: Literal["/"] = "/", **kwargs, ) -> tuple[Figure, Axes]: """Stacked spectratype plots. Parameters ---------- dfall : pd.DataFrame data frame for plotting. labels : list[str] | None, optional a list of the data frame objects. Names of objects will be used for the legend. figsize : tuple[float, float], optional size of figure. title : str, optional string for the title of the plot. width : int | float | None, optional width of bars. xtick_fontsize : int | None, optional size of x tick labels xtick_rotation : int | float | None, optional rotation of x tick labels. legend_options : tuple[str, tuple[float, float], int], optional a tuple holding 3 options for specify legend options: 1) loc (string), 2) bbox_to_anchor (tuple), 3) ncol (int). hide_legend : bool, optional whether or not to hide the legend. H : Literal["/"], optional not sure. **kwargs other kwargs passed to matplotlib.plt Returns ------- tuple[Figure, Axes] spectratype plot. """ if type(dfall) is not list: dfall = [dfall] n_df = len(dfall) n_col = len(dfall[0].columns) n_ind = len(dfall[0].index) if width is None: wdth = 0.1 * n_ind / 60 + 0.8 else: wdth = width # Initialize the matplotlib figure fig, ax = plt.subplots(figsize=figsize) for df in dfall: # for each data frame ax = df.plot( kind="bar", linewidth=0, stacked=True, ax=ax, legend=False, grid=False, **kwargs, ) # make bar plots ( h, l, ) = ax.get_legend_handles_labels() # get the handles we want to modify for i in range(0, n_df * n_col, n_col): # len(h) = n_col * n_df for j, pa in enumerate(h[i : i + n_col]): for rect in pa.patches: # for each index rect.set_x( rect.get_x() + 1 / float(n_df + 1) * i / float(n_col) ) rect.set_hatch(H * int(i / n_col)) # edited part # need to see if there's a better way to toggle this. rect.set_width(wdth) n = 5 # Keeps every 5th label visible and hides the rest [ l.set_visible(False) for (i, l) in enumerate(ax.xaxis.get_ticklabels()) if i % n != 0 ] ax.set_title(title) ax.set_ylabel("count") # Add invisible data to add another legend n = [] for i in range(n_df): n.append(ax.bar(0, 0, color="gray", hatch=H * i)) if legend_options is None: Legend = ("center right", (1.25, 0.5), 1) else: Legend = legend_options if hide_legend is False: l1 = ax.legend( h[:n_col], l[:n_col], loc=Legend[0], bbox_to_anchor=Legend[1], ncol=Legend[2], frameon=False, ) if labels is not None: l2 = plt.legend( n, labels, loc=Legend[0], bbox_to_anchor=Legend[1], ncol=Legend[2], frameon=False, ) ax.add_artist(l1) # modify the x ticks accordingly xtick_params = {} if xtick_rotation is None: xtick_params["rotation"] = 90 else: xtick_params["rotation"] = xtick_rotation if xtick_fontsize is not None: xtick_params["fontsize"] = xtick_fontsize plt.xticks(**xtick_params) return fig, ax return _plot_spectra_stacked( dat_2, labels=labels, figsize=figsize, title=title, width=width, xtick_fontsize=xtick_fontsize, xtick_rotation=xtick_rotation, legend_options=legend_options, hide_legend=hide_legend, **kwargs, )
[docs] def clone_overlap( adata: AnnData, group_by: str, color_by: str | None = None, weighted_overlap: bool = False, clone_key: str | None = None, color_mapping: list | dict | None = None, node_labels: bool = True, return_graph: bool = False, save: str | None = None, legend_kwargs: dict = { "ncol": 2, "bbox_to_anchor": (1, 0.5), "frameon": False, "loc": "center left", }, node_label_size: int = 10, as_heatmap: bool = False, return_heatmap_data: bool = False, scale_edge_lambda: Callable | None = None, **kwargs, ) -> nxv.CircosPlot: """ A plot function to visualise clonal overlap as a circos-style plot. Originally written with nxviz < 0.7.3. Ported from https://github.com/zktuong/nxviz/tree/custom_color_mapping_circos_nodes_and_edges Parameters ---------- adata : AnnData AnnData object. group_by : str column name in obs for collapsing to nodes in circos plot. color_by : str | None, optional column name in obs for grouping and color of nodes in plot. Must be a same or subcategory of the `group_by` categories e.g. `group_by="group_tissue", color_by="tissue"`. weighted_overlap : bool, optional if True, instead of collapsing to overlap to binary, edge thickness will reflect the number of cells found in the overlap. In the future, there will be the option to use something like a jaccard index instead. clone_key : str | None, optional column name for clones. None defaults to 'clone_id'. color_mapping : list | dict | None, optional custom color mapping provided as a sequence (correpsonding to order of categories or alpha-numeric order ifdtype is not category), or dictionary containing custom {category:color} mapping. node_labels : bool, optional whether to use node objects as labels or not return_graph : bool, optional whether or not to return the graph for fine tuning. save : str | None, optional file path for saving plot legend_kwargs : dict, optional options for adjusting legend placement node_label_size : int, optional size of labels if node_labels = True as_heatmap : bool, optional whether to return plot as heatmap. return_heatmap_data : bool, optional whether to return heatmap data as a pandas dataframe. scale_edge_lambda : Callable | None, optional a lambda function to scale the edge thickness. If None, will not scale. **kwargs passed to `matplotlib.pyplot.savefig`. Returns ------- nxv.CircosPlot a `nxviz.CircosPlot` object. Raises ------ KeyError if `clone_overlap` not found in `adata.uns`. ValueError if input is not AnnData. """ if clone_key is None: clone_ = "clone_id" else: clone_ = clone_key if isinstance(adata, AnnData): data = adata.obs.copy() # get rid of problematic rows that appear because of category conversion? if "clone_overlap" in adata.uns: overlap = adata.uns["clone_overlap"].copy() else: raise KeyError( "`clone_overlap` not found in `adata.uns`. Did you run `tl.clone_overlap`?" ) else: raise ValueError("Please provide a AnnData object.") edges = {} if not weighted_overlap: for x in overlap.index: if overlap.loc[x].sum() > 1: edges[x] = [ y + ({str(clone_): x},) for y in list( product( [ i for i in overlap.loc[x][ overlap.loc[x] > 0 ].index ], repeat=2, ) ) ] else: tmp_overlap = overlap.astype(bool).sum(axis=1) combis = { x: list( product( [i for i in overlap.loc[x][overlap.loc[x] > 0].index], repeat=2, ) ) for x in tmp_overlap.index if tmp_overlap.loc[x] > 1 } tmp_edge_weight_dict = defaultdict(list) for k_clone, val_pair in combis.items(): for pair in val_pair: tmp_edge_weight_dict[pair].append( overlap.loc[k_clone, list(pair)].sum() ) for combix in tmp_edge_weight_dict: if scale_edge_lambda is not None: tmp_edge_weight_dict[combix] = scale_edge_lambda( sum(tmp_edge_weight_dict[combix]) ) else: tmp_edge_weight_dict[combix] = sum(tmp_edge_weight_dict[combix]) for x in overlap.index: if overlap.loc[x].sum() > 1: edges[x] = [ y + ( { str(clone_): x, "weight": ( tmp_edge_weight_dict[y] if not isinstance(tmp_edge_weight_dict[y], list) else 0 ), }, ) for y in list( product( [ i for i in overlap.loc[x][ overlap.loc[x] > 0 ].index ], repeat=2, ) ) ] color_by = group_by if color_by is None else color_by # create graph G = nx.Graph() # add in the nodes G.add_nodes_from( [ (p, {str(color_by): d}) for p, d in zip(data[group_by], data[color_by]) ] ) # unpack the edgelist and add to the graph for edge in edges: G.add_edges_from(edges[edge]) if not weighted_overlap: weighted_attr = None else: weighted_attr = "weight" if color_mapping is None: if str(color_by) + "_colors" in adata.uns: if pd.api.types.is_categorical_dtype(adata.obs[group_by]): color_by_dict = dict( zip( list(adata.obs[str(color_by)].cat.categories), adata.uns[str(color_by) + "_colors"], ) ) else: color_by_dict = dict( zip( list(adata.obs[str(color_by)].unique()), adata.uns[str(color_by) + "_colors"], ) ) else: if len(adata.obs[str(color_by)].unique()) <= 20: pal = cycle(palettes.default_20) elif len(adata.obs[str(color_by)].unique()) <= 28: pal = cycle(palettes.default_28) else: pal = cycle(palettes.default_102) color_by_dict = dict( zip(list(adata.obs[str(color_by)].unique()), pal) ) else: if type(color_mapping) is dict: color_by_dict = color_mapping else: if pd.api.types.is_categorical_dtype(data[group_by]): color_by_dict = dict( zip(list(data[str(color_by)].cat.categories), color_mapping) ) else: color_by_dict = dict( zip(sorted(list(set(data[str(color_by)]))), color_mapping) ) df = data[[group_by, color_by]] if group_by == color_by: df = data[[group_by]] df = ( df.sort_values(group_by) .drop_duplicates(subset=group_by, keep="first") .reset_index(drop=True) ) else: df = ( df.sort_values(color_by) .drop_duplicates(subset=group_by, keep="first") .reset_index(drop=True) ) if as_heatmap: hm = nx.to_pandas_adjacency(G) sns.clustermap(hm, **kwargs) if return_heatmap_data: return hm else: # remove self loops G.remove_edges_from(nx.selfloop_edges(G)) ax = nxv.circos( G, group_by=color_by, node_color_by=color_by, edge_lw_by=weighted_attr, node_palette=color_by_dict, ) # group_by if node_labels: annotate.circos_group( G, group_by=color_by, midpoint=False, fontdict={"size": node_label_size}, ) annotate.node_colormapping( G, color_by=color_by, palette=color_by_dict, legend_kwargs=legend_kwargs, ) if save is not None: plt.savefig(save, bbox_inches="tight", **kwargs) if return_graph: return G
[docs] def productive_ratio( adata: AnnData, figsize: tuple[float, float] = (8, 4), palette: list[str] = ["lightblue", "darkblue"], fontsize: int | float = 8, rotation: int | float = 90, legend_kwargs: dict = { "bbox_to_anchor": (1, 0.5), "loc": "center left", "frameon": False, }, ): """Plot productive/non-productive contig ratio from AnnData (cell level). Parameters ---------- adata : AnnData AnnData object with `.uns['productive_ratio']` computed from `tl.productive_ratio`. figsize : tuple[float, float], optional Size of figure. palette : list[str], optional List of colours to plot non-productive and productive respectively. fontsize : int | float, optional Font size of x and y tick labels. rotation : int | float, optional Rotation of x tick labels. legend_kwargs : dict, optional Any additional kwargs to `plt.legend` """ res = adata.uns["productive_ratio"]["results"] locus = adata.uns["productive_ratio"]["locus"] group_by = adata.uns["productive_ratio"]["group_by"] plt.figure(figsize=figsize) ax = sns.barplot( x=group_by, y="productive+non-productive", data=res, color=palette[0] ) ax = sns.barplot( x=group_by, y="productive", data=res, color=palette[1], ax=ax ) legend = [ mpatches.Patch( color=palette[0], label="% with non-productive " + locus ), mpatches.Patch(color=palette[1], label="% with productive " + locus), ] plt.xticks(fontsize=fontsize, rotation=rotation) plt.yticks(fontsize=fontsize) plt.xlabel("") plt.ylabel("") ax.set(ylim=(0, 100)) plt.title(locus) # add legend plt.legend(handles=legend, **legend_kwargs)
def _pack(hierarchy: list[dict], target_enc, packer: str) -> list: """Dispatch circle packing to circlify or packcircles.""" if packer == "circlify": return circlify.circlify( hierarchy, show_enclosure=False, target_enclosure=target_enc ) if packer == "packcircles": return _pack_with_packcircles(hierarchy, target_enc) raise ValueError( f"Unknown packer '{packer}'. Choose 'circlify' or 'packcircles'." ) def _pack_with_packcircles(hierarchy: list[dict], target_enc) -> list: """Recursive circle packer using packcircles for leaf levels. Groups (nodes with children) are positioned with circlify (small N, fast). Leaves are packed with packcircles, which uses an iterative approach that scales much better than circlify's O(N²) sequential algorithm. """ import math from types import SimpleNamespace try: import packcircles as pc except ImportError: raise ImportError( "packcircles is required for packer='packcircles'. " "Install it with: pip install packcircles" ) all_circles: list = [] def _do_pack( nodes: list[dict], enc_x: float, enc_y: float, enc_r: float, level: int ) -> None: if not nodes: return has_children = any("children" in n for n in nodes) if not has_children: # Leaf level — use packcircles. # Sort descending so largest clones match packcircles' internal ordering. sorted_nodes = sorted(nodes, key=lambda n: n["datum"], reverse=True) radii = [math.sqrt(n["datum"]) for n in sorted_nodes] if len(radii) < 3: # packcircles requires >= 3 circles; fall back to circlify for small N. fallback_data = [ {"id": n["id"], "datum": n["datum"]} for n in sorted_nodes ] enc = circlify.Circle(enc_x, enc_y, enc_r) fb = circlify.circlify( fallback_data, show_enclosure=False, target_enclosure=enc ) id_to_node = {n["id"]: n for n in sorted_nodes} for fc in fb: node = id_to_node.get(fc.ex.get("id", "")) if node is not None: all_circles.append( SimpleNamespace(x=fc.x, y=fc.y, r=fc.r, ex=node) ) return packed = list(pc.pack(radii)) if not packed: return # Use area-weighted centroid (r² weights) so the centre of mass # sits at the origin rather than the bounding-box midpoint. # This produces a noticeably rounder overall shape compared to # bounding-box centering, because large circles pull the centre # toward where most of the area actually is. total_w = sum(pr * pr for px, py, pr in packed) if total_w == 0: return w_cx = sum(pr * pr * px for px, py, pr in packed) / total_w w_cy = sum(pr * pr * py for px, py, pr in packed) / total_w max_reach = max( math.sqrt((px - w_cx) ** 2 + (py - w_cy) ** 2) + pr for px, py, pr in packed ) if max_reach == 0: return scale = enc_r / max_reach for (px, py, pr), node in zip(packed, sorted_nodes): all_circles.append( SimpleNamespace( x=enc_x + (px - w_cx) * scale, y=enc_y + (py - w_cy) * scale, r=pr * scale, ex=node, ) ) else: # Group level — circlify positions the group circles (small N, fast). stripped = [{"id": n["id"], "datum": n["datum"]} for n in nodes] group_enc = circlify.Circle(enc_x, enc_y, enc_r) group_circles = circlify.circlify( stripped, show_enclosure=False, target_enclosure=group_enc ) node_by_id = {n["id"]: n for n in nodes} for gc in group_circles: node = node_by_id.get(gc.ex.get("id", "")) if node is None: continue # Use the original node dict as ex so color_group_lookup matches. all_circles.append( SimpleNamespace(x=gc.x, y=gc.y, r=gc.r, ex=node) ) if "children" in node: _do_pack(node["children"], gc.x, gc.y, gc.r, level + 1) _do_pack(hierarchy, target_enc.x, target_enc.y, target_enc.r, 1) return all_circles
[docs] def clone_circlepackplot( data: AnnData | DandelionPolars, group_by: str | list[str], palette: str | dict | None = None, figsize: tuple[float, float] = (8, 8), title: str | None = None, min_clone_size: int = 1, clone_key: str | None = None, show_group_labels: bool = True, show_clone_labels: bool = False, show_count_labels: bool = False, alpha: float = 0.6, show_legend: str | list[str] | None = None, legend_kwargs: dict = { "bbox_to_anchor": (1, 0.5), "loc": "center left", "frameon": False, }, as_subplots: bool = False, n_col: int | None = None, n_row: int | None = None, scale_subplots: bool = True, scale_factor: float | None = None, outer_ring_color: str | None = None, show_enclosure_label: bool = True, max_clones_per_group: int | None = None, aggregate_by_size: bool = False, packer: Literal["circlify", "packcircles"] = "circlify", ) -> tuple[Figure, Axes] | tuple[Figure, list[Axes]]: """ A bubble plot to visualise clone sizes within groups using circle packing. Each group (e.g. sample, celltype) is represented as an enclosing circle, with clones within that group shown as packed inner circles sized proportionally to clone size. When `group_by` is a list the hierarchy follows the list order: the first element is the outermost ring, subsequent elements are nested rings, and clone circles sit at the innermost level. Each level is coloured independently using its own colour map. Parameters ---------- data : AnnData | DandelionPolars DandelionPolars or AnnData object. group_by : str | list[str] Column name(s) in metadata to group clones by. A single string gives one level of nesting; a list gives multi-level nesting in list order (e.g. ``['sample_id', 'celltype']``). palette : str | dict | None, optional Colour specification. * ``None`` (default): for DandelionPolars objects, uses ``"Set2"`` for every level; for AnnData objects, each column is looked up in ``adata.uns`` (e.g. ``uns["leiden_colors"]``) and falls back to scanpy's default palette cycle when the key is absent. * ``str``: a seaborn palette name applied uniformly to every level. * ``dict``: a nested mapping where each key is a column name from ``group_by`` and each value is either a ``{category: colour}`` dict or a list of colours assigned in category order (respecting ``.cat.categories`` for categoricals, numeric sort order, or alphabetical otherwise). Missing columns or values are auto-assigned. Examples for ``group_by=["A", "B"]``:: palette={"A": {"x": "red", "y": "blue"}, "B": {"x": "green"}} palette={"A": ["red", "blue"], "B": ["green", "orange"]} figsize : tuple[float, float], optional Figure size. When ``as_subplots=True`` this is interpreted as the size of each individual subplot panel. title : str | None, optional Title of plot. min_clone_size : int, optional Minimum clone size to include. Set to 2 to exclude singletons. clone_key : str | None, optional Column name for clones. None defaults to 'clone_id'. show_group_labels : bool, optional Whether to annotate each group enclosure with its label. show_clone_labels : bool, optional Whether to annotate each clone circle with its clone ID. show_count_labels : bool, optional Whether to annotate each circle with its cell count. alpha : float, optional Transparency of clone circles. show_legend : str | list[str] | None, optional Controls which group_by levels appear in the legend. * ``None`` (default): show all levels. * ``str``: show only that level (e.g. ``"isotype"``). * ``list[str]``: show only those levels. * ``False``: hide the legend entirely. legend_kwargs : dict, optional Keyword arguments forwarded to ``ax.legend``. as_subplots : bool, optional If ``True``, split the figure into one subplot per top-level group (the first element of ``group_by``). The subplots are tiled according to ``n_col`` and ``n_row``. n_col : int | None, optional Number of subplot columns when ``as_subplots=True``. If ``None``, defaults to 4 when there are more than 4 subplots, otherwise the number of subplots. Ignored when ``as_subplots=False``. n_row : int | None, optional Number of subplot rows when ``as_subplots=True``. If ``None``, computed automatically from ``n_col``. Ignored when ``as_subplots=False``. scale_subplots : bool, optional When ``as_subplots=True``, scale each subplot's enclosing circle so that its area is proportional to its total cell count relative to the largest subplot. This keeps circle sizes visually comparable across panels: a clone of size *n* always occupies the same area regardless of which subplot it appears in. Uniform axis limits are applied to all subplots. Has no effect when ``as_subplots=False``. Default ``True``. scale_factor : float | None, optional Direct multiplier on the enclosure radius. Larger values produce larger circles; smaller values produce smaller circles. Applies in both ``as_subplots=True`` and ``as_subplots=False`` modes. * ``None`` (default): in single-panel mode the enclosure fills the panel (radius = 1.0); in subplot mode the radius is determined by ``scale_subplots`` alone. * ``float``: the enclosure radius is set to ``scale_factor`` in single-panel mode, or multiplied by ``scale_factor`` on top of the ``scale_subplots`` auto-computed radius in subplot mode. Axes are fixed to ``[-1.05, 1.05]`` so relative sizes remain visually meaningful. Values < 1 shrink circles; values > 1 enlarge them. outer_ring_color : str | None, optional If set, all outermost group rings (depth-0 circles in single-panel mode and the enclosing ring in each subplot) are drawn in this single colour instead of their level-0 palette colour. Any valid matplotlib colour string is accepted (e.g. ``"black"``, ``"#333333"``). ``None`` (default) preserves per-group colouring from ``palette``. show_enclosure_label : bool, optional Whether to display the total cell count below each enclosing group ring. Only has an effect when ``show_count_labels=True``. Default ``True``. max_clones_per_group : int | None, optional If set, only the top-*N* largest clones within each leaf group are drawn. Clones are ranked by size (descending) and any beyond the *N*-th are silently dropped before packing. Useful for keeping very large groups readable. ``None`` (default) includes all clones that pass ``min_clone_size``. aggregate_by_size : bool, optional When ``True``, instead of drawing one circle per clone, clones are grouped into buckets by their size. Each bucket becomes a single circle whose area is proportional to ``clone_size × number_of_clones`` (total cells in that bucket). The circle label is ``"n=<size>"`` and, when ``show_count_labels=True``, the annotation reads ``"<size>\\n(<total_cells>)"``. This greatly reduces the number of circles for samples with many small clones and gives a compact overview of the size distribution. Default ``False``. packer : {"circlify", "packcircles"}, optional Circle-packing backend to use. * ``"circlify"`` (default): uses the *circlify* library, which produces deterministic, aesthetically balanced layouts. * ``"packcircles"``: uses the *packcircles* library, which applies an iterative overlap-removal algorithm. Leaf-level circles are packed with *packcircles*; enclosing group circles are still laid out by *circlify*. Requires ``pip install packcircles``. Returns ------- tuple[Figure, Axes] Circle-packing bubble plot (when ``as_subplots=False``). tuple[Figure, list[Axes]] Figure and list of per-group axes (when ``as_subplots=True``). Raises ------ ValueError If no clones remain after filtering by `min_clone_size`. ValueError If ``packer`` is not ``"circlify"`` or ``"packcircles"``. """ _is_adata = isinstance(data, AnnData) _adata_uns = data.uns if _is_adata else {} if isinstance(data, DandelionPolars): data = data.copy() if data._backend == "polars": data.to_pandas() data = data._metadata.copy() elif isinstance(data, AnnData): data = data.obs.copy() clone_ = clone_key if clone_key is not None else "clone_id" # Remove cells with no assigned clone (No_contig, NaN, etc.) _no_clone = {"No_contig", "nan", "None", "NA", ""} data = data[data[clone_].notna()].copy() data = data[~data[clone_].astype(str).isin(_no_clone)].copy() size = data[clone_].value_counts() keep = list(size[size >= min_clone_size].index) data_ = data[data[clone_].isin(keep)] if data_.empty: raise ValueError( f"No clones remaining after filtering with min_clone_size={min_clone_size}." ) group_by_cols = [group_by] if isinstance(group_by, str) else list(group_by) # Remove rows where any group_by column contains no-data sentinel values for col in group_by_cols: data_ = data_[~data_[col].astype(str).isin(_no_clone)].copy() def _auto_colors( col: str, vals: list[str], offset: int = 0 ) -> dict[str, tuple]: """Return a {value: colour} map for *vals* in *col*. For AnnData inputs the column's ``.uns`` entry is consulted first. DandelionPolars inputs always use Set2. *offset* shifts the starting index in the palette so that successive hierarchy levels draw from distinct colours rather than all cycling from position 0. """ if _is_adata: uns_key = f"{col}_colors" if uns_key in _adata_uns: obs_col = data[col] if col in data.columns else None if obs_col is not None and pd.api.types.is_categorical_dtype( obs_col ): cats = list(obs_col.cat.categories) else: cats = ( sorted(data[col].dropna().unique().tolist()) if col in data.columns else [] ) color_dict = { str(c): clr for c, clr in zip(cats, _adata_uns[uns_key]) } return {v: color_dict.get(v, "gray") for v in vals} # Fallback: scanpy default palette cycle with per-level offset n = len(vals) total_needed = offset + n if total_needed <= 20: pal = list(palettes.default_20) elif total_needed <= 28: pal = list(palettes.default_28) else: pal = list(palettes.default_102) colors = [pal[i % len(pal)] for i in range(offset, offset + n)] return dict(zip(vals, colors)) # DandelionPolars / non-AnnData: use Set2 with per-level offset full_pal = sns.color_palette("Set2", max(8, offset + len(vals))) return dict(zip(vals, full_pal[offset : offset + len(vals)])) def _ordered_vals(col: str) -> list[str]: """Return unique string values for col in the correct display order. Categorical columns respect their ``.cat.categories`` order (matching scanpy). Purely numeric columns sort numerically (1, 2, 10 not 1, 10, 2). Everything else sorts alphabetically. """ col_data = data_[col] if pd.api.types.is_categorical_dtype(col_data): present = set(col_data.dropna().astype(str)) return [ str(c) for c in col_data.cat.categories if str(c) in present ] raw_vals = col_data.dropna().unique() as_numeric = pd.to_numeric(pd.Series(raw_vals), errors="coerce") if as_numeric.notna().all(): return [str(v) for v in sorted(raw_vals, key=float)] return sorted(str(v) for v in raw_vals) # Build one colour map per group_by level. # _palette_offset ensures each level starts from a fresh position in the # auto-assigned palette so that outer-ring colours (level 0) are always # visually distinct from inner-ring colours (level 1+). level_color_maps: list[dict[str, tuple]] = [] level_ordered_vals: list[list[str]] = [] _palette_offset = 0 for col in group_by_cols: unique_vals = _ordered_vals(col) level_ordered_vals.append(unique_vals) if isinstance(palette, str): full_pal = sns.color_palette( palette, _palette_offset + len(unique_vals) ) cmap: dict[str, tuple] = dict( zip(unique_vals, full_pal[_palette_offset:]) ) elif isinstance(palette, dict): col_palette = palette.get(col, {}) if isinstance(col_palette, list): cmap = dict(zip(unique_vals, col_palette)) missing = [v for v in unique_vals if v not in cmap] if missing: cmap.update(_auto_colors(col, missing, _palette_offset)) else: cmap = {} missing = [] for v in unique_vals: if v in col_palette: cmap[v] = col_palette[v] else: missing.append(v) if missing: cmap.update(_auto_colors(col, missing, _palette_offset)) else: cmap = _auto_colors(col, unique_vals, _palette_offset) _palette_offset += len(unique_vals) level_color_maps.append(cmap) # Side-lookup: id(node) → (level_index, group_value) # Avoids embedding extra keys in the dicts that circlify would warn about. color_group_lookup: dict[int, tuple[int, str]] = {} def _build_hierarchy( df: pd.DataFrame, levels: list[str], depth: int, parent_info: tuple[int, str] | None, ) -> list[dict]: """Recursively build a circlify-compatible hierarchy. Each node dict is registered by its object identity so the renderer can retrieve the correct per-level colour without polluting the dicts. """ if not levels: # Leaves inherit the deepest group_by level's colour. clone_sizes = df[clone_].value_counts() # Re-apply min_clone_size per group: a clone whose cells are spread # across groups can pass the global filter (e.g. 1 cell in each of # 2 samples = global size 2) yet show up with size 1 inside a group. clone_sizes = clone_sizes[clone_sizes >= min_clone_size] if ( max_clones_per_group is not None and len(clone_sizes) > max_clones_per_group ): clone_sizes = clone_sizes.head(max_clones_per_group) leaf_info: tuple[int, str] = ( parent_info if parent_info is not None else (0, "") ) result = [] if aggregate_by_size: # One circle per distinct clone size; datum = size × count (total cells). # Reduces hundreds of individual-clone circles to ~10-20 size buckets. size_dist = clone_sizes.value_counts().sort_index( ascending=False ) for clone_size, n_clones in size_dist.items(): node: dict = { "id": f"n={int(clone_size)}", "datum": int(clone_size) * int(n_clones), } color_group_lookup[id(node)] = leaf_info result.append(node) else: for cid, cnt in clone_sizes.items(): node = {"id": str(cid), "datum": int(cnt)} color_group_lookup[id(node)] = leaf_info result.append(node) return result current, rest = levels[0], levels[1:] result = [] for grp, gdata in df.groupby(current, observed=True): grp_str = str(grp) my_info: tuple[int, str] = (depth, grp_str) children = _build_hierarchy(gdata, rest, depth + 1, my_info) node = { "id": grp_str, "datum": sum(c["datum"] for c in children), "children": children, } color_group_lookup[id(node)] = my_info result.append(node) return result def _render_circles_on_ax( ax: Axes, circles: list, lookup: dict[int, tuple[int, str]], count_groups: bool = True, ): filled_patches = [] filled_facecolors = [] ring_patches = [] ring_edgecolors = [] texts = [] for circle in circles: if circle.ex is None: continue x, y, r = circle.x, circle.y, circle.r label = circle.ex.get("id", "") level_idx, group_val = lookup.get(id(circle.ex), (0, label)) level_idx = min(level_idx, len(level_color_maps) - 1) color = level_color_maps[level_idx].get(group_val, "gray") is_group = "children" in circle.ex if is_group: _ring_color = ( outer_ring_color if outer_ring_color is not None and level_idx == 0 else color ) ring_patches.append(mpatches.Circle((x, y), r)) ring_edgecolors.append(_ring_color) if show_group_labels: texts.append( ( x, y + r, label, dict( ha="center", va="bottom", fontsize=9, fontweight="bold", color=_ring_color, ), ) ) if show_count_labels and count_groups and show_enclosure_label: texts.append( ( x, y - r - 0.05, str(circle.ex["datum"]), dict( ha="center", va="top", fontsize=7, color=_ring_color, ), ) ) else: filled_patches.append(mpatches.Circle((x, y), r)) filled_facecolors.append(color) # For aggregate_by_size nodes (id="n=<size>"), show "size (total_cells)" # so the reader can see both the clone size and the total cell count. _is_agg = label.startswith("n=") and label[2:].isdigit() _count_str = ( f"{label[2:]}\n({circle.ex['datum']})" if _is_agg else str(circle.ex["datum"]) ) if show_clone_labels and show_count_labels: texts.append( ( x, y + r * 0.25, label, dict(ha="center", va="center", fontsize=7), ) ) texts.append( ( x, y - r * 0.25, _count_str, dict(ha="center", va="center", fontsize=7), ) ) elif show_clone_labels: texts.append( ( x, y, label, dict(ha="center", va="center", fontsize=7), ) ) elif show_count_labels: texts.append( ( x, y, _count_str, dict(ha="center", va="center", fontsize=7), ) ) if filled_patches: pc = PatchCollection( filled_patches, facecolors=filled_facecolors, edgecolors="white", linewidths=0.5, alpha=alpha, zorder=1, ) ax.add_collection(pc) if ring_patches: rc = PatchCollection( ring_patches, facecolors="none", edgecolors=ring_edgecolors, linewidths=2, zorder=2, ) ax.add_collection(rc) for tx, ty, ts, kw in texts: ax.text(tx, ty, ts, **kw) def _set_ax_limits(ax: Axes, circles) -> None: xs = [c.x for c in circles if c.ex is not None] ys = [c.y for c in circles if c.ex is not None] rs = [c.r for c in circles if c.ex is not None] if xs: margin = 0.05 ax.set_xlim( min(cx - cr for cx, cr in zip(xs, rs)) - margin, max(cx + cr for cx, cr in zip(xs, rs)) + margin, ) ax.set_ylim( min(cy - cr for cy, cr in zip(ys, rs)) - margin, max(cy + cr for cy, cr in zip(ys, rs)) + margin, ) else: ax.set_xlim(-1.1, 1.1) ax.set_ylim(-1.1, 1.1) def _build_legend_handles() -> list: if show_legend is None or show_legend is True: _legend_levels = None elif isinstance(show_legend, str): _legend_levels = {show_legend} else: _legend_levels = set(show_legend) handles = [] first_added = True for col, cmap, ordered_vals in zip( group_by_cols, level_color_maps, level_ordered_vals ): if _legend_levels is not None and col not in _legend_levels: continue if not first_added: handles.append( mpatches.Patch( facecolor="none", edgecolor="none", label=" " ) ) first_added = False handles.append( mpatches.Patch(facecolor="none", edgecolor="none", label=col) ) handles += [ mpatches.Patch(color=cmap[v], label=v) for v in ordered_vals ] return handles if as_subplots: _top_col_str = data_[group_by_cols[0]].astype(str) top_groups = level_ordered_vals[0] n_subplots = len(top_groups) _n_col = ( n_col if n_col is not None else (4 if n_subplots > 4 else n_subplots) ) _n_row = n_row if n_row is not None else -(-n_subplots // _n_col) # Pre-build all sub-hierarchies so totals are available for scaling _sub_builds: list[tuple[str, list[dict], dict]] = [] for grp_val in top_groups: sub_data = data_[_top_col_str == grp_val] color_group_lookup.clear() sub_hier = _build_hierarchy( sub_data, group_by_cols[1:], 1, (0, grp_val) ) _sub_builds.append((grp_val, sub_hier, dict(color_group_lookup))) if scale_subplots: _sub_totals = [ sum(c["datum"] for c in hier) if hier else 0 for _, hier, _ in _sub_builds ] _max_total = max(_sub_totals) if any(_sub_totals) else 1 fig, axes = plt.subplots( _n_row, _n_col, figsize=(figsize[0] * _n_col, figsize[1] * _n_row), squeeze=False, ) axes_flat: np.ndarray = axes.flatten() for idx, (grp_val, sub_hierarchy, lookup) in enumerate(_sub_builds): ax = axes_flat[idx] if not sub_hierarchy: ax.axis("off") continue if scale_subplots: _base_r = (_sub_totals[idx] / _max_total) ** 0.5 _enc_r = ( _base_r * scale_factor if scale_factor is not None else _base_r ) else: _enc_r = scale_factor if scale_factor is not None else 1.0 _target_enc = circlify.Circle(0, 0, _enc_r) sub_circles = _pack(sub_hierarchy, _target_enc, packer) ax.set_aspect("equal") _render_circles_on_ax(ax, sub_circles, lookup, count_groups=False) # Outer ring for the top-level group (mirrors single-panel behaviour) _outer_color = ( outer_ring_color if outer_ring_color is not None else level_color_maps[0].get(grp_val, "gray") ) ax.add_patch( mpatches.Circle( (0, 0), _enc_r, fill=False, edgecolor=_outer_color, linewidth=2, ) ) if show_count_labels and show_enclosure_label: _total = sum(c["datum"] for c in sub_hierarchy) ax.text( 0, -_enc_r - 0.05, str(_total), ha="center", va="top", fontsize=8, color=_outer_color, ) if scale_subplots or scale_factor is not None: _margin = 0.05 ax.set_xlim(-1.0 - _margin, 1.0 + _margin) ax.set_ylim(-1.0 - _margin, 1.0 + _margin) else: _set_ax_limits(ax, sub_circles) ax.set_title(grp_val, fontsize=10, fontweight="bold") ax.axis("off") for idx in range(n_subplots, _n_row * _n_col): axes_flat[idx].axis("off") if show_legend is not False: handles = _build_legend_handles() if handles: axes_flat[n_subplots - 1].legend( handles=handles, **legend_kwargs ) if title is not None: fig.suptitle(title) return fig, list(axes_flat[:n_subplots]) hierarchy = _build_hierarchy(data_, group_by_cols, 0, None) if scale_factor is not None: _target_enc = circlify.Circle(0, 0, scale_factor) else: _target_enc = circlify.Circle(0, 0, 1) circles = _pack(hierarchy, _target_enc, packer) fig, ax = plt.subplots(figsize=figsize) ax.set_aspect("equal") _render_circles_on_ax(ax, circles, color_group_lookup) if scale_factor is not None: _margin = 0.05 ax.set_xlim(-1.0 - _margin, 1.0 + _margin) ax.set_ylim(-1.0 - _margin, 1.0 + _margin) else: _set_ax_limits(ax, circles) if show_legend is not False: handles = _build_legend_handles() ax.legend(handles=handles, **legend_kwargs) if title is not None: ax.set_title(title) ax.axis("off") return fig, ax
@contextmanager def _temporary_obs_columns(adata: AnnData, mudata: MuData | None, **kwargs): """Temporarily add columns from submodalities or shared obs to adata.obs.""" if mudata is None: # plain AnnData, nothing to do yield kwargs return added = {} try: for key, value in kwargs.items(): if key in {"color", "size", "shape"} and isinstance(value, str): if ":" in value: # case: "mod:col" mod, col = value.split(":", 1) if mod not in mudata.mod: raise KeyError(f"MuData has no modality '{mod}'") sub = mudata.mod[mod] if col not in sub.obs.columns: raise KeyError( f"'{col}' not found in mudata.mod['{mod}'].obs" ) temp_col = f"{mod}:{col}" adata.obs[temp_col] = sub.obs[col].reindex(adata.obs.index) kwargs[key] = temp_col added[temp_col] = None else: # case: shared obs in mudata.obs if value not in mudata.obs: raise KeyError(f"'{value}' not found in mudata.obs") adata.obs[value] = mudata.obs[value].reindex( adata.obs.index ) kwargs[key] = value added[value] = None yield kwargs finally: # cleanup temporary columns for temp_col in added: adata.obs.drop(columns=temp_col, inplace=True, errors="ignore")