Source code for dandelion.external.immcantation.polars.shazam

import os
import re
import functools
import warnings

import pandas as pd  # used only for rpy2 interop
import polars as pl
import numpy as np

from scanpy import logging as logg
from typing import Literal

from plotnine import (
    ggplot,
    options,
    aes,
    xlab,
    ylab,
    facet_wrap,
    theme,
    annotate,
    theme_bw,
    geom_histogram,
    geom_vline,
    save_as_pdf_pages,
)

from dandelion.polars.core._core import (
    DandelionPolars,
    load_polars,
    _sanitize_data_polars,
    write_airr,
    SCHEMA_OVERRIDES,
)


[docs] def quantify_mutations( data: DandelionPolars | pl.DataFrame | pl.LazyFrame | str, split_locus: bool = False, sequence_column: str | None = None, germline_column: str | None = None, region_definition: str | None = None, mutation_definition: str | None = None, frequency: bool = False, combine: bool = True, existing_columns: Literal["overwrite", "suffix"] = "overwrite", existing_column_suffix: str = "_new", **kwargs, ) -> pl.DataFrame | None: """ Run basic mutation load analysis. Implemented in `shazam` https://shazam.readthedocs.io/en/stable/vignettes/Mutation-Vignette. Parameters ---------- data : Dandelion | str Dandelion object, file path to AIRR file. split_locus : bool, optional whether to return the results for heavy chain and light chain separately. sequence_column : str | None, optional passed to shazam's `observedMutations`. https://shazam.readthedocs.io/en/stable/topics/observedMutations germline_column : str | None, optional passed to shazam's `observedMutations`. https://shazam.readthedocs.io/en/stable/topics/observedMutations region_definition : str | None, optional passed to shazam's `observedMutations`. https://shazam.readthedocs.io/en/stable/topics/IMGT_SCHEMES/ mutation_definition : str | None, optional passed to shazam's `observedMutations`. https://shazam.readthedocs.io/en/stable/topics/MUTATION_SCHEMES/ frequency : bool, optional whether to return the results a frequency or counts. combine : bool, optional whether to return the results for replacement and silent mutations separately. existing_columns : Literal["overwrite", "suffix"], optional how to handle mutation output columns that already exist in input data. - "overwrite": replace existing columns with freshly computed values. - "suffix": preserve existing columns and append new values with ``existing_column_suffix``. existing_column_suffix : str, optional suffix used when ``existing_columns='suffix'``. **kwargs passed to shazam::observedMutations. Returns ------- pd.DataFrame pandas DataFrame holding mutation information. """ start = logg.info("Quantifying mutations") try: from rpy2.robjects.packages import importr from rpy2.rinterface import NULL except Exception: raise ( ImportError( "Unable to initialise R instance. Please run this separately through R with shazam's tutorials." ) ) sh = importr("shazam") base = importr("base") # Load input as Polars if isinstance(data, DandelionPolars): dat = load_polars(data._data) else: dat = load_polars(data) if existing_columns not in {"overwrite", "suffix"}: raise ValueError( "existing_columns must be one of {'overwrite', 'suffix'}." ) if existing_columns == "suffix" and existing_column_suffix == "": raise ValueError( "existing_column_suffix cannot be empty when existing_columns='suffix'." ) warnings.filterwarnings("ignore") # Sanitize and filter using Polars dat = _sanitize_data_polars(dat) if isinstance(dat, pl.LazyFrame): dat = dat.collect(engine="streaming") if "ambiguous" in dat.columns: dat_ = dat.filter(pl.col("ambiguous") == "F") else: dat_ = dat seq_ = "sequence_alignment" if sequence_column is None else sequence_column germline_ = ( "germline_alignment_d_mask" if germline_column is None else germline_column ) reg_d = NULL if region_definition is None else base.get(region_definition) mut_d = ( NULL if mutation_definition is None else base.get(mutation_definition) ) # Convert to pandas only for rpy2 interop if split_locus is False: dat_r = safe_py2rpy( dat_.with_columns(pl.col("*").cast(pl.String)).to_pandas() ) results = sh.observedMutations( dat_r, sequenceColumn=seq_, germlineColumn=germline_, regionDefinition=reg_d, mutationDefinition=mut_d, frequency=frequency, combine=combine, **kwargs, ) pd_df = safe_rpy2py(results) else: dat_h = dat_.filter(pl.col("locus") == "IGH") dat_l = dat_.filter(pl.col("locus").is_in(["IGK", "IGL"])) dat_h_r = safe_py2rpy( dat_h.with_columns(pl.col("*").cast(pl.String)).to_pandas() ) dat_l_r = safe_py2rpy( dat_l.with_columns(pl.col("*").cast(pl.String)).to_pandas() ) results_h = sh.observedMutations( dat_h_r, sequenceColumn=seq_, germlineColumn=germline_, regionDefinition=reg_d, mutationDefinition=mut_d, frequency=frequency, combine=combine, **kwargs, ) results_l = sh.observedMutations( dat_l_r, sequenceColumn=seq_, germlineColumn=germline_, regionDefinition=reg_d, mutationDefinition=mut_d, frequency=frequency, combine=combine, **kwargs, ) results_h = safe_rpy2py(results_h) results_l = safe_rpy2py(results_l) pd_df = pd.concat([results_h, results_l], ignore_index=True) # Identify new columns produced by R pd_df.set_index("sequence_id", inplace=True, drop=False) dat_cols = dat_.columns cols_to_return = [c for c in pd_df.columns if c not in dat_cols] if len(cols_to_return) < 1: cols_to_return = list( filter(re.compile("mu_.*").match, [c for c in pd_df.columns]) ) else: cols_to_return = cols_to_return # Convert R output to Polars and merge back # Clean R NA types more thoroughly before converting to Polars for col in pd_df.columns: if pd_df[col].dtype == object: pd_df[col] = pd_df[col].apply( lambda x: ( None if ( hasattr(x, "__class__") and ( "NACharacterType" in str(type(x).__name__) or "NAType" in str(type(x).__name__) ) ) else x ) ) # pd.concat (split_locus=True) can produce object columns with # mixed str/int values when dtypes differ across the two frames. # PyArrow rejects these, so coerce to str when mixing is detected. types_seen = {type(v) for v in pd_df[col].dropna()} if len(types_seen) > 1 and str in types_seen: pd_df[col] = pd_df[col].apply( lambda x: str(x) if x is not None else x ) r_out_pl = pl.from_pandas(pd_df, schema_overrides=SCHEMA_OVERRIDES) if isinstance(data, DandelionPolars): # Append new columns to data._data via sequence_id join base_df = data._data if isinstance(base_df, pl.LazyFrame): base_df = base_df.collect(engine="streaming") overlap = [c for c in cols_to_return if c in base_df.columns] rename_map: dict[str, str] = {} if overlap: if existing_columns == "overwrite": base_df = base_df.drop(overlap) else: for col in overlap: candidate = f"{col}{existing_column_suffix}" while ( candidate in base_df.columns or candidate in cols_to_return or candidate in rename_map.values() ): candidate = f"{candidate}{existing_column_suffix}" rename_map[col] = candidate if rename_map: r_out_pl = r_out_pl.rename(rename_map) cols_to_return = [rename_map.get(c, c) for c in cols_to_return] add_df = r_out_pl.select(["sequence_id"] + cols_to_return) data._data = base_df.join(add_df, on="sequence_id", how="left") # Build metadata in Polars if split_locus is False: metadata_ = ( data._data.select(["cell_id"] + cols_to_return) .with_columns( [pl.col(c).cast(pl.Float64) for c in cols_to_return] ) .group_by("cell_id") .sum() ) else: grouped = ( data._data.select(["locus", "cell_id"] + cols_to_return) .with_columns( [pl.col(c).cast(pl.Float64) for c in cols_to_return] ) .group_by(["locus", "cell_id"]) .sum() ) loci = grouped.select("locus").unique().to_series().to_list() metadatas: list[pl.DataFrame] = [] for loc in loci: tmp = grouped.filter(pl.col("locus") == loc).drop("locus") tmp = tmp.rename({c: f"{c}_{loc}" for c in cols_to_return}) metadatas.append(tmp) # Outer join across all locus-specific summaries on cell_id if len(metadatas) > 0: metadata_ = functools.reduce( lambda left, right: left.join( right, on="cell_id", how="full", coalesce=True ), metadatas, ) else: metadata_ = pl.DataFrame({"cell_id": []}) existing_metadata = data._metadata if existing_metadata is None: data._metadata = metadata_ else: if isinstance(existing_metadata, pl.LazyFrame): existing_metadata = existing_metadata.collect( engine="streaming" ) if "cell_id" not in existing_metadata.columns: data._metadata = metadata_ else: meta_cols = [c for c in metadata_.columns if c != "cell_id"] overlap_meta = [ c for c in meta_cols if c in existing_metadata.columns ] rename_meta: dict[str, str] = {} if overlap_meta: if existing_columns == "overwrite": existing_metadata = existing_metadata.drop(overlap_meta) else: for col in overlap_meta: candidate = f"{col}{existing_column_suffix}" while ( candidate in existing_metadata.columns or candidate in metadata_.columns or candidate in rename_meta.values() ): candidate = ( f"{candidate}{existing_column_suffix}" ) rename_meta[col] = candidate if rename_meta: metadata_ = metadata_.rename(rename_meta) data._metadata = existing_metadata.join( metadata_, on="cell_id", how="full", coalesce=True ) logg.info( " finished", time=start, deep=( "Updated Dandelion object: \n" " 'data', contig-indexed AIRR table\n" " 'metadata', cell-indexed observations table\n" ), ) else: # Merge results back into Polars DataFrame and return/save base_df = dat_ overlap = [c for c in cols_to_return if c in base_df.columns] rename_map = {} if overlap: if existing_columns == "overwrite": base_df = base_df.drop(overlap) else: for col in overlap: candidate = f"{col}{existing_column_suffix}" while ( candidate in base_df.columns or candidate in cols_to_return or candidate in rename_map.values() ): candidate = f"{candidate}{existing_column_suffix}" rename_map[col] = candidate if rename_map: r_out_pl = r_out_pl.rename(rename_map) cols_to_return = [rename_map.get(c, c) for c in cols_to_return] add_df = r_out_pl.select(["sequence_id"] + cols_to_return) out_df = base_df.join(add_df, on="sequence_id", how="left") if isinstance(data, (pl.DataFrame, pl.LazyFrame)): logg.info( " finished", time=start, deep=("Returning Polars DataFrame\n") ) return out_df elif os.path.isfile(str(data)): logg.info( " finished", time=start, deep=("saving DataFrame at {}\n".format(str(data))), ) write_airr(out_df, data) return out_df return None
[docs] def calculate_threshold( data: DandelionPolars | pl.DataFrame | pl.LazyFrame | str, mode: Literal["single-cell", "heavy"] = "single-cell", manual_threshold: float | None = None, VJthenLen: bool = False, model: ( Literal[ "ham", "aa", "hh_s1f", "hh_s5f", "mk_rs1nf", "hs1f_compat", "m1n_compat", ] | None ) = None, normalize_method: Literal["len"] | None = None, threshold_method: Literal["gmm", "density"] | None = None, edge: float | None = None, cross: list[float] | None = None, subsample: int | None = None, threshold_model: ( Literal["norm-norm", "norm-gamma", "gamma-norm", "gamma-gamma"] | None ) = None, cutoff: Literal["optimal", "intersect", "user"] | None = None, sensitivity: float | None = None, specificity: float | None = None, plot: bool = True, plot_group: str | None = None, figsize: tuple[float, float] = (4.5, 2.5), save_plot: str | None = None, n_cpus: int = 1, **kwargs, ) -> float: """ Calculating nearest neighbor distances for tuning clonal assignment with `shazam`. https://shazam.readthedocs.io/en/stable/vignettes/DistToNearest-Vignette/ Runs the following: distToNearest Get non-zero distance of every heavy chain (IGH) sequence (as defined by sequenceColumn) to its nearest sequence in a partition of heavy chains sharing the same V gene, J gene, and junction length (VJL), or in a partition of single cells with heavy chains sharing the same heavy chain VJL combination, or of single cells with heavy and light chains sharing the same heavy chain VJL and light chain VJL combinations. findThreshold automtically determines an optimal threshold for clonal assignment of Ig sequences using a vector of nearest neighbor distances. It provides two alternative methods using either a Gamma/Gaussian Mixture Model fit (threshold_method="gmm") or kernel density fit (threshold_method="density"). Parameters ---------- data : Dandelion | pd.DataFrame | str input `Danelion`, AIRR data as pandas DataFrame or path to file. mode : Literal["single-cell", "heavy"], optional accepts one of "heavy" or "single-cell". Refer to https://shazam.readthedocs.io/en/stable/vignettes/DistToNearest-Vignette. manual_threshold : float | None, optional value to manually plot in histogram. VJthenLen : bool, optional logical value specifying whether to perform partitioning as a 2-stage process. If True, partitions are made first based on V and J gene, and then further split based on junction lengths corresponding to sequenceColumn. If False, perform partition as a 1-stage process during which V gene, J gene, and junction length are used to create partitions simultaneously. Defaults to False. model : Literal["ham", "aa", "hh_s1f", "hh_s5f", "mk_rs1nf", "hs1f_compat", "m1n_compat", ] | None, optional underlying SHM model, which must be one of "ham","aa","hh_s1f","hh_s5f","mk_rs1nf","hs1f_compat","m1n_compat". normalize_method : Literal["len"] | None, optional method of normalization. The default is "len", which divides the distance by the length of the sequence group. If "none" then no normalization if performed. threshold_method : Literal["gmm", "density"] | None, optional string defining the method to use for determining the optimal threshold. One of "gmm" or "density". edge : float | None, optional upper range as a fraction of the data density to rule initialization of Gaussian fit parameters. Default value is 0.9 (or 90). Applies only when threshold_method="density". cross : list[float] | None, optional supplementary nearest neighbor distance vector output from distToNearest for initialization of the Gaussian fit parameters. Applies only when method="gmm". subsample : int | None, optional maximum number of distances to subsample to before threshold detection. threshold_model : Literal["norm-norm", "norm-gamma", "gamma-norm", "gamma-gamma"] | None, optional allows the user to choose among four possible combinations of fitting curves: "norm-norm", "norm-gamma", "gamma-norm", and "gamma-gamma". Applies only when method="gmm". cutoff : Literal["optimal", "intersect", "user"] | None, optional method to use for threshold selection: the optimal threshold "optimal", the intersection point of the two fitted curves "intersect", or a value defined by user for one of the sensitivity or specificity "user". Applies only when method="gmm". sensitivity : float | None, optional sensitivity required. Applies only when method="gmm" and cutoff="user". specificity : float | None, optional specificity required. Applies only when method="gmm" and cutoff="user". plot : bool, optional whether or not to return plot. plot_group : str | None, optional determines the fill color and facets. figsize : tuple[float, float], optional size of plot. save_plot : str | None, optional if specified, plot will be save with this path. n_cpus : int, optional number of cpus to run `distToNearest`. defaults to 1. **kwargs passed to shazam's `distToNearest <https://shazam.readthedocs.io/en/stable/topics/distToNearest/>`__. Returns ------- float threshold value for clonal assignment in DefineClones. Raises ------ ValueError if automatic thresholding failed. """ logg.info("Calculating threshold") try: from rpy2.robjects.packages import importr from rpy2.rinterface import NULL from rpy2.robjects import FloatVector except Exception: raise ( ImportError( "Unable to initialise R instance. Please run this separately through R with shazam's tutorials." ) ) # Load input as Polars if isinstance(data, DandelionPolars): dat = load_polars(data._data) else: dat = load_polars(data) warnings.filterwarnings("ignore") sh = importr("shazam") # Choose v_call column in Polars dat_cols = ( dat.collect_schema().names() if isinstance(dat, pl.LazyFrame) else dat.columns ) v_call = "v_call_genotyped" if "v_call_genotyped" in dat_cols else "v_call" model_ = "ham" if model is None else model norm_ = "len" if normalize_method is None else normalize_method threshold_method_ = ( "density" if threshold_method is None else threshold_method ) subsample_ = NULL if subsample is None else subsample # Sanitize using Polars; convert to pandas only for rpy2 dat_pl = _sanitize_data_polars(dat) if isinstance(dat_pl, pl.LazyFrame): dat_pl = dat_pl.collect(engine="streaming") if mode == "heavy": dat_h = dat_pl.filter(pl.col("locus").is_in(["IGH", "TRB", "TRD"])) dat_h_r = safe_py2rpy( dat_h.with_columns(pl.col("*").cast(pl.String)).to_pandas() ) dist_ham = sh.distToNearest( dat_h_r, vCallColumn=v_call, model=model_, normalize=norm_, **kwargs ) elif mode == "single-cell": dat_r = safe_py2rpy( dat_pl.with_columns(pl.col("*").cast(pl.String)).to_pandas() ) try: dist_ham = sh.distToNearest( dat_r, cellIdColumn="cell_id", locusColumn="locus", VJthenLen=VJthenLen, vCallColumn=v_call, normalize=norm_, model=model_, nproc=n_cpus, **kwargs, ) except Exception: logg.info( "Rerun this after filtering. For now, switching to heavy mode." ) dat_h = dat_pl.filter(pl.col("locus").is_in(["IGH", "TRB", "TRD"])) # drop "cell_id" column as it causes issues dat_h = dat_h.drop("cell_id") dat_h_r = safe_py2rpy( dat_h.with_columns(pl.col("*").cast(pl.String)).to_pandas() ) dist_ham = sh.distToNearest( dat_h_r, vCallColumn=v_call, model=model_, normalize=norm_, nproc=n_cpus, **kwargs, ) dist_ham = safe_rpy2py(dist_ham) # Find threshold using density method dist = np.array(dist_ham["dist_nearest"]) if manual_threshold is None: if threshold_method_ == "density": edge_ = 0.9 if edge is None else edge dist_threshold = sh.findThreshold( FloatVector(dist[~np.isnan(dist)]), method=threshold_method_, subsample=subsample_, edge=edge_, ) threshold = np.array(dist_threshold.slots["threshold"])[0] if np.isnan(threshold): logg.info( " Threshold method 'density' did not return with any values. Switching to method = 'gmm'." ) threshold_method_ = "gmm" threshold_model_ = ( "gamma-gamma" if threshold_model is None else threshold_model ) cross_ = NULL if cross is None else cross cutoff_ = "optimal" if cutoff is None else cutoff sen_ = NULL if sensitivity is None else sensitivity spc_ = NULL if specificity is None else specificity dist_threshold = sh.findThreshold( FloatVector(dist[~np.isnan(dist)]), method=threshold_method_, model=threshold_model_, cross=cross_, subsample=subsample_, cutoff=cutoff_, sen=sen_, spc=spc_, ) dist_threshold = safe_rpy2py(dist_threshold) threshold = np.array(dist_threshold.slots["threshold"])[0] else: threshold_model_ = ( "gamma-gamma" if threshold_model is None else threshold_model ) cross_ = NULL if cross is None else cross cutoff_ = "optimal" if cutoff is None else cutoff sen_ = NULL if sensitivity is None else sensitivity spc_ = NULL if specificity is None else specificity dist_threshold = sh.findThreshold( FloatVector(dist[~np.isnan(dist)]), method=threshold_method_, model=threshold_model_, cross=cross_, subsample=subsample_, cutoff=cutoff_, sen=sen_, spc=spc_, ) dist_threshold = safe_rpy2py(dist_threshold) threshold = np.array(dist_threshold.slots["threshold"])[0] if np.isnan(threshold): raise ValueError( "Automatic thresholding failed. Please visually inspect the resulting distribution fits" + " and choose a threshold value manually." ) # dist_ham = pandas2ri.rpy2py_data frame(dist_ham) tr = threshold else: tr = manual_threshold if plot: options.figure_size = figsize if plot_group is None: if "sample_id" in dist_ham.columns: plot_group = "sample_id" else: plot_group = None else: plot_group = plot_group if plot_group is None: p = ( ggplot(dist_ham, aes("dist_nearest")) + theme_bw() + xlab("Grouped Hamming distance") + ylab("Count") + geom_histogram(binwidth=0.01) + geom_vline( xintercept=tr, linetype="dashed", color="blue", size=0.5 ) + annotate( "text", x=tr + 0.02, y=10, label="Threshold:\n" + str(np.around(tr, decimals=2)), size=8, color="Blue", ) + theme(legend_position="none") ) else: p = ( ggplot(dist_ham, aes("dist_nearest", fill=str(plot_group))) + theme_bw() + xlab("Grouped Hamming distance") + ylab("Count") + geom_histogram(binwidth=0.01) + geom_vline( xintercept=tr, linetype="dashed", color="blue", size=0.5 ) + annotate( "text", x=tr + 0.02, y=10, label="Threshold:\n" + str(np.around(tr, decimals=2)), size=8, color="Blue", ) + facet_wrap("~" + str(plot_group), scales="free_y") + theme(legend_position="none") ) if save_plot is not None: save_as_pdf_pages([p], filename=save_plot, verbose=False) p.show() return tr
def safe_py2rpy(df: pd.DataFrame) -> object: """Convert pandas DataFrame to R object safely.""" try: import rpy2 from rpy2.robjects.conversion import localconverter from rpy2.robjects import pandas2ri except Exception: raise ( ImportError( "Unable to initialise R instance. Please run this separately through R with shazam's tutorials." ) ) try: with localconverter( rpy2.robjects.default_converter + pandas2ri.converter ): return pandas2ri.py2rpy(df) except Exception: df = df.astype(str) with localconverter( rpy2.robjects.default_converter + pandas2ri.converter ): return pandas2ri.py2rpy(df) def safe_rpy2py(r_object): """Convert R object to pandas DataFrame or other Python object safely.""" try: import rpy2 from rpy2.robjects.conversion import localconverter from rpy2.robjects import pandas2ri except Exception: raise ( ImportError( "Unable to initialise R instance. Please run this separately through R with shazam's tutorials." ) ) with localconverter(rpy2.robjects.default_converter + pandas2ri.converter): result = rpy2.robjects.conversion.rpy2py(r_object) # Replace R NA types with proper pandas NA/None for downstream Polars compatibility # Only apply to DataFrame objects if isinstance(result, pd.DataFrame): for col in result.columns: if result[col].dtype == object: # Replace various R NA types result[col] = result[col].apply( lambda x: ( None if ( hasattr(x, "__class__") and "NACharacterType" in str(type(x).__name__) ) else x ) ) return result