Source code for dandelion.tools._trajectory

#!/usr/bin/env python
# @author: chenqu, kp9, kelvin
import re
import numpy as np
import pandas as pd
import scanpy as sc
import scipy as sp

from anndata import AnnData
from typing import Literal

from dandelion.utilities._utilities import bh, PResults


def _filter_cells(
    adata: AnnData,
    col: str,
    filter_pattern: str | None = ",|None|No_contig",
    remove_missing: bool = True,
) -> AnnData:
    """
    Helper function that identifies filter_pattern hits in `.obs[col]` of adata, and then either removes the
    offending cells or masks the matched values with a uniform value of `col+"_missing"`.
    """
    # find filter pattern hits in our column of interest
    mask = adata.obs[col].str.contains(filter_pattern)
    if remove_missing:
        # remove the cells
        adata = adata[~mask].copy()
    else:
        # uniformly mask the filter pattern hits
        adata.obs.loc[mask, col] = col + "_missing"
    return adata


[docs] def setup_vdj_pseudobulk( adata: AnnData, mode: Literal["B", "abT", "gdT"] | None = "abT", subsetby: str | None = None, groups: list[str] | None = None, allowed_chain_status: list[str] | None = [ "Single pair", "Extra pair", "Extra pair-exception", "Orphan VDJ", "Orphan VDJ-exception", ], productive_vdj: bool = True, productive_vj: bool = True, extract_cols: list[str] | None = None, productive_cols: list[str] | None = None, check_vdj_mapping: list[Literal["v_call", "d_call", "j_call"]] | None = [ "v_call", "j_call", ], check_vj_mapping: list[Literal["v_call", "j_call"]] | None = [ "v_call", "j_call", ], check_extract_cols_mapping: list[str] | None = None, filter_pattern: str | None = ",|None|No_contig", remove_missing: bool = True, ) -> AnnData: """Function for prepare anndata for computing pseudobulk vdj feature space. Parameters ---------- adata : AnnData cell adata before constructing anndata. mode : Literal["B", "abT", "gdT"] | None, optional Optional mode for extractin the V(D)J genes. If set as `None`, it requires the option `extract_cols` to be specified with a list of column names where this will be used to retrieve the main call. subsetby : str | None, optional If provided, only the groups/categories in this column will be used for computing the VDJ feature space. groups : list[str] | None, optional If provided, only the following groups/categories will be used for computing the VDJ feature space. allowed_chain_status : list[str] | None, optional If provided, only the ones in this list are kept from the `chain_status` column. productive_vdj : bool, optional If True, cells will only be kept if the main VDJ chain is productive. productive_vj : bool, optional If True, cells will only be kept if the main VJ chain is productive. extract_cols : list[str] | None, optional Column names where VDJ/VJ information is stored so that this will be used instead of the standard columns. productive_cols : list[str] | None, optional Column names where contig productive status is stored so that this will be used instead of the standard columns. check_vdj_mapping : list[Literal["v_call", "d_call", "j_call"]] | None, optional Only columns in the argument will be checked for unclear mapping (containing comma) in VDJ calls. Specifying None will skip this step. check_vj_mapping : list[Literal["v_call", "j_call"]] | None, optional Only columns in the argument will be checked for unclear mapping (containing comma) in VJ calls. Specifying None will skip this step. check_extract_cols_mapping : list[str] | None, optional Only columns in the argument will be checked for unclear mapping (containing comma) in columns specified in extract_cols. Specifying None will skip this step. filter_pattern : str | None, optional pattern to filter from object. If `None`, does not filter. remove_missing : bool, optional If True, will remove cells with contigs matching the filter from the object. If False, will mask them with a uniform value dependent on the column name. Returns ------- AnnData filtered cell adata object. """ # keep ony relevant cells (ones with a pair of chains) based on productive column if mode is not None: if productive_vdj: adata = adata[ adata.obs["productive_" + mode + "_VDJ"].str.startswith("T") ].copy() if productive_vj: adata = adata[ adata.obs["productive_" + mode + "_VJ"].str.startswith("T") ].copy() else: if productive_cols is not None: for col in productive_cols: adata = adata[adata.obs[col].str.startswith("T")].copy() if any([re.search("_VDJ_main|_VJ_main", i) for i in adata.obs]): if check_vdj_mapping is not None: if not isinstance(check_vdj_mapping, list): check_vdj_mapping = [check_vdj_mapping] if check_vj_mapping is not None: if not isinstance(check_vj_mapping, list): check_vj_mapping = [check_vj_mapping] if allowed_chain_status is not None: adata = adata[ adata.obs["chain_status"].isin(allowed_chain_status) ].copy() if (groups is not None) and (subsetby is not None): adata = adata[adata.obs[subsetby].isin(groups)].copy() if extract_cols is None: if not any([re.search("_VDJ_main|_VJ_main", i) for i in adata.obs]): v_call = ( "v_call_genotyped_" if "v_call_genotyped_VDJ" in adata.obs else "v_call_" ) if mode is not None: adata.obs["v_call_" + mode + "_VDJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs[v_call + mode + "_VDJ"] ] adata.obs["d_call_" + mode + "_VDJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs["d_call_" + mode + "_VDJ"] ] adata.obs["j_call_" + mode + "_VDJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs["j_call_" + mode + "_VDJ"] ] adata.obs["v_call_" + mode + "_VJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs[v_call + mode + "_VJ"] ] adata.obs["j_call_" + mode + "_VJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs["j_call_" + mode + "_VJ"] ] else: adata.obs["v_call_VDJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs[v_call + "VDJ"] ] adata.obs["d_call_VDJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs["d_call_VDJ"] ] adata.obs["j_call_VDJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs["j_call_VDJ"] ] adata.obs["v_call_VJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs[v_call + "VJ"] ] adata.obs["j_call_VJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs["j_call_VJ"] ] else: for col in extract_cols: adata.obs[col + "_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs[col] ] # remove any cells if there's unclear mapping if filter_pattern is not None: if mode is not None: if check_vdj_mapping is not None: for col in check_vdj_mapping: adata = _filter_cells( adata=adata, col=col + "_" + mode + "_VDJ_main", filter_pattern=filter_pattern, remove_missing=remove_missing, ) if check_vj_mapping is not None: for col in check_vj_mapping: adata = _filter_cells( adata=adata, col=col + "_" + mode + "_VJ_main", filter_pattern=filter_pattern, remove_missing=remove_missing, ) else: if extract_cols is None: if check_vdj_mapping is not None: for col in check_vdj_mapping: adata = _filter_cells( adata=adata, col=col + "_VDJ_main", filter_pattern=filter_pattern, remove_missing=remove_missing, ) if check_vj_mapping is not None: for col in check_vj_mapping: adata = _filter_cells( adata=adata, col=col + "_VJ_main", filter_pattern=filter_pattern, remove_missing=remove_missing, ) else: if check_extract_cols_mapping is not None: for col in check_extract_cols_mapping: adata = _filter_cells( adata=adata, col=col + "_main", filter_pattern=filter_pattern, remove_missing=remove_missing, ) return adata
def _get_pbs( pbs: np.ndarray | sp.sparse.csr_matrix | None, obs_to_bulk: str | None, adata: AnnData, ) -> np.ndarray: """ Helper function to ensure we have a cells by pseudobulks matrix which we can use for pseudobulking. Uses the pbs and obs_to_bulk inputs to vdj_pseudobulk() and gex_pseudobulk(). """ # well, we need some way to pseudobulk if pbs is None and obs_to_bulk is None: raise ValueError( "You need to specify `pbs` or `obs_to_bulk` when calling the function" ) # but just one if pbs is not None and obs_to_bulk is not None: raise ValueError("You need to specify `pbs` or `obs_to_bulk`, not both") # turn the pseubodulk matrix dense if need be if pbs is not None: if sp.sparse.issparse(pbs): pbs = pbs.todense() # get the obs-derived pseudobulk if obs_to_bulk is not None: if type(obs_to_bulk) is list: # this will create a single value by pasting all the columns together tobulk = adata.obs[obs_to_bulk].T.astype(str).agg(",".join) else: # we just have a single column tobulk = adata.obs[obs_to_bulk] # this pandas function creates the exact pseudobulk assignment we want # this needs to be different than the default uint8 # as you can have more than 255 cells in a pseudobulk, it turns out pbs = pd.get_dummies(tobulk, dtype="uint16").values return pbs def _get_pbs_obs( pbs: np.ndarray, obs_to_take: str | None, adata: AnnData ) -> pd.DataFrame: """ Helper function to create the pseudobulk object's obs. Uses the pbs and obs_to_take inputs to vdj_pseudobulk() and gex_pseudobulk(). """ # prepare per-pseudobulk calls of specified metadata columns pbs_obs = pd.DataFrame(index=np.arange(pbs.shape[1])) if obs_to_take is not None: # just in case a single is passed as a string if type(obs_to_take) is not list: obs_to_take = [obs_to_take] # now we can iterate over this nicely # using the logic of milopy's annotate_nhoods() for anno_col in obs_to_take: anno_dummies = pd.get_dummies(adata.obs[anno_col]) # this needs to be turned to a matrix so dimensions get broadcast correctly anno_count = np.asmatrix(pbs).T.dot(anno_dummies.values) anno_frac = np.array(anno_count / anno_count.sum(1)) anno_frac = pd.DataFrame( anno_frac, index=np.arange(pbs.shape[1]), columns=anno_dummies.columns, ) pbs_obs[anno_col] = anno_frac.idxmax(1) pbs_obs[anno_col + "_fraction"] = anno_frac.max(1) # report the number of cells for each pseudobulk # ensure pbs is an array so that it sums into a vector that can go in easily pbs_obs["cell_count"] = np.sum(np.asarray(pbs), axis=0) return pbs_obs
[docs] def vdj_pseudobulk( adata: AnnData, pbs: np.ndarray | sp.sparse.csr_matrix | None = None, obs_to_bulk: list[str] | str | None = None, obs_to_take: list[str] | str | None = None, normalise: bool = True, renormalise: bool = False, min_count: int = 1, mode: Literal["B", "abT", "gdT"] | None = "abT", extract_cols: list[str] | None = [ "v_call_abT_VDJ_main", "j_call_abT_VDJ_main", "v_call_abT_VJ_main", "j_call_abT_VJ_main", ], ) -> AnnData: """Function for making pseudobulk vdj feature space. One of `pbs` or `obs_to_bulk` needs to be specified when calling. Parameters ---------- adata : AnnData Cell adata, preferably after `ddl.tl.setup_vdj_pseudobulk()` pbs : np.ndarray | sp.sparse.csr_matrix | None, optional Optional binary matrix with cells as rows and pseudobulk groups as columns obs_to_bulk : list[str] | str | None, optional Optional obs column(s) to group pseudobulks into; if multiple are provided, they will be combined obs_to_take : list[str] | str | None, optional Optional obs column(s) to identify the most common value of for each pseudobulk. normalise : bool, optional If True, will scale the counts of each V(D)J gene group to 1 for each pseudobulk. renormalise : bool, optional If True, will re-scale the counts of each V(D)J gene group to 1 for each pseudobulk with any "missing" calls removed. Relevant with `normalise` as True, if `setup_vdj_pseudobulk()` was ran with `remove_missing` set to False. min_count : int, optional Pseudobulks with fewer than these many non-"missing" calls in a V(D)J gene group will have their non-"missing" calls set to 0 for that group. Relevant with `normalise` as True. mode : Literal["B", "abT", "gdT"] | None, optional Optional mode for extracting the V(D)J genes. If set as `None`, it will use e.g. `v_call_VDJ` instead of `v_call_abT_VDJ`. If `extract_cols` is provided, then this argument is ignored. extract_cols : list[str] | None, optional Column names where VDJ/VJ information is stored so that this will be used instead of the standard columns. Returns ------- AnnData pb_adata, whereby each observation is a pseudobulk:\n VDJ usage frequency/counts stored in pb_adata.X\n VDJ genes stored in pb_adata.var\n pseudobulk metadata stored in pb_adata.obs\n pseudobulk assignment (binary matrix with input cells as columns) stored in pb_adata.obsm['pbs']\n """ # get our cells by pseudobulks matrix pbs = _get_pbs(pbs, obs_to_bulk, adata) # if not specified by the user, use the following default dandelion VJ columns if extract_cols is None: if mode is None: extract_cols = [ i for i in adata.obs if re.search( "|".join( [ "_call_VDJ_main", "_call_VJ_main", ] ), i, ) ] else: extract_cols = [ i for i in adata.obs if re.search( "|".join([mode + "_VDJ_main", mode + "_VJ_main"]), i ) ] # perform matrix multiplication of pseudobulks by cells matrix by a cells by VJs matrix # start off by creating the cell by VJs matrix, using the pandas dummies again # skip the prefix stuff as the VJ genes will be unique in the columns vjs = pd.get_dummies(adata.obs[extract_cols], prefix="", prefix_sep="") # TODO: DENAN SOMEHOW? AS IN NAN GENES? # can now multiply transposed pseudobulk assignments by this vjs thing, turn to df vj_pb_count = pbs.T.dot(vjs.values) df = pd.DataFrame( vj_pb_count, index=np.arange(pbs.shape[1]), columns=vjs.columns ) if normalise: # identify any missing calls inserted by the setup, will end with _missing # negate as we want to actually remove them later defined_mask = ~(df.columns.str.endswith("_missing")) # loop over V(D)J gene categories for col in extract_cols: # identify columns holding genes belonging to the category # and then normalise the values to 1 for each pseudobulk group_mask = np.isin(df.columns, np.unique(adata.obs[col])) # identify the defined (non-missing) calls for the group group_defined_mask = group_mask & defined_mask # compute sum of non-missing values for each pseudobulk for this category # and compare to the min_count defined_min_count = ( df.loc[:, group_defined_mask].sum(axis=1) >= min_count ) # we can now normalise for the pseudobulks, for now all the pseudobulks df.loc[:, group_mask] = df.loc[:, group_mask].div( df.loc[:, group_mask].sum(axis=1), axis=0 ) if renormalise: # repeat the normalisation for non-missing values only # and only use pseudobulks crossing the min_count threshold df.loc[defined_min_count, group_defined_mask] = df.loc[ defined_min_count, group_defined_mask ].div( df.loc[defined_min_count, group_defined_mask].sum(axis=1), axis=0, ) # we can now mask the pseudobulks with insufficient defined counts df.loc[~defined_min_count, group_defined_mask] = 0 # create obs for the new pseudobulk object pbs_obs = _get_pbs_obs(pbs, obs_to_take, adata) # store our feature space and derived metadata into an AnnData pb_adata = sc.AnnData( np.array(df), var=pd.DataFrame(index=df.columns), obs=pbs_obs ) # store the pseudobulk assignments, as a sparse for storage efficiency # transpose as the original matrix is cells x pseudobulks pb_adata.obsm["pbs"] = sp.sparse.csr_matrix(pbs.T) return pb_adata
[docs] def pseudotime_transfer( adata: AnnData, pr_res: PResults, suffix: str = "" ) -> AnnData: """Function to add pseudotime and branch probabilities into adata.obs in place. Parameters ---------- adata : AnnData adata for which pseudotime to be transferred to pr_res : PResults palantir pseudotime inference output object suffix : str, optional suffix to be added after the added column names, default "" (none) Returns ------- AnnData transferred `AnnData`. """ adata.obs["pseudotime" + suffix] = pr_res.pseudotime.copy() for col in pr_res.branch_probs.columns: adata.obs["prob_" + col + suffix] = pr_res.branch_probs[col].copy() return adata
[docs] def project_pseudotime_to_cell( adata: AnnData, pb_adata: AnnData, term_states: list[str], suffix: str = "" ) -> AnnData: """Function to project pseudotime & branch probabilities from pb_adata (pseudobulk adata) to adata (cell adata). Parameters ---------- adata : AnnData Cell adata, preferably after `ddl.tl.setup_vdj_pseudobulk()` pb_adata : AnnData neighbourhood/pseudobulked adata term_states : list[str] list of terminal states with branch probabilities to be transferred suffix : str, optional suffix to be added after the added column names, default "" (none) Returns ------- AnnData subset of adata whereby cells that don't belong to any neighbourhood are removed and projected pseudotime information stored in .obs - `pseudotime+suffix`, and `'prob_'+term_state+suffix` for each terminal state """ # extract out cell x pseudobulk matrix. it's stored as pseudobulk x cell so transpose nhoods = np.array(pb_adata.obsm["pbs"].T.todense()) # leave out cells that don't belong to any neighbourhood nhoodsum = np.sum(nhoods, axis=1) cdata = adata[nhoodsum > 0].copy() print( "number of cells removed due to not belonging to any neighbourhood", sum(nhoodsum == 0), ) # print how many cells removed # also subset the pseudobulk_assignments pb_assign_trim = nhoods[nhoodsum > 0] # for each cell pseudotime_mean is the average of the pseudotime of all pseudobulks the cell is in, weighted by 1/neighbourhood size nhoods_cdata = nhoods[nhoodsum > 0, :] nhoods_cdata_norm = nhoods_cdata / np.sum( nhoods_cdata, axis=0, keepdims=True ) col_list = ["pseudotime" + suffix] + [ "prob_" + state + suffix for state in term_states ] for col in col_list: cdata.obs[col] = ( np.array( nhoods_cdata_norm.dot(pb_adata.obs[col]).T / np.sum(nhoods_cdata_norm, axis=1) ) .flatten() .copy() ) cdata.uns["pseudobulk_assignments"] = pb_assign_trim.copy() return cdata
[docs] def pseudobulk_gex( adata_raw: AnnData, pbs: np.ndarray | sp.sparse.csr_matrix | None = None, obs_to_bulk: list[str] | str | None = None, obs_to_take: list[str] | str | None = None, ) -> AnnData: """Function to pseudobulk gene expression (raw count). Parameters ---------- adata_raw : AnnData Needs to have raw counts in .X pbs : np.ndarray | sp.sparse.csr_matrix | None, optional Optional binary matrix with cells as rows and pseudobulk groups as columns obs_to_bulk : list[str] | str | None, optional Optional obs column(s) to group pseudobulks into; if multiple are provided, they will be combined obs_to_take : list[str] | str | None, optional Optional obs column(s) to identify the most common value of for each pseudobulk Returns ------- AnnData pb_adata whereby each observation is a cell neighbourhood\n pseudobulked gene expression stored in pb_adata.X\n genes stored in pb_adata.var\n pseudobulk metadata stored in pb_adata.obs\n pseudobulk assignment (binary matrix with input cells as columns) stored in pb_adata.obsm['pbs']\n """ # get our cells by pseudobulks matrix pbs = _get_pbs(pbs, obs_to_bulk, adata_raw) # make pseudobulk matrix pbs_X = adata_raw.X.T.dot(pbs) # create obs for the new pseudobulk object pbs_obs = _get_pbs_obs(pbs, obs_to_take, adata_raw) ## Make new anndata object pb_adata = sc.AnnData(pbs_X.T, obs=pbs_obs, var=adata_raw.var) # store the pseudobulk assignments, as a sparse for storage efficiency # transpose as the original matrix is cells x pseudobulks pb_adata.obsm["pbs"] = sp.sparse.csr_matrix(pbs.T) return pb_adata
[docs] def bin_expression( adata: AnnData, bin_no: int, genes: list[str], pseudotime_col: str ) -> pd.DataFrame: """Function to compute average gene expression in bins along pseudotime. Parameters ---------- adata : AnnData cell adata. bin_no : int number of bins to be divided along pseudotime. genes : list[str] list of genes for the computation pseudotime_col : str column in adata.obs where pseudotime is stored Returns ------- pd.DataFrame a data frame with genes as rows, and pseudotime bins as columns, and averaged gene expression as the data """ # define bins bins = np.linspace(0, 1, bin_no + 1) # get gene expression x = np.array(adata[:, genes].X.todense()) # get pseudotime y = np.array(adata.obs[pseudotime_col]) # calculate average gene expression in each bin gene_summary = pd.DataFrame(columns=bins[:-1], index=genes) for i in range(gene_summary.shape[1]): time = bins[i] select = np.array(bins[i] <= y) & np.array(y < bins[i + 1]) gene_summary.loc[:, time] = np.mean(x[select, :], axis=0) return gene_summary
[docs] def chatterjee_corr( adata: AnnData, genes: list[str], pseudotime_col: str ) -> pd.DataFrame: """Function to compute chatterjee correlation of gene expression with pseudotime. Parameters ---------- adata : AnnData cell adata genes : list[str] List of genes selected to compute the correlation pseudotime_col : str column in adata.obs where pseudotime is stored Returns ------- pd.DataFrame a data frame with genes as rows, with cor_res (correlation statistics), pval (p-value), adj_pval (p-value adjusted by BH method) as columns. """ # get gene expression x = np.array(adata[:, genes].X.todense()) # add small perturbation for random tie breaking x = x + np.random.randn(x.shape[0], x.shape[1]) * 1e-15 # get pseudotime y = list(adata.obs[pseudotime_col]) # compute chatterjee correlation # ref: Sourav Chatterjee (2021) A New Coefficient of Correlation, Journal of the American Statistical Association, 116:536, 2009-2022, DOI: 10.1080/01621459.2020.1758115 stat = 1 - np.sum( np.abs(np.diff(np.argsort(x[np.argsort(y), :], axis=0), axis=0)), axis=0 ) * 3 / (x.shape[0] ** 2 - 1) stat = np.array(stat).flatten() pval = 1 - sp.stats.norm.cdf(stat, loc=0, scale=np.sqrt(2 / 5 / x.shape[0])) # put results into data frame cor_res cor_res = pd.DataFrame({"cor_stat": stat, "pval": pval}) cor_res.index = genes # compute adjusted pval using BH method cor_res["adj_pval"] = bh(cor_res["pval"].to_numpy()) # sort genes based on adjusted pval cor_res = cor_res.sort_values(by="adj_pval") return cor_res