Source code for dandelion.plotting._plotting

#!/usr/bin/env python
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import nxviz as nxv
import pandas as pd
import seaborn as sns

from anndata import AnnData
from itertools import product, cycle
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from nxviz import annotate
from plotnine import (
    aes,
    geom_line,
    ggplot,
    ggtitle,
    labs,
    options,
    scale_color_manual,
    theme_classic,
    xlab,
    ylab,
)
from scanpy.plotting import palettes
from scanpy.plotting._tools.scatterplots import embedding
from time import sleep
from tqdm import tqdm
from typing import Callable, Literal

from dandelion.tools._diversity import rarefun
from dandelion.utilities._core import *
from dandelion.utilities._io import *
from dandelion.utilities._utilities import *


[docs] def clone_rarefaction( vdj_data: AnnData | Dandelion, color: str, clone_key: str | None = None, palette: list[str] | None = None, figsize: tuple[float, float] = (5, 3), chain_status_include: list[ Literal[ "Single pair", "Orphan VDJ", "Orphan VDJ-exception", "Orphan VJ", "Orphan VJ-exception", "Extra pair", "Extra pair-exception", ] ] = [ "Single pair", "Orphan VDJ", "Orphan VDJ-exception", "Extra pair", "Extra pair-exception", ], save: str | None = None, ) -> ggplot: """ Plot rarefaction curve for cell numbers vs clone size. Parameters ---------- vdj_data : AnnData | Dandelion `AnnData` or `Dandelion` object. color : str Column name to split the calculation of clone numbers for a given number of cells for e.g. sample, patient etc. clone_key : str | None, optional Column name specifying the clone_id column in metadata/obs. palette : list[str] | None, optional Color mapping for unique elements in color. Will try to retrieve from AnnData `.uns` slot if present. figsize : tuple[float, float], optional Size of plot. chain_status_include : list[Literal["Single pair", "Orphan VDJ", "Orphan VDJ-exception", "Orphan VJ", "Orphan VJ-exception", "Extra pair", "Extra pair-exception", ]], optional chain statuses to include. save : str | None, optional Save path. Returns ------- ggplot rarefaction plot. """ if isinstance(vdj_data, AnnData): metadata = vdj_data.obs.copy() elif isinstance(vdj_data, Dandelion): metadata = vdj_data.metadata.copy() if clone_key is None: clonekey = "clone_id" else: clonekey = clone_key groups = list(set(metadata[color])) if "contig_QC_pass" in metadata: metadata = metadata[metadata["contig_QC_pass"].isin(TRUES)] elif "chain_status" in metadata: metadata = metadata[metadata["chain_status"].isin(chain_status_include)] if type(metadata[clonekey]) == "category": metadata[clonekey] = metadata[clonekey].cat.remove_unused_categories() res = {} for g in groups: _metadata = metadata[metadata[color] == g] res[g] = _metadata[clonekey].value_counts() res_ = pd.DataFrame.from_dict(res, orient="index") # remove those with no counts print( "removing due to zero counts:", ", ".join( [res_.index[i] for i, x in enumerate(res_.sum(axis=1) == 0) if x] ), ) sleep(0.5) res_ = res_[~(res_.sum(axis=1) == 0)] # set up for calculating rarefaction tot = res_.apply(sum, axis=1) # S = res_.apply(lambda x: x[x > 0].shape[0], axis=1) nr = res_.shape[0] # append the results to a dictionary rarecurve = {} for i in tqdm( range(0, nr), desc="Calculating rarefaction curve ", bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}", ): n = np.arange(1, tot[i], step=10) if n[-1:] != tot[i]: n = np.append(n, tot[i]) rarecurve[res_.index[i]] = [ rarefun( np.array(res_.iloc[i,]), z, ) for z in n ] y = pd.DataFrame([rarecurve[c] for c in rarecurve]).T pred = pd.DataFrame( [np.append(np.arange(1, s, 10), s) for s in res_.sum(axis=1)], index=res_.index, ).T y = y.melt() pred = pred.melt() pred["yhat"] = y["value"] options.figure_size = figsize if palette is None: if isinstance(vdj_data, AnnData): try: pal = vdj_data.uns[str(color) + "_colors"] except: if len(list(set(pred.variable))) <= 20: pal = palettes.default_20 elif len(list(set(pred.variable))) <= 28: pal = palettes.default_28 elif len(list(set(pred.variable))) <= 102: pal = palettes.default_102 else: pal = cycle(palettes.default_102) if pal is not None: p = ( ggplot(pred, aes(x="value", y="yhat", color="variable")) + theme_classic() + xlab("number of cells") + ylab("number of clones") + ggtitle("rarefaction curve") + labs(color=color) + scale_color_manual(values=(pal)) + geom_line() ) else: p = ( ggplot(pred, aes(x="value", y="yhat", color="variable")) + theme_classic() + xlab("number of cells") + ylab("number of clones") + ggtitle("rarefaction curve") + labs(color=color) + geom_line() ) else: if len(list(set(pred.variable))) <= 20: pal = palettes.default_20 elif len(list(set(pred.variable))) <= 28: pal = palettes.default_28 elif len(list(set(pred.variable))) <= 102: pal = palettes.default_102 else: pal = None if pal is not None: p = ( ggplot(pred, aes(x="value", y="yhat", color="variable")) + theme_classic() + xlab("number of cells") + ylab("number of clones") + ggtitle("rarefaction curve") + labs(color=color) + scale_color_manual(values=(pal)) + geom_line() ) else: p = ( ggplot(pred, aes(x="value", y="yhat", color="variable")) + theme_classic() + xlab("number of cells") + ylab("number of clones") + ggtitle("rarefaction curve") + labs(color=color) + geom_line() ) else: p = ( ggplot(pred, aes(x="value", y="yhat", color="variable")) + theme_classic() + xlab("number of cells") + ylab("number of clones") + ggtitle("rarefaction curve") + labs(color=color) + geom_line() ) if save: p.save( filename="figures/rarefaction" + str(save), height=plt.rcParams["figure.figsize"][0], width=plt.rcParams["figure.figsize"][1], units="in", dpi=plt.rcParams["savefig.dpi"], ) return p
[docs] def clone_network( adata: AnnData, basis: str = "vdj", edges: bool = True, **kwargs ) -> None: """ Using scanpy's plotting module to plot the network. Only thing that is changed is the default options: `basis = 'bcr'` and `edges = True`. Parameters ---------- adata : AnnData AnnData object. basis : str, optional key for embedding. edges : bool, optional whether or not to plot edges. **kwargs passed `sc.pl.embedding`. """ embedding(adata, basis=basis, edges=edges, **kwargs)
[docs] def barplot( vdj_data: AnnData | Dandelion, color: str, palette: str = "Set1", figsize: tuple[float, float] = (8, 3), normalize: bool = True, sort_descending: bool = True, title: str | None = None, xtick_fontsize: int | None = None, xtick_rotation: int | float | None = None, min_clone_size: int = 1, clone_key: str | None = None, **kwargs, ) -> tuple[Figure, Axes]: """ A barplot function to plot usage of V/J genes in the data. Parameters ---------- vdj_data : AnnData | Dandelion `Dandelion` or `AnnData` object. color : str column name in metadata for plotting in bar plot. palette : str, optional Colors to use for the different levels of the color variable. Should be something that can be interpreted by [color_palette](https://seaborn.pydata.org/generated/seaborn.color_palette.html#seaborn.color_palette), or a dictionary mapping hue levels to matplotlib colors. See [seaborn.barplot](https://seaborn.pydata.org/generated/seaborn.barplot.html). figsize : tuple[float, float], optional figure size. normalize : bool, optional if True, will return as proportion out of 1. Otherwise False will return counts. sort_descending : bool, optional whether or not to sort the order of the plot. title : str | None, optional title of plot. xtick_fontsize : int | None, optional size of x tick labels xtick_rotation : int | float | None, optional rotation of x tick labels. min_clone_size : int, optional minimum clone size to keep. clone_key : str | None, optional column name for clones. None defaults to 'clone_id'. **kwargs passed to `sns.barplot`. Returns ------- tuple[Figure, Axes] bar plot. """ if isinstance(vdj_data, Dandelion): data = vdj_data.metadata.copy() elif isinstance(vdj_data, AnnData): data = vdj_data.obs.copy() min_size = min_clone_size if clone_key is None: clone_ = "clone_id" else: clone_ = clone_key size = data[clone_].value_counts() keep = list(size[size >= min_size].index) data_ = data[data[clone_].isin(keep)] sns.set_style("whitegrid", {"axes.grid": False}) res = pd.DataFrame(data_[color].value_counts(normalize=normalize)) if not sort_descending: res = res.sort_index() res.reset_index(drop=False, inplace=True) # Initialize the matplotlib figure fig, ax = plt.subplots(figsize=figsize) # plot try: sns.barplot(x="index", y=color, data=res, palette=palette, **kwargs) except ValueError: yname = "proportion" if normalize else "count" sns.barplot(x=color, y=yname, data=res, palette=palette, **kwargs) # change some parts if title is None: ax.set_title(color.replace("_", " ") + " usage") else: ax.set_title(title) if normalize: ax.set_ylabel("proportion") else: ax.set_ylabel("count") ax.set_xlabel("") # modify the x ticks accordingly xtick_params = {} if xtick_rotation is None: xtick_params["rotation"] = 90 else: xtick_params["rotation"] = xtick_rotation if xtick_fontsize is not None: xtick_params["fontsize"] = xtick_fontsize plt.xticks(**xtick_params) return fig, ax
[docs] def stackedbarplot( vdj_data: AnnData | Dandelion, color: str, groupby: str | None, figsize: tuple[float, float] = (8, 3), normalize: bool = False, title: str | None = None, sort_descending: bool = True, xtick_fontsize: int | None = None, xtick_rotation: int | float | None = None, hide_legend: bool = False, legend_options: tuple[str, tuple[float, float], int] = ( "upper left", (1, 1), 1, ), labels: list[str] | None = None, min_clone_size: int = 1, clone_key: str | None = None, **kwargs, ) -> tuple[Figure, Axes]: """ A stacked bar plot function to plot usage of V/J genes in the data split by groups. Parameters ---------- vdj_data : AnnData | Dandelion `Dandelion` or `AnnData` object. color : str column name in metadata for plotting in bar plot. groupby : str | None column name in metadata to split by during plotting. figsize : tuple[float, float], optional figure size. normalize : bool, optional if True, will return as proportion out of 1, otherwise False will return counts. title : str | None, optional title of plot. sort_descending : bool, optional whether or not to sort the order of the plot. xtick_fontsize : int | None, optional size of x tick labels xtick_rotation : int | float | None, optional rotation of x tick labels. hide_legend : bool, optional whether or not to hide the legend. legend_options : tuple[str, tuple[float, float], int], optional a tuple holding 3 options for specify legend options: 1) loc (string), 2) bbox_to_anchor (tuple), 3) ncol (int). labels : list[str] | None, optional Names of objects will be used for the legend if list of multiple data frames supplied. min_clone_size : int, optional minimum clone size to keep. clone_key : str | None, optional column name for clones. None defaults to 'clone_id'. **kwargs other kwargs passed to `matplotlib.plt`. Returns ------- tuple[Figure, Axes] stacked barplot. """ if isinstance(vdj_data, Dandelion): data = vdj_data.metadata.copy() elif isinstance(vdj_data, AnnData): data = vdj_data.obs.copy() # quick fix to prevent dropping of nan data[groupby] = [str(l) for l in data[groupby]] min_size = min_clone_size if clone_key is None: clone_ = "clone_id" else: clone_ = clone_key size = data[clone_].value_counts() keep = list(size[size >= min_size].index) data_ = data[data[clone_].isin(keep)] dat_ = pd.DataFrame( data_.groupby(color)[groupby] .value_counts(normalize=normalize) .unstack(fill_value=0) .stack(), columns=["value"], ) dat_.reset_index(drop=False, inplace=True) dat_order = pd.DataFrame(data[color].value_counts(normalize=normalize)) dat_ = dat_.pivot(index=color, columns=groupby, values="value") if sort_descending is True: dat_ = dat_.reindex(dat_order.index) elif sort_descending is False: dat_ = dat_.reindex(dat_order.index[::-1]) elif sort_descending is None: dat_ = dat_.sort_index() def _plot_bar_stacked( dfall: pd.DataFrame, labels: list[str] | None = None, figsize: tuple[float, float] = (8, 3), title: str = "multiple stacked bar plot", xtick_fontsize: int | None = None, xtick_rotation: int | float | None = None, legend_options: tuple[str, tuple[float, float], int] = None, hide_legend: bool = False, H: Literal["/"] = "/", **kwargs, ) -> tuple[Figure, Axes]: """ Given a list of data frames, with identical columns and index, create a clustered stacked bar plot. Parameters ---------- dfall : pd.DataFrame data frame for plotting. labels : list[str] | None, optional a list of the data frame objects. Names of objects will be used for the legend. figsize : tuple[float, float], optional size of figure. title : str, optional string for the title of the plot xtick_fontsize : int | None, optional xtick fontsize. xtick_rotation : int | float | None, optional rotation of xtick labels legend_options : tuple[str, tuple[float, float], int], optional legend options. hide_legend : bool, optional whether to show legend. H : Literal["/"], optional is the hatch used for identification of the different data frames **kwargs other kwargs passed to matplotlib.plt Returns ------- tuple[Figure, Axes] stacked barplot. """ if type(dfall) is not list: dfall = [dfall] n_df = len(dfall) n_col = len(dfall[0].columns) n_ind = len(dfall[0].index) # Initialize the matplotlib figure fig, ax = plt.subplots(figsize=figsize) for df in dfall: # for each data frame ax = df.plot( kind="bar", linewidth=0, stacked=True, ax=ax, legend=False, grid=False, **kwargs, ) # make bar plots ( h, l, ) = ax.get_legend_handles_labels() # get the handles we want to modify for i in range(0, n_df * n_col, n_col): # len(h) = n_col * n_df for j, pa in enumerate(h[i : i + n_col]): for rect in pa.patches: # for each index rect.set_x( rect.get_x() + 1 / float(n_df + 1) * i / float(n_col) ) rect.set_hatch(H * int(i / n_col)) # edited part rect.set_width(1 / float(n_df + 1)) ax.set_xticks((np.arange(0, 2 * n_ind, 2) + 1 / float(n_df + 1)) / 2.0) ax.set_xticklabels(df.index, rotation=0) ax.set_title(title) if normalize: ax.set_ylabel("proportion") else: ax.set_ylabel("count") # Add invisible data to add another legend n = [] for i in range(n_df): n.append(ax.bar(0, 0, color="grey", hatch=H * i)) if legend_options is None: Legend = ("center right", (1.15, 0.5), 1) else: Legend = legend_options if hide_legend is False: l1 = ax.legend( h[:n_col], l[:n_col], loc=Legend[0], bbox_to_anchor=Legend[1], ncol=Legend[2], frameon=False, ) if labels is not None: l2 = plt.legend( n, labels, loc=Legend[0], bbox_to_anchor=Legend[1], ncol=Legend[2], frameon=False, ) ax.add_artist(l2) ax.add_artist(l1) # modify the x ticks accordingly xtick_params = {} if xtick_rotation is None: xtick_params["rotation"] = 90 else: xtick_params["rotation"] = xtick_rotation if xtick_fontsize is not None: xtick_params["fontsize"] = xtick_fontsize plt.xticks(**xtick_params) return fig, ax if title is None: title = ( "multiple stacked bar plot : " + color.replace("_", " ") + " usage" ) else: title = title return _plot_bar_stacked( dat_, labels=labels, figsize=figsize, title=title, xtick_fontsize=xtick_fontsize, xtick_rotation=xtick_rotation, legend_options=legend_options, hide_legend=hide_legend, **kwargs, )
[docs] def spectratype( vdj_data: Dandelion, color: str, groupby: str, locus: str, figsize: tuple[float, float] = (5, 3), width: int | float | None = None, title: str | None = None, xtick_fontsize: int | None = None, xtick_rotation: int | float | None = None, hide_legend: bool = False, legend_options: tuple[str, tuple[float, float], int] = ( "upper left", (1, 1), 1, ), labels: list[str] | None = None, **kwargs, ) -> tuple[Figure, Axes]: """ A spectratype function to plot usage of CDR3 length. Parameters ---------- vdj_data : Dandelion `Dandelion` object. color : str column name in metadata for plotting in bar plot. groupby : str column name in metadata to split by during plotting. locus : str either IGH or IGL. figsize : tuple[float, float], optional figure size. width : int | float | None, optional width of bars. title : str | None, optional title of plot. xtick_fontsize : int | None, optional size of x tick labels xtick_rotation : int | float | None, optional rotation of x tick labels. hide_legend : bool, optional whether or not to hide the legend. legend_options : tuple[str, tuple[float, float], int], optional a tuple holding 3 options for specify legend options: 1) loc (string), 2) bbox_to_anchor (tuple), 3) ncol (int). labels : list[str] | None, optional Names of objects will be used for the legend if list of multiple data frames supplied. **kwargs other kwargs passed to matplotlib.pyplot.plot Returns ------- tuple[Figure, Axes] spectratype plot. """ data = vdj_data.data.copy() if "ambiguous" in vdj_data.data: data = data[data["ambiguous"] == "F"].copy() if type(locus) is not list: locus = [locus] data = data[data["locus"].isin(locus)].copy() data[groupby] = [str(l) for l in data[groupby]] dat_ = pd.DataFrame( data.groupby(color)[groupby] .value_counts(normalize=False) .unstack(fill_value=0) .stack(), columns=["value"], ) dat_.reset_index(drop=False, inplace=True) dat_[color] = pd.to_numeric(dat_[color], errors="coerce") dat_.sort_values(by=color) dat_2 = dat_.pivot(index=color, columns=groupby, values="value") new_index = range(0, int(dat_[color].max()) + 1) dat_2 = dat_2.reindex(new_index, fill_value=0) def _plot_spectra_stacked( dfall: pd.DataFrame, labels: list[str] | None = None, figsize: tuple[float, float] = (5, 3), title: str = "multiple stacked bar plot", width: int | float | None = None, xtick_fontsize: int | None = None, xtick_rotation: int | float | None = None, legend_options: tuple[str, tuple[float, float], int] = None, hide_legend: bool = False, H: Literal["/"] = "/", **kwargs, ) -> tuple[Figure, Axes]: """Stacked spectratype plots. Parameters ---------- dfall : pd.DataFrame data frame for plotting. labels : list[str] | None, optional a list of the data frame objects. Names of objects will be used for the legend. figsize : tuple[float, float], optional size of figure. title : str, optional string for the title of the plot. width : int | float | None, optional width of bars. xtick_fontsize : int | None, optional size of x tick labels xtick_rotation : int | float | None, optional rotation of x tick labels. legend_options : tuple[str, tuple[float, float], int], optional a tuple holding 3 options for specify legend options: 1) loc (string), 2) bbox_to_anchor (tuple), 3) ncol (int). hide_legend : bool, optional whether or not to hide the legend. H : Literal["/"], optional not sure. **kwargs other kwargs passed to matplotlib.plt Returns ------- tuple[Figure, Axes] spectratype plot. """ if type(dfall) is not list: dfall = [dfall] n_df = len(dfall) n_col = len(dfall[0].columns) n_ind = len(dfall[0].index) if width is None: wdth = 0.1 * n_ind / 60 + 0.8 else: wdth = width # Initialize the matplotlib figure fig, ax = plt.subplots(figsize=figsize) for df in dfall: # for each data frame ax = df.plot( kind="bar", linewidth=0, stacked=True, ax=ax, legend=False, grid=False, **kwargs, ) # make bar plots ( h, l, ) = ax.get_legend_handles_labels() # get the handles we want to modify for i in range(0, n_df * n_col, n_col): # len(h) = n_col * n_df for j, pa in enumerate(h[i : i + n_col]): for rect in pa.patches: # for each index rect.set_x( rect.get_x() + 1 / float(n_df + 1) * i / float(n_col) ) rect.set_hatch(H * int(i / n_col)) # edited part # need to see if there's a better way to toggle this. rect.set_width(wdth) n = 5 # Keeps every 5th label visible and hides the rest [ l.set_visible(False) for (i, l) in enumerate(ax.xaxis.get_ticklabels()) if i % n != 0 ] ax.set_title(title) ax.set_ylabel("count") # Add invisible data to add another legend n = [] for i in range(n_df): n.append(ax.bar(0, 0, color="gray", hatch=H * i)) if legend_options is None: Legend = ("center right", (1.25, 0.5), 1) else: Legend = legend_options if hide_legend is False: l1 = ax.legend( h[:n_col], l[:n_col], loc=Legend[0], bbox_to_anchor=Legend[1], ncol=Legend[2], frameon=False, ) if labels is not None: l2 = plt.legend( n, labels, loc=Legend[0], bbox_to_anchor=Legend[1], ncol=Legend[2], frameon=False, ) ax.add_artist(l1) # modify the x ticks accordingly xtick_params = {} if xtick_rotation is None: xtick_params["rotation"] = 90 else: xtick_params["rotation"] = xtick_rotation if xtick_fontsize is not None: xtick_params["fontsize"] = xtick_fontsize plt.xticks(**xtick_params) return fig, ax return _plot_spectra_stacked( dat_2, labels=labels, figsize=figsize, title=title, width=width, xtick_fontsize=xtick_fontsize, xtick_rotation=xtick_rotation, legend_options=legend_options, hide_legend=hide_legend, **kwargs, )
[docs] def clone_overlap( adata: AnnData, groupby: str, colorby: str | None = None, weighted_overlap: bool = False, clone_key: str | None = None, color_mapping: list | dict | None = None, node_labels: bool = True, return_graph: bool = False, save: str | None = None, legend_kwargs: dict = { "ncol": 2, "bbox_to_anchor": (1, 0.5), "frameon": False, "loc": "center left", }, node_label_size: int = 10, as_heatmap: bool = False, return_heatmap_data: bool = False, scale_edge_lambda: Callable | None = None, **kwargs, ) -> nxv.CircosPlot: """ A plot function to visualise clonal overlap as a circos-style plot. Originally written with nxviz < 0.7.3. Ported from https://github.com/zktuong/nxviz/tree/custom_color_mapping_circos_nodes_and_edges Parameters ---------- adata : AnnData `AnnData` object. groupby : str column name in obs for collapsing to nodes in circos plot. colorby : str | None, optional column name in obs for grouping and color of nodes in plot. Must be a same or subcategory of the `groupby` categories e.g. `groupby="group_tissue", colorby="tissue"`. weighted_overlap : bool, optional if True, instead of collapsing to overlap to binary, edge thickness will reflect the number of cells found in the overlap. In the future, there will be the option to use something like a jaccard index instead. clone_key : str | None, optional column name for clones. None defaults to 'clone_id'. color_mapping : list | dict | None, optional custom color mapping provided as a sequence (correpsonding to order of categories or alpha-numeric order ifdtype is not category), or dictionary containing custom {category:color} mapping. node_labels : bool, optional whether to use node objects as labels or not return_graph : bool, optional whether or not to return the graph for fine tuning. save : str | None, optional file path for saving plot legend_kwargs : dict, optional options for adjusting legend placement node_label_size : int, optional size of labels if node_labels = True as_heatmap : bool, optional whether to return plot as heatmap. return_heatmap_data : bool, optional whether to return heatmap data as a pandas dataframe. scale_edge_lambda : Callable | None, optional a lambda function to scale the edge thickness. If None, will not scale. **kwargs passed to `matplotlib.pyplot.savefig`. Returns ------- nxv.CircosPlot a `nxviz.CircosPlot` object. Raises ------ KeyError if `clone_overlap` not found in `adata.uns`. ValueError if input is not `AnnData`. """ if clone_key is None: clone_ = "clone_id" else: clone_ = clone_key if isinstance(adata, AnnData): data = adata.obs.copy() # get rid of problematic rows that appear because of category conversion? if "clone_overlap" in adata.uns: overlap = adata.uns["clone_overlap"].copy() else: raise KeyError( "`clone_overlap` not found in `adata.uns`. Did you run `tl.clone_overlap`?" ) else: raise ValueError("Please provide a AnnData object.") edges = {} if not weighted_overlap: for x in overlap.index: if overlap.loc[x].sum() > 1: edges[x] = [ y + ({str(clone_): x},) for y in list( product( [ i for i in overlap.loc[x][ overlap.loc[x] > 0 ].index ], repeat=2, ) ) ] else: tmp_overlap = overlap.astype(bool).sum(axis=1) combis = { x: list( product( [i for i in overlap.loc[x][overlap.loc[x] > 0].index], repeat=2, ) ) for x in tmp_overlap.index if tmp_overlap.loc[x] > 1 } tmp_edge_weight_dict = defaultdict(list) for k_clone, val_pair in combis.items(): for pair in val_pair: tmp_edge_weight_dict[pair].append( overlap.loc[k_clone, list(pair)].sum() ) for combix in tmp_edge_weight_dict: if scale_edge_lambda is not None: tmp_edge_weight_dict[combix] = scale_edge_lambda( sum(tmp_edge_weight_dict[combix]) ) else: tmp_edge_weight_dict[combix] = sum(tmp_edge_weight_dict[combix]) for x in overlap.index: if overlap.loc[x].sum() > 1: edges[x] = [ y + ( { str(clone_): x, "weight": ( tmp_edge_weight_dict[y] if not isinstance(tmp_edge_weight_dict[y], list) else 0 ), }, ) for y in list( product( [ i for i in overlap.loc[x][ overlap.loc[x] > 0 ].index ], repeat=2, ) ) ] colorby = groupby if colorby is None else colorby # create graph G = nx.Graph() # add in the nodes G.add_nodes_from( [(p, {str(colorby): d}) for p, d in zip(data[groupby], data[colorby])] ) # unpack the edgelist and add to the graph for edge in edges: G.add_edges_from(edges[edge]) if not weighted_overlap: weighted_attr = None else: weighted_attr = "weight" if color_mapping is None: if str(colorby) + "_colors" in adata.uns: if pd.api.types.is_categorical_dtype(adata.obs[groupby]): colorby_dict = dict( zip( list(adata.obs[str(colorby)].cat.categories), adata.uns[str(colorby) + "_colors"], ) ) else: colorby_dict = dict( zip( list(adata.obs[str(colorby)].unique()), adata.uns[str(colorby) + "_colors"], ) ) else: if len(adata.obs[str(colorby)].unique()) <= 20: pal = cycle(palettes.default_20) elif len(adata.obs[str(colorby)].unique()) <= 28: pal = cycle(palettes.default_28) else: pal = cycle(palettes.default_102) colorby_dict = dict( zip(list(adata.obs[str(colorby)].unique()), pal) ) else: if type(color_mapping) is dict: colorby_dict = color_mapping else: if pd.api.types.is_categorical_dtype(data[groupby]): colorby_dict = dict( zip(list(data[str(colorby)].cat.categories), color_mapping) ) else: colorby_dict = dict( zip(sorted(list(set(data[str(colorby)]))), color_mapping) ) df = data[[groupby, colorby]] if groupby == colorby: df = data[[groupby]] df = ( df.sort_values(groupby) .drop_duplicates(subset=groupby, keep="first") .reset_index(drop=True) ) else: df = ( df.sort_values(colorby) .drop_duplicates(subset=groupby, keep="first") .reset_index(drop=True) ) if as_heatmap: hm = nx.to_pandas_adjacency(G) sns.clustermap(hm, **kwargs) if return_heatmap_data: return hm else: # remove self loops G.remove_edges_from(nx.selfloop_edges(G)) ax = nxv.circos( G, group_by=colorby, node_color_by=colorby, edge_lw_by=weighted_attr, node_palette=colorby_dict, ) # group_by if node_labels: annotate.circos_group( G, group_by=colorby, midpoint=False, fontdict={"size": node_label_size}, ) annotate.node_colormapping( G, color_by=colorby, palette=colorby_dict, legend_kwargs=legend_kwargs, ) if save is not None: plt.savefig(save, bbox_inches="tight", **kwargs) if return_graph: return G
[docs] def productive_ratio( adata: AnnData, figsize: tuple[float, float] = (8, 4), palette: list[str] = ["lightblue", "darkblue"], fontsize: int | float = 8, rotation: int | float = 90, legend_kwargs: dict = { "bbox_to_anchor": (1, 0.5), "loc": "center left", "frameon": False, }, ): """Plot productive/non-productive contig ratio from AnnData (cell level). Parameters ---------- adata : AnnData AnnData object with `.uns['productive_ratio']` computed from `tl.productive_ratio`. figsize : tuple[float, float], optional Size of figure. palette : list[str], optional List of colours to plot non-productive and productive respectively. fontsize : int | float, optional Font size of x and y tick labels. rotation : int | float, optional Rotation of x tick labels. legend_kwargs : dict, optional Any additional kwargs to `plt.legend` """ res = adata.uns["productive_ratio"]["results"] locus = adata.uns["productive_ratio"]["locus"] groupby = adata.uns["productive_ratio"]["groupby"] plt.figure(figsize=figsize) ax = sns.barplot( x=groupby, y="productive+non-productive", data=res, color=palette[0] ) ax = sns.barplot( x=groupby, y="productive", data=res, color=palette[1], ax=ax ) legend = [ mpatches.Patch( color=palette[0], label="% with non-productive " + locus ), mpatches.Patch(color=palette[1], label="% with productive " + locus), ] plt.xticks(fontsize=fontsize, rotation=rotation) plt.yticks(fontsize=fontsize) plt.xlabel("") plt.ylabel("") ax.set(ylim=(0, 100)) plt.title(locus) # add legend plt.legend(handles=legend, **legend_kwargs)