Source code for dandelion.base.tools._trajectory

#!/usr/bin/env python
# @author: chenqu, kp9, kelvin
from __future__ import annotations

import re

from anndata import AnnData
import numpy as np
import pandas as pd
import scipy as sp

from scanpy import logging as logg
from typing import Literal


def _filter_cells(
    adata: AnnData,
    col: str,
    filter_pattern: str | None = ",|None|No_contig",
    remove_missing: bool = True,
) -> AnnData:
    """
    Helper function that identifies filter_pattern hits in `.obs[col]` of adata, and then either removes the
    offending cells or masks the matched values with a uniform value of `col+"_missing"`.
    """
    # find filter pattern hits in our column of interest
    mask = (
        adata.obs[col].astype("string").str.contains(filter_pattern, na=False)
    )
    mask_arr = mask.to_numpy(dtype=bool)
    if remove_missing:
        # remove the cells
        adata = adata[~mask_arr].copy()
    else:
        # uniformly mask the filter pattern hits
        adata.obs.loc[mask, col] = col + "_missing"
    return adata


[docs] def setup_vdj_pseudobulk( adata: AnnData, mode: Literal["B", "abT", "gdT"] | None = "abT", subsetby: str | None = None, groups: list[str] | None = None, allowed_chain_status: list[str] | None = [ "Single pair", "Extra pair", "Extra pair-exception", "Orphan VDJ", "Orphan VDJ-exception", ], productive_vdj: bool = True, productive_vj: bool = True, extract_cols: list[str] | None = None, productive_cols: list[str] | None = None, check_vdj_mapping: list[Literal["v_call", "d_call", "j_call"]] | None = [ "v_call", "j_call", ], check_vj_mapping: list[Literal["v_call", "j_call"]] | None = [ "v_call", "j_call", ], check_extract_cols_mapping: list[str] | None = None, filter_pattern: str | None = ",|None|No_contig", remove_missing: bool = True, ) -> AnnData: """Function for prepare anndata for computing pseudobulk vdj feature space. Parameters ---------- adata : AnnData cell adata before constructing anndata. mode : Literal["B", "abT", "gdT"] | None, optional Optional mode for extractin the V(D)J genes. If set as `None`, it requires the option `extract_cols` to be specified with a list of column names where this will be used to retrieve the main call. subsetby : str | None, optional If provided, only the groups/categories in this column will be used for computing the VDJ feature space. groups : list[str] | None, optional If provided, only the following groups/categories will be used for computing the VDJ feature space. allowed_chain_status : list[str] | None, optional If provided, only the ones in this list are kept from the `chain_status` column. productive_vdj : bool, optional If True, cells will only be kept if the main VDJ chain is productive. productive_vj : bool, optional If True, cells will only be kept if the main VJ chain is productive. extract_cols : list[str] | None, optional Column names where VDJ/VJ information is stored so that this will be used instead of the standard columns. productive_cols : list[str] | None, optional Column names where contig productive status is stored so that this will be used instead of the standard columns. check_vdj_mapping : list[Literal["v_call", "d_call", "j_call"]] | None, optional Only columns in the argument will be checked for unclear mapping (containing comma) in VDJ calls. Specifying None will skip this step. check_vj_mapping : list[Literal["v_call", "j_call"]] | None, optional Only columns in the argument will be checked for unclear mapping (containing comma) in VJ calls. Specifying None will skip this step. check_extract_cols_mapping : list[str] | None, optional Only columns in the argument will be checked for unclear mapping (containing comma) in columns specified in extract_cols. Specifying None will skip this step. filter_pattern : str | None, optional pattern to filter from object. If `None`, does not filter. remove_missing : bool, optional If True, will remove cells with contigs matching the filter from the object. If False, will mask them with a uniform value dependent on the column name. Returns ------- AnnData filtered cell adata object. """ # keep ony relevant cells (ones with a pair of chains) based on productive column if mode is not None: if productive_vdj: mask_vdj = ( adata.obs["productive_" + mode + "_VDJ"] .astype("string") .str.startswith("T", na=False) .to_numpy(dtype=bool) ) adata = adata[mask_vdj].copy() if productive_vj: mask_vj = ( adata.obs["productive_" + mode + "_VJ"] .astype("string") .str.startswith("T", na=False) .to_numpy(dtype=bool) ) adata = adata[mask_vj].copy() else: if productive_cols is not None: for col in productive_cols: mask_col = ( adata.obs[col] .astype("string") .str.startswith("T", na=False) .to_numpy(dtype=bool) ) adata = adata[mask_col].copy() if any([re.search("_VDJ_main|_VJ_main", i) for i in adata.obs]): if check_vdj_mapping is not None: if not isinstance(check_vdj_mapping, list): check_vdj_mapping = [check_vdj_mapping] if check_vj_mapping is not None: if not isinstance(check_vj_mapping, list): check_vj_mapping = [check_vj_mapping] if allowed_chain_status is not None: adata = adata[ adata.obs["chain_status"].isin(allowed_chain_status) ].copy() if (groups is not None) and (subsetby is not None): adata = adata[adata.obs[subsetby].isin(groups)].copy() if extract_cols is None: if not any([re.search("_VDJ_main|_VJ_main", i) for i in adata.obs]): if mode is not None: adata.obs["v_call_" + mode + "_VDJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs["v_call_" + mode + "_VDJ"] ] adata.obs["d_call_" + mode + "_VDJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs["d_call_" + mode + "_VDJ"] ] adata.obs["j_call_" + mode + "_VDJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs["j_call_" + mode + "_VDJ"] ] adata.obs["v_call_" + mode + "_VJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs["v_call_" + mode + "_VJ"] ] adata.obs["j_call_" + mode + "_VJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs["j_call_" + mode + "_VJ"] ] else: adata.obs["v_call_VDJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs["v_call_VDJ"] ] adata.obs["d_call_VDJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs["d_call_VDJ"] ] adata.obs["j_call_VDJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs["j_call_VDJ"] ] adata.obs["v_call_VJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs["v_call_VJ"] ] adata.obs["j_call_VJ_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs["j_call_VJ"] ] else: for col in extract_cols: adata.obs[col + "_main"] = [ x.split("|")[0] if x != "None" else "None" for x in adata.obs[col] ] # remove any cells if there's unclear mapping if filter_pattern is not None: if mode is not None: if check_vdj_mapping is not None: for col in check_vdj_mapping: adata = _filter_cells( adata=adata, col=col + "_" + mode + "_VDJ_main", filter_pattern=filter_pattern, remove_missing=remove_missing, ) if check_vj_mapping is not None: for col in check_vj_mapping: adata = _filter_cells( adata=adata, col=col + "_" + mode + "_VJ_main", filter_pattern=filter_pattern, remove_missing=remove_missing, ) else: if extract_cols is None: if check_vdj_mapping is not None: for col in check_vdj_mapping: adata = _filter_cells( adata=adata, col=col + "_VDJ_main", filter_pattern=filter_pattern, remove_missing=remove_missing, ) if check_vj_mapping is not None: for col in check_vj_mapping: adata = _filter_cells( adata=adata, col=col + "_VJ_main", filter_pattern=filter_pattern, remove_missing=remove_missing, ) else: if check_extract_cols_mapping is not None: for col in check_extract_cols_mapping: adata = _filter_cells( adata=adata, col=col + "_main", filter_pattern=filter_pattern, remove_missing=remove_missing, ) return adata
def _get_pbs( pbs: np.ndarray | sp.sparse.csr_matrix | None, obs_to_bulk: str | None, adata: AnnData, ) -> np.ndarray: """ Helper function to ensure we have a cells by pseudobulks matrix which we can use for pseudobulking. Uses the pbs and obs_to_bulk inputs to vdj_pseudobulk() and gex_pseudobulk(). """ # well, we need some way to pseudobulk if pbs is None and obs_to_bulk is None: raise ValueError( "You need to specify `pbs` or `obs_to_bulk` when calling the function" ) # but just one if pbs is not None and obs_to_bulk is not None: raise ValueError("You need to specify `pbs` or `obs_to_bulk`, not both") # turn the pseubodulk matrix dense if need be if pbs is not None: if sp.sparse.issparse(pbs): pbs = pbs.todense() # get the obs-derived pseudobulk if obs_to_bulk is not None: if type(obs_to_bulk) is list: # this will create a single value by pasting all the columns together tobulk = adata.obs[obs_to_bulk].T.astype(str).agg(",".join) else: # we just have a single column tobulk = adata.obs[obs_to_bulk] # this pandas function creates the exact pseudobulk assignment we want # this needs to be different than the default uint8 # as you can have more than 255 cells in a pseudobulk, it turns out pbs = pd.get_dummies(tobulk, dtype="uint16").values return pbs def _get_pbs_obs( pbs: np.ndarray, obs_to_take: str | None, adata: AnnData ) -> pd.DataFrame: """ Helper function to create the pseudobulk object's obs. Uses the pbs and obs_to_take inputs to vdj_pseudobulk() and gex_pseudobulk(). """ # prepare per-pseudobulk calls of specified metadata columns pbs_obs = pd.DataFrame(index=np.arange(pbs.shape[1])) if obs_to_take is not None: # just in case a single is passed as a string if type(obs_to_take) is not list: obs_to_take = [obs_to_take] # now we can iterate over this nicely # using the logic of milopy's annotate_nhoods() for anno_col in obs_to_take: anno_dummies = pd.get_dummies(adata.obs[anno_col]) # this needs to be turned to a matrix so dimensions get broadcast correctly anno_count = np.asmatrix(pbs).T.dot(anno_dummies.values) anno_frac = np.array(anno_count / anno_count.sum(1)) anno_frac = pd.DataFrame( anno_frac, index=np.arange(pbs.shape[1]), columns=anno_dummies.columns, ) pbs_obs[anno_col] = anno_frac.idxmax(1) pbs_obs[anno_col + "_fraction"] = anno_frac.max(1) # report the number of cells for each pseudobulk # ensure pbs is an array so that it sums into a vector that can go in easily pbs_obs["cell_count"] = np.sum(np.asarray(pbs), axis=0) return pbs_obs
[docs] def vdj_pseudobulk( adata: AnnData, pbs: np.ndarray | sp.sparse.csr_matrix | None = None, obs_to_bulk: list[str] | str | None = None, obs_to_take: list[str] | str | None = None, normalise: bool = True, renormalise: bool = False, min_count: int = 1, mode: Literal["B", "abT", "gdT"] | None = "abT", extract_cols: list[str] | None = [ "v_call_abT_VDJ_main", "j_call_abT_VDJ_main", "v_call_abT_VJ_main", "j_call_abT_VJ_main", ], ) -> AnnData: """Function for making pseudobulk vdj feature space. One of `pbs` or `obs_to_bulk` needs to be specified when calling. Parameters ---------- adata : AnnData Cell adata, preferably after `ddl.tl.setup_vdj_pseudobulk()` pbs : np.ndarray | sp.sparse.csr_matrix | None, optional Optional binary matrix with cells as rows and pseudobulk groups as columns obs_to_bulk : list[str] | str | None, optional Optional obs column(s) to group pseudobulks into; if multiple are provided, they will be combined obs_to_take : list[str] | str | None, optional Optional obs column(s) to identify the most common value of for each pseudobulk. normalise : bool, optional If True, will scale the counts of each V(D)J gene group to 1 for each pseudobulk. renormalise : bool, optional If True, will re-scale the counts of each V(D)J gene group to 1 for each pseudobulk with any "missing" calls removed. Relevant with `normalise` as True, if `setup_vdj_pseudobulk()` was ran with `remove_missing` set to False. min_count : int, optional Pseudobulks with fewer than these many non-"missing" calls in a V(D)J gene group will have their non-"missing" calls set to 0 for that group. Relevant with `normalise` as True. mode : Literal["B", "abT", "gdT"] | None, optional Optional mode for extracting the V(D)J genes. If set as `None`, it will use e.g. `v_call_VDJ` instead of `v_call_abT_VDJ`. If `extract_cols` is provided, then this argument is ignored. extract_cols : list[str] | None, optional Column names where VDJ/VJ information is stored so that this will be used instead of the standard columns. Returns ------- AnnData pb_adata, whereby each observation is a pseudobulk:\n VDJ usage frequency/counts stored in pb_adata.X\n VDJ genes stored in pb_adata.var\n pseudobulk metadata stored in pb_adata.obs\n pseudobulk assignment (binary matrix with input cells as columns) stored in pb_adata.obsm['pbs']\n """ # get our cells by pseudobulks matrix pbs = _get_pbs(pbs, obs_to_bulk, adata) # if not specified by the user, use the following default dandelion VJ columns if extract_cols is None: if mode is None: extract_cols = [ i for i in adata.obs if re.search( "|".join( [ "_call_VDJ_main", "_call_VJ_main", ] ), i, ) ] else: extract_cols = [ i for i in adata.obs if re.search( "|".join([mode + "_VDJ_main", mode + "_VJ_main"]), i ) ] # perform matrix multiplication of pseudobulks by cells matrix by a cells by VJs matrix # start off by creating the cell by VJs matrix, using the pandas dummies again # skip the prefix stuff as the VJ genes will be unique in the columns vjs = pd.get_dummies(adata.obs[extract_cols], prefix="", prefix_sep="") # TODO: DENAN SOMEHOW? AS IN NAN GENES? # can now multiply transposed pseudobulk assignments by this vjs thing, turn to df vj_pb_count = pbs.T.dot(vjs.values) df = pd.DataFrame( vj_pb_count, index=np.arange(pbs.shape[1]), columns=vjs.columns ) if normalise: # identify any missing calls inserted by the setup, will end with _missing # negate as we want to actually remove them later defined_mask = ~(df.columns.str.endswith("_missing")) # loop over V(D)J gene categories for col in extract_cols: # identify columns holding genes belonging to the category # and then normalise the values to 1 for each pseudobulk group_mask = np.isin(df.columns, np.unique(adata.obs[col])) # identify the defined (non-missing) calls for the group group_defined_mask = group_mask & defined_mask # compute sum of non-missing values for each pseudobulk for this category # and compare to the min_count defined_min_count = ( df.loc[:, group_defined_mask].sum(axis=1) >= min_count ) # we can now normalise for the pseudobulks, for now all the pseudobulks df.loc[:, group_mask] = df.loc[:, group_mask].div( df.loc[:, group_mask].sum(axis=1), axis=0 ) if renormalise: # repeat the normalisation for non-missing values only # and only use pseudobulks crossing the min_count threshold df.loc[defined_min_count, group_defined_mask] = df.loc[ defined_min_count, group_defined_mask ].div( df.loc[defined_min_count, group_defined_mask].sum(axis=1), axis=0, ) # we can now mask the pseudobulks with insufficient defined counts df.loc[~defined_min_count, group_defined_mask] = 0 # create obs for the new pseudobulk object pbs_obs = _get_pbs_obs(pbs, obs_to_take, adata) # store our feature space and derived metadata into an AnnData pb_adata = AnnData( np.array(df), var=pd.DataFrame(index=df.columns), obs=pbs_obs ) # store the pseudobulk assignments, as a sparse for storage efficiency # transpose as the original matrix is cells x pseudobulks pb_adata.obsm["pbs"] = sp.sparse.csr_matrix(pbs.T) return pb_adata
[docs] def pseudotime_transfer( adata: AnnData, pr_res: PResults, suffix: str = "" ) -> AnnData: """Function to add pseudotime and branch probabilities into adata.obs in place. Parameters ---------- adata : AnnData adata for which pseudotime to be transferred to pr_res : PResults palantir pseudotime inference output object suffix : str, optional suffix to be added after the added column names, default "" (none) Returns ------- AnnData transferred AnnData. """ adata.obs["pseudotime" + suffix] = pr_res.pseudotime.copy() for col in pr_res.branch_probs.columns: adata.obs["prob_" + col + suffix] = pr_res.branch_probs[col].copy() return adata
[docs] def project_pseudotime_to_cell( adata: AnnData, pb_adata: AnnData, term_states: list[str], suffix: str = "" ) -> AnnData: """Function to project pseudotime & branch probabilities from pb_adata (pseudobulk adata) to adata (cell adata). Parameters ---------- adata : AnnData Cell adata, preferably after `ddl.tl.setup_vdj_pseudobulk()` pb_adata : AnnData neighbourhood/pseudobulked adata term_states : list[str] list of terminal states with branch probabilities to be transferred suffix : str, optional suffix to be added after the added column names, default "" (none) Returns ------- AnnData subset of adata whereby cells that don't belong to any neighbourhood are removed and projected pseudotime information stored in .obs - `pseudotime+suffix`, and `'prob_'+term_state+suffix` for each terminal state """ # extract out cell x pseudobulk matrix. it's stored as pseudobulk x cell so transpose nhoods = np.array(pb_adata.obsm["pbs"].T.todense()) # leave out cells that don't belong to any neighbourhood nhoodsum = np.sum(nhoods, axis=1) cdata = adata[nhoodsum > 0].copy() logg.info( f"number of cells removed due to not belonging to any neighbourhood: {sum(nhoodsum == 0),}", ) # print how many cells removed # also subset the pseudobulk_assignments pb_assign_trim = nhoods[nhoodsum > 0] # for each cell pseudotime_mean is the average of the pseudotime of all pseudobulks the cell is in, weighted by 1/neighbourhood size nhoods_cdata = nhoods[nhoodsum > 0, :] nhoods_cdata_norm = nhoods_cdata / np.sum( nhoods_cdata, axis=0, keepdims=True ) col_list = ["pseudotime" + suffix] + [ "prob_" + state + suffix for state in term_states ] for col in col_list: cdata.obs[col] = ( np.array( nhoods_cdata_norm.dot(pb_adata.obs[col]).T / np.sum(nhoods_cdata_norm, axis=1) ) .flatten() .copy() ) cdata.uns["pseudobulk_assignments"] = pb_assign_trim.copy() return cdata
[docs] def pseudobulk_gex( adata_raw: AnnData, pbs: np.ndarray | sp.sparse.csr_matrix | None = None, obs_to_bulk: list[str] | str | None = None, obs_to_take: list[str] | str | None = None, ) -> AnnData: """Function to pseudobulk gene expression (raw count). Parameters ---------- adata_raw : AnnData Needs to have raw counts in .X pbs : np.ndarray | sp.sparse.csr_matrix | None, optional Optional binary matrix with cells as rows and pseudobulk groups as columns obs_to_bulk : list[str] | str | None, optional Optional obs column(s) to group pseudobulks into; if multiple are provided, they will be combined obs_to_take : list[str] | str | None, optional Optional obs column(s) to identify the most common value of for each pseudobulk Returns ------- AnnData pb_adata whereby each observation is a cell neighbourhood\n pseudobulked gene expression stored in pb_adata.X\n genes stored in pb_adata.var\n pseudobulk metadata stored in pb_adata.obs\n pseudobulk assignment (binary matrix with input cells as columns) stored in pb_adata.obsm['pbs']\n """ # get our cells by pseudobulks matrix pbs = _get_pbs(pbs, obs_to_bulk, adata_raw) # make pseudobulk matrix pbs_X = adata_raw.X.T.dot(pbs) # create obs for the new pseudobulk object pbs_obs = _get_pbs_obs(pbs, obs_to_take, adata_raw) ## Make new anndata object pb_adata = AnnData(pbs_X.T, obs=pbs_obs, var=adata_raw.var) # store the pseudobulk assignments, as a sparse for storage efficiency # transpose as the original matrix is cells x pseudobulks pb_adata.obsm["pbs"] = sp.sparse.csr_matrix(pbs.T) return pb_adata
# def bin_expression( # adata: AnnData, bin_no: int, genes: list[str], pseudotime_col: str # ) -> pd.DataFrame: # """Function to compute average gene expression in bins along pseudotime. # Parameters # ---------- # adata : AnnData # cell adata. # bin_no : int # number of bins to be divided along pseudotime. # genes : list[str] # list of genes for the computation # pseudotime_col : str # column in adata.obs where pseudotime is stored # Returns # ------- # pd.DataFrame # a data frame with genes as rows, and pseudotime bins as columns, and averaged gene expression as the data # """ # # define bins # bins = np.linspace(0, 1, bin_no + 1) # # get gene expression # x = np.array(adata[:, genes].X.todense()) # # get pseudotime # y = np.array(adata.obs[pseudotime_col]) # # calculate average gene expression in each bin # gene_summary = pd.DataFrame(columns=bins[:-1], index=genes) # for i in range(gene_summary.shape[1]): # time = bins[i] # select = np.array(bins[i] <= y) & np.array(y < bins[i + 1]) # gene_summary.loc[:, time] = np.mean(x[select, :], axis=0) # return gene_summary # def chatterjee_corr( # adata: AnnData, genes: list[str], pseudotime_col: str # ) -> pd.DataFrame: # """Function to compute chatterjee correlation of gene expression with pseudotime. # Parameters # ---------- # adata : AnnData # cell adata # genes : list[str] # List of genes selected to compute the correlation # pseudotime_col : str # column in adata.obs where pseudotime is stored # Returns # ------- # pd.DataFrame # a data frame with genes as rows, with cor_res (correlation statistics), # pval (p-value), adj_pval (p-value adjusted by BH method) as columns. # """ # # get gene expression # x = np.array(adata[:, genes].X.todense()) # # add small perturbation for random tie breaking # x = x + np.random.randn(x.shape[0], x.shape[1]) * 1e-15 # # get pseudotime # y = list(adata.obs[pseudotime_col]) # # compute chatterjee correlation # # ref: Sourav Chatterjee (2021) A New Coefficient of Correlation, Journal of the American Statistical Association, 116:536, 2009-2022, DOI: 10.1080/01621459.2020.1758115 # stat = 1 - np.sum( # np.abs(np.diff(np.argsort(x[np.argsort(y), :], axis=0), axis=0)), axis=0 # ) * 3 / (x.shape[0] ** 2 - 1) # stat = np.array(stat).flatten() # pval = 1 - sp.stats.norm.cdf(stat, loc=0, scale=np.sqrt(2 / 5 / x.shape[0])) # # put results into data frame cor_res # cor_res = pd.DataFrame({"cor_stat": stat, "pval": pval}) # cor_res.index = genes # # compute adjusted pval using BH method # cor_res["adj_pval"] = bh(cor_res["pval"].to_numpy()) # # sort genes based on adjusted pval # cor_res = cor_res.sort_values(by="adj_pval") # return cor_res