#!/usr/bin/env python
# @author: chenqu, kp9, kelvin
from __future__ import annotations
import re
from anndata import AnnData
import numpy as np
import pandas as pd
import polars as pl
import scipy as sp
from contextlib import contextmanager
from scanpy import logging as logg
from typing import Literal
from dandelion.polars.core._core import DandelionPolars
@contextmanager
def _trajectory_context(
adata: AnnData,
vdj: DandelionPolars,
mode: Literal["B", "abT", "gdT"] | None,
extract_cols: list[str] | None = None,
productive_cols: list[str] | None = None,
):
"""
Context manager that temporarily adds VDJ/VJ gene and productive columns to adata.obs
for trajectory analysis, then removes them upon exit.
Parameters
----------
adata : AnnData
AnnData object to add columns to
vdj : DandelionPolars
Dandelion object containing VDJ data
mode : Literal["B", "abT", "gdT"] | None
Cell type mode
extract_cols : list[str] | None
Custom column names to extract instead of standard mode-based columns
productive_cols : list[str] | None
Custom productive column names
"""
# Track original obs
original_obs = adata.obs.copy()
# change from lazy to eager to avoid issues with context manager and polars dataframes
if isinstance(vdj._data, pl.LazyFrame) or isinstance(
vdj._metadata, pl.LazyFrame
):
vdj.to_eager()
original_lazy = True
else:
original_lazy = False
try:
# Add chain_status if available in vdj metadata
if vdj._metadata is not None:
metadata_pd = vdj._metadata.to_pandas().set_index("cell_id")
if "chain_status" in metadata_pd.columns:
adata.obs["chain_status"] = metadata_pd["chain_status"]
if mode is not None:
# Extract productive status for VDJ and VJ loci
productive_vdj = vdj._split_first(
cols="productive", key_added="productive", celltype=mode
)
productive_vdj_pd = productive_vdj.to_pandas().set_index("cell_id")
# Add productive columns
if "productive_VDJ" in productive_vdj_pd.columns:
adata.obs[f"productive_{mode}_VDJ"] = productive_vdj_pd[
"productive_VDJ"
]
if "productive_VJ" in productive_vdj_pd.columns:
adata.obs[f"productive_{mode}_VJ"] = productive_vdj_pd[
"productive_VJ"
]
# Extract v_call, d_call, j_call for VDJ and VJ
v_call = vdj._split_first(
cols="v_call", key_added="v_call", celltype=mode
)
d_call = vdj._split_first(
cols="d_call", key_added="d_call", celltype=mode
)
j_call = vdj._split_first(
cols="j_call", key_added="j_call", celltype=mode
)
# Merge all splits into one dataframe
merged = (
v_call.drop("celltype_group", strict=False)
.join(
d_call.drop("celltype_group", strict=False),
on="cell_id",
how="left",
)
.join(
j_call.drop("celltype_group", strict=False),
on="cell_id",
how="left",
)
)
# Convert to pandas and set index
merged_pd = merged.to_pandas().set_index("cell_id")
# Add all columns to adata.obs with mode suffix
for col in merged_pd.columns:
# Rename columns to match expected format
if col.endswith("_VDJ") or col.endswith("_VJ"):
new_col = col.replace("v_call_", f"v_call_{mode}_")
new_col = new_col.replace("d_call_", f"d_call_{mode}_")
new_col = new_col.replace("j_call_", f"j_call_{mode}_")
adata.obs[new_col] = merged_pd[col]
else:
# Mode is None - use extract_cols if provided or default columns
if extract_cols is None:
# Default to generic VDJ/VJ columns
v_call = vdj._split_first(cols="v_call", key_added="v_call")
d_call = vdj._split_first(cols="d_call", key_added="d_call")
j_call = vdj._split_first(cols="j_call", key_added="j_call")
merged = v_call.join(d_call, on="cell_id", how="left").join(
j_call, on="cell_id", how="left"
)
merged_pd = merged.to_pandas().set_index("cell_id")
for col in merged_pd.columns:
adata.obs[col] = merged_pd[col]
else:
# User provided custom extract_cols
# Extract each column from vdj
for col in extract_cols:
# Remove the mode suffix if present to get the base column name
base_col = col
for prefix in ["_B_", "_abT_", "_gdT_"]:
if prefix in col:
base_col = col.split(prefix)[0]
break
# Try to extract this column from vdj
if base_col in ["v_call", "d_call", "j_call"]:
extracted = vdj._split_first(
cols=base_col, key_added=base_col
)
extracted_pd = extracted.to_pandas().set_index(
"cell_id"
)
for extracted_col in extracted_pd.columns:
adata.obs[extracted_col] = extracted_pd[
extracted_col
]
# Handle productive columns if provided
if productive_cols is not None:
productive = vdj._split_first(
cols="productive", key_added="productive"
)
productive_pd = productive.to_pandas().set_index("cell_id")
for col in productive_pd.columns:
adata.obs[col] = productive_pd[col]
yield adata
finally:
# Clean up: restore original obs
adata.obs = original_obs
# Restore original laziness of vdj if we changed it
if original_lazy:
vdj.to_lazy()
def _filter_cells(
adata: AnnData,
col: str,
filter_pattern: str | None = ",|None|No_contig",
remove_missing: bool = True,
) -> AnnData:
"""
Helper function that identifies filter_pattern hits in `.obs[col]` of adata, and then either removes the
offending cells or masks the matched values with a uniform value of `col+"_missing"`.
"""
# find filter pattern hits in our column of interest
mask = (
adata.obs[col].astype("string").str.contains(filter_pattern, na=False)
)
mask_arr = mask.to_numpy(dtype=bool)
if remove_missing:
# remove the cells
adata = adata[~mask_arr].copy()
else:
# uniformly mask the filter pattern hits
adata.obs.loc[mask, col] = col + "_missing"
return adata
[docs]
def setup_vdj_pseudobulk(
adata: AnnData,
vdj: DandelionPolars,
mode: Literal["B", "abT", "gdT"] | None = "abT",
subsetby: str | None = None,
groups: list[str] | None = None,
allowed_chain_status: list[str] | None = [
"Single pair",
"Extra pair",
"Extra pair-exception",
"Orphan VDJ",
"Orphan VDJ-exception",
],
productive_vdj: bool = True,
productive_vj: bool = True,
extract_cols: list[str] | None = None,
productive_cols: list[str] | None = None,
check_vdj_mapping: list[Literal["v_call", "d_call", "j_call"]] | None = [
"v_call",
"j_call",
],
check_vj_mapping: list[Literal["v_call", "j_call"]] | None = [
"v_call",
"j_call",
],
check_extract_cols_mapping: list[str] | None = None,
filter_pattern: str | None = ",|None|No_contig",
remove_missing: bool = True,
) -> AnnData:
"""Function to prepare AnnData for computing pseudobulk vdj feature space.
Parameters
----------
adata : AnnData
cell adata before constructing anndata.
vdj : DandelionPolars
Dandelion object containing VDJ data
mode : Literal["B", "abT", "gdT"] | None, optional
Optional mode for extractin the V(D)J genes. If set as `None`, it requires the option `extract_cols` to be
specified with a list of column names where this will be used to retrieve the main call.
subsetby : str | None, optional
If provided, only the groups/categories in this column will be used for computing the VDJ feature space.
groups : list[str] | None, optional
If provided, only the following groups/categories will be used for computing the VDJ feature space.
allowed_chain_status : list[str] | None, optional
If provided, only the ones in this list are kept from the `chain_status` column.
productive_vdj : bool, optional
If True, cells will only be kept if the main VDJ chain is productive.
productive_vj : bool, optional
If True, cells will only be kept if the main VJ chain is productive.
extract_cols : list[str] | None, optional
Column names where VDJ/VJ information is stored so that this will be used instead of the standard columns.
productive_cols : list[str] | None, optional
Column names where contig productive status is stored so that this will be used instead of the standard columns.
check_vdj_mapping : list[Literal["v_call", "d_call", "j_call"]] | None, optional
Only columns in the argument will be checked for unclear mapping (containing comma) in VDJ calls.
Specifying None will skip this step.
check_vj_mapping : list[Literal["v_call", "j_call"]] | None, optional
Only columns in the argument will be checked for unclear mapping (containing comma) in VJ calls.
Specifying None will skip this step.
check_extract_cols_mapping : list[str] | None, optional
Only columns in the argument will be checked for unclear mapping (containing comma) in columns specified in extract_cols.
Specifying None will skip this step.
filter_pattern : str | None, optional
pattern to filter from object. If `None`, does not filter.
remove_missing : bool, optional
If True, will remove cells with contigs matching the filter from the object. If False, will mask them with a uniform
value dependent on the column name.
Returns
-------
AnnData
filtered cell adata object.
"""
# Use context manager to temporarily add VDJ columns
with _trajectory_context(
adata, vdj, mode, extract_cols, productive_cols
) as adata_:
# keep ony relevant cells (ones with a pair of chains) based on productive column
if mode is not None:
if productive_vdj:
mask_vdj = (
adata_.obs["productive_" + mode + "_VDJ"]
.astype("string")
.str.startswith("T", na=False)
.to_numpy(dtype=bool)
)
adata_ = adata_[mask_vdj].copy()
if productive_vj:
mask_vj = (
adata_.obs["productive_" + mode + "_VJ"]
.astype("string")
.str.startswith("T", na=False)
.to_numpy(dtype=bool)
)
adata_ = adata_[mask_vj].copy()
else:
if productive_cols is not None:
for col in productive_cols:
mask_col = (
adata_.obs[col]
.astype("string")
.str.startswith("T", na=False)
.to_numpy(dtype=bool)
)
adata_ = adata_[mask_col].copy()
if any([re.search("_VDJ_main|_VJ_main", i) for i in adata_.obs]):
if check_vdj_mapping is not None:
if not isinstance(check_vdj_mapping, list):
check_vdj_mapping = [check_vdj_mapping]
if check_vj_mapping is not None:
if not isinstance(check_vj_mapping, list):
check_vj_mapping = [check_vj_mapping]
if allowed_chain_status is not None:
adata_ = adata_[
adata_.obs["chain_status"].isin(allowed_chain_status)
].copy()
if (groups is not None) and (subsetby is not None):
adata_ = adata_[adata_.obs[subsetby].isin(groups)].copy()
if extract_cols is None:
if mode is not None:
adata_.obs["v_call_" + mode + "_VDJ_main"] = [
x.split("|")[0] if str(x) != "None" else "None"
for x in adata_.obs["v_call_" + mode + "_VDJ"]
]
adata_.obs["d_call_" + mode + "_VDJ_main"] = [
x.split("|")[0] if str(x) != "None" else "None"
for x in adata_.obs["d_call_" + mode + "_VDJ"]
]
adata_.obs["j_call_" + mode + "_VDJ_main"] = [
x.split("|")[0] if str(x) != "None" else "None"
for x in adata_.obs["j_call_" + mode + "_VDJ"]
]
adata_.obs["v_call_" + mode + "_VJ_main"] = [
x.split("|")[0] if str(x) != "None" else "None"
for x in adata_.obs["v_call_" + mode + "_VJ"]
]
adata_.obs["j_call_" + mode + "_VJ_main"] = [
x.split("|")[0] if str(x) != "None" else "None"
for x in adata_.obs["j_call_" + mode + "_VJ"]
]
else:
adata_.obs["v_call_VDJ_main"] = [
x.split("|")[0] if str(x) != "None" else "None"
for x in adata_.obs["v_call_VDJ"]
]
adata_.obs["d_call_VDJ_main"] = [
x.split("|")[0] if str(x) != "None" else "None"
for x in adata_.obs["d_call_VDJ"]
]
adata_.obs["j_call_VDJ_main"] = [
x.split("|")[0] if str(x) != "None" else "None"
for x in adata_.obs["j_call_VDJ"]
]
adata_.obs["v_call_VJ_main"] = [
x.split("|")[0] if str(x) != "None" else "None"
for x in adata_.obs["v_call_VJ"]
]
adata_.obs["j_call_VJ_main"] = [
x.split("|")[0] if str(x) != "None" else "None"
for x in adata_.obs["j_call_VJ"]
]
else:
for col in extract_cols:
adata_.obs[col + "_main"] = [
x.split("|")[0] if str(x) != "None" else "None"
for x in adata_.obs[col]
]
# remove any cells if there's unclear mapping
if filter_pattern is not None:
if mode is not None:
if check_vdj_mapping is not None:
for col in check_vdj_mapping:
adata_ = _filter_cells(
adata=adata_,
col=col + "_" + mode + "_VDJ_main",
filter_pattern=filter_pattern,
remove_missing=remove_missing,
)
if check_vj_mapping is not None:
for col in check_vj_mapping:
adata_ = _filter_cells(
adata=adata_,
col=col + "_" + mode + "_VJ_main",
filter_pattern=filter_pattern,
remove_missing=remove_missing,
)
else:
if extract_cols is None:
if check_vdj_mapping is not None:
for col in check_vdj_mapping:
adata_ = _filter_cells(
adata=adata_,
col=col + "_VDJ_main",
filter_pattern=filter_pattern,
remove_missing=remove_missing,
)
if check_vj_mapping is not None:
for col in check_vj_mapping:
adata_ = _filter_cells(
adata=adata_,
col=col + "_VJ_main",
filter_pattern=filter_pattern,
remove_missing=remove_missing,
)
else:
if check_extract_cols_mapping is not None:
for col in check_extract_cols_mapping:
adata_ = _filter_cells(
adata=adata_,
col=col + "_main",
filter_pattern=filter_pattern,
remove_missing=remove_missing,
)
# Return a copy to ensure modifications persist outside context manager
return adata_.copy()
def _get_pbs(
pbs: np.ndarray | sp.sparse.csr_matrix | None,
obs_to_bulk: str | None,
adata: AnnData,
) -> np.ndarray:
"""
Helper function to ensure we have a cells by pseudobulks matrix which we can use for
pseudobulking. Uses the pbs and obs_to_bulk inputs to vdj_pseudobulk() and
gex_pseudobulk().
"""
# well, we need some way to pseudobulk
if pbs is None and obs_to_bulk is None:
raise ValueError(
"You need to specify `pbs` or `obs_to_bulk` when calling the function"
)
# but just one
if pbs is not None and obs_to_bulk is not None:
raise ValueError("You need to specify `pbs` or `obs_to_bulk`, not both")
# turn the pseubodulk matrix dense if need be
if pbs is not None:
if sp.sparse.issparse(pbs):
pbs = pbs.todense()
# get the obs-derived pseudobulk
if obs_to_bulk is not None:
if type(obs_to_bulk) is list:
# this will create a single value by pasting all the columns together
tobulk = adata.obs[obs_to_bulk].T.astype(str).agg(",".join)
else:
# we just have a single column
tobulk = adata.obs[obs_to_bulk]
# this pandas function creates the exact pseudobulk assignment we want
# this needs to be different than the default uint8
# as you can have more than 255 cells in a pseudobulk, it turns out
pbs = pd.get_dummies(tobulk, dtype="uint16").values
return pbs
def _get_pbs_obs(
pbs: np.ndarray, obs_to_take: str | None, adata: AnnData
) -> pd.DataFrame:
"""
Helper function to create the pseudobulk object's obs. Uses the pbs and obs_to_take
inputs to vdj_pseudobulk() and gex_pseudobulk().
"""
# prepare per-pseudobulk calls of specified metadata columns
pbs_obs = pd.DataFrame(index=np.arange(pbs.shape[1]))
if obs_to_take is not None:
# just in case a single is passed as a string
if type(obs_to_take) is not list:
obs_to_take = [obs_to_take]
# now we can iterate over this nicely
# using the logic of milopy's annotate_nhoods()
for anno_col in obs_to_take:
anno_dummies = pd.get_dummies(adata.obs[anno_col])
# this needs to be turned to a matrix so dimensions get broadcast correctly
anno_count = np.asmatrix(pbs).T.dot(anno_dummies.values)
anno_frac = np.array(anno_count / anno_count.sum(1))
anno_frac = pd.DataFrame(
anno_frac,
index=np.arange(pbs.shape[1]),
columns=anno_dummies.columns,
)
pbs_obs[anno_col] = anno_frac.idxmax(1)
pbs_obs[anno_col + "_fraction"] = anno_frac.max(1)
# report the number of cells for each pseudobulk
# ensure pbs is an array so that it sums into a vector that can go in easily
pbs_obs["cell_count"] = np.sum(np.asarray(pbs), axis=0)
return pbs_obs
[docs]
def vdj_pseudobulk(
adata: AnnData,
vdj: DandelionPolars | None = None,
pbs: np.ndarray | sp.sparse.csr_matrix | None = None,
obs_to_bulk: list[str] | str | None = None,
obs_to_take: list[str] | str | None = None,
normalise: bool = True,
renormalise: bool = False,
min_count: int = 1,
mode: Literal["B", "abT", "gdT"] | None = "abT",
extract_cols: list[str] | None = [
"v_call_abT_VDJ_main",
"j_call_abT_VDJ_main",
"v_call_abT_VJ_main",
"j_call_abT_VJ_main",
],
) -> AnnData:
"""Function for making pseudobulk vdj feature space. One of `pbs` or `obs_to_bulk`
needs to be specified when calling.
Parameters
----------
adata : AnnData
Cell adata, preferably after `ddl.tl.setup_vdj_pseudobulk()`
vdj : DandelionPolars | None, optional
Dandelion object containing VDJ data. Only needed if columns are not already in adata.obs
pbs : np.ndarray | sp.sparse.csr_matrix | None, optional
Optional binary matrix with cells as rows and pseudobulk groups as columns
obs_to_bulk : list[str] | str | None, optional
Optional obs column(s) to group pseudobulks into; if multiple are provided, they
will be combined
obs_to_take : list[str] | str | None, optional
Optional obs column(s) to identify the most common value of for each pseudobulk.
normalise : bool, optional
If True, will scale the counts of each V(D)J gene group to 1 for each pseudobulk.
renormalise : bool, optional
If True, will re-scale the counts of each V(D)J gene group to 1 for each pseudobulk with any "missing" calls removed.
Relevant with `normalise` as True, if `setup_vdj_pseudobulk()` was ran with `remove_missing` set to False.
min_count : int, optional
Pseudobulks with fewer than these many non-"missing" calls in a V(D)J gene group will have their non-"missing" calls
set to 0 for that group. Relevant with `normalise` as True.
mode : Literal["B", "abT", "gdT"] | None, optional
Optional mode for extracting the V(D)J genes. If set as `None`, it will use e.g. `v_call_VDJ` instead of `v_call_abT_VDJ`.
If `extract_cols` is provided, then this argument is ignored.
extract_cols : list[str] | None, optional
Column names where VDJ/VJ information is stored so that this will be used instead of the standard columns.
Returns
-------
AnnData
pb_adata, whereby each observation is a pseudobulk:\n
VDJ usage frequency/counts stored in pb_adata.X\n
VDJ genes stored in pb_adata.var\n
pseudobulk metadata stored in pb_adata.obs\n
pseudobulk assignment (binary matrix with input cells as columns) stored in pb_adata.obsm['pbs']\n
Raises
------
ValueError
if neither ``pbs`` nor ``obs_to_bulk`` is specified, or if both are specified.
ValueError
if required VDJ columns are not in ``adata.obs`` and ``vdj`` is not provided.
"""
# Check if we need to use context manager
needs_context = False
if extract_cols is not None:
# Check if any of the extract_cols are missing from adata.obs
needs_context = any(
col not in adata.obs.columns for col in extract_cols
)
elif mode is not None:
# Check for mode-specific columns
expected_cols = [
f"v_call_{mode}_VDJ_main",
f"j_call_{mode}_VDJ_main",
f"v_call_{mode}_VJ_main",
f"j_call_{mode}_VJ_main",
]
needs_context = any(
col not in adata.obs.columns for col in expected_cols
)
if needs_context:
if vdj is None:
raise ValueError(
"vdj parameter must be provided when required columns are not in adata.obs"
)
# Use context manager to add required columns
with _trajectory_context(adata, vdj, mode, extract_cols) as adata_:
return _vdj_pseudobulk_impl(
adata_,
pbs,
obs_to_bulk,
obs_to_take,
normalise,
renormalise,
min_count,
mode,
extract_cols,
)
else:
# Columns already exist, no need for context manager
return _vdj_pseudobulk_impl(
adata,
pbs,
obs_to_bulk,
obs_to_take,
normalise,
renormalise,
min_count,
mode,
extract_cols,
)
def _vdj_pseudobulk_impl(
adata: AnnData,
pbs: np.ndarray | sp.sparse.csr_matrix | None,
obs_to_bulk: list[str] | str | None,
obs_to_take: list[str] | str | None,
normalise: bool,
renormalise: bool,
min_count: int,
mode: Literal["B", "abT", "gdT"] | None,
extract_cols: list[str] | None,
) -> AnnData:
"""Implementation of vdj_pseudobulk that operates on adata with required columns."""
# get our cells by pseudobulks matrix
pbs = _get_pbs(pbs, obs_to_bulk, adata)
# if not specified by the user, use the following default dandelion VJ columns
if extract_cols is None:
if mode is None:
extract_cols = [
i
for i in adata.obs
if re.search(
"|".join(
[
"_call_VDJ_main",
"_call_VJ_main",
]
),
i,
)
]
else:
extract_cols = [
i
for i in adata.obs
if re.search(
"|".join([mode + "_VDJ_main", mode + "_VJ_main"]), i
)
]
# perform matrix multiplication of pseudobulks by cells matrix by a cells by VJs matrix
# start off by creating the cell by VJs matrix, using the pandas dummies again
# skip the prefix stuff as the VJ genes will be unique in the columns
vjs = pd.get_dummies(adata.obs[extract_cols], prefix="", prefix_sep="")
# TODO: DENAN SOMEHOW? AS IN NAN GENES?
# can now multiply transposed pseudobulk assignments by this vjs thing, turn to df
vj_pb_count = pbs.T.dot(vjs.values)
df = pd.DataFrame(
vj_pb_count, index=np.arange(pbs.shape[1]), columns=vjs.columns
)
if normalise:
# identify any missing calls inserted by the setup, will end with _missing
# negate as we want to actually remove them later
defined_mask = ~(df.columns.str.endswith("_missing"))
# loop over V(D)J gene categories
for col in extract_cols:
# identify columns holding genes belonging to the category
# and then normalise the values to 1 for each pseudobulk
group_mask = np.isin(df.columns, np.unique(adata.obs[col]))
# identify the defined (non-missing) calls for the group
group_defined_mask = group_mask & defined_mask
# compute sum of non-missing values for each pseudobulk for this category
# and compare to the min_count
defined_min_count = (
df.loc[:, group_defined_mask].sum(axis=1) >= min_count
)
# we can now normalise for the pseudobulks, for now all the pseudobulks
df.loc[:, group_mask] = df.loc[:, group_mask].div(
df.loc[:, group_mask].sum(axis=1), axis=0
)
if renormalise:
# repeat the normalisation for non-missing values only
# and only use pseudobulks crossing the min_count threshold
df.loc[defined_min_count, group_defined_mask] = df.loc[
defined_min_count, group_defined_mask
].div(
df.loc[defined_min_count, group_defined_mask].sum(axis=1),
axis=0,
)
# we can now mask the pseudobulks with insufficient defined counts
df.loc[~defined_min_count, group_defined_mask] = 0
# create obs for the new pseudobulk object
pbs_obs = _get_pbs_obs(pbs, obs_to_take, adata)
# store our feature space and derived metadata into an AnnData
pb_adata = AnnData(
np.array(df), var=pd.DataFrame(index=df.columns), obs=pbs_obs
)
# store the pseudobulk assignments, as a sparse for storage efficiency
# transpose as the original matrix is cells x pseudobulks
pb_adata.obsm["pbs"] = sp.sparse.csr_matrix(pbs.T)
return pb_adata
[docs]
def pseudotime_transfer(
adata: AnnData, pr_res: PResults, suffix: str = ""
) -> AnnData:
"""Function to add pseudotime and branch probabilities into adata.obs in place.
Parameters
----------
adata : AnnData
adata for which pseudotime to be transferred to
pr_res : PResults
palantir pseudotime inference output object
suffix : str, optional
suffix to be added after the added column names, default "" (none)
Returns
-------
AnnData
transferred AnnData.
"""
adata.obs["pseudotime" + suffix] = pr_res.pseudotime.copy()
for col in pr_res.branch_probs.columns:
adata.obs["prob_" + col + suffix] = pr_res.branch_probs[col].copy()
return adata
[docs]
def project_pseudotime_to_cell(
adata: AnnData, pb_adata: AnnData, term_states: list[str], suffix: str = ""
) -> AnnData:
"""Function to project pseudotime & branch probabilities from pb_adata (pseudobulk adata) to adata (cell adata).
Parameters
----------
adata : AnnData
Cell adata, preferably after `ddl.tl.setup_vdj_pseudobulk()`
pb_adata : AnnData
neighbourhood/pseudobulked adata
term_states : list[str]
list of terminal states with branch probabilities to be transferred
suffix : str, optional
suffix to be added after the added column names, default "" (none)
Returns
-------
AnnData
subset of adata whereby cells that don't belong to any neighbourhood are removed
and projected pseudotime information stored in .obs - `pseudotime+suffix`, and `'prob_'+term_state+suffix` for each terminal state
"""
# extract out cell x pseudobulk matrix. it's stored as pseudobulk x cell so transpose
nhoods = np.array(pb_adata.obsm["pbs"].T.todense())
# leave out cells that don't belong to any neighbourhood
nhoodsum = np.sum(nhoods, axis=1)
cdata = adata[nhoodsum > 0].copy()
logg.info(
f"number of cells removed due to not belonging to any neighbourhood: {sum(nhoodsum == 0),}",
) # print how many cells removed
# also subset the pseudobulk_assignments
pb_assign_trim = nhoods[nhoodsum > 0]
# for each cell pseudotime_mean is the average of the pseudotime of all pseudobulks the cell is in, weighted by 1/neighbourhood size
nhoods_cdata = nhoods[nhoodsum > 0, :]
nhoods_cdata_norm = nhoods_cdata / np.sum(
nhoods_cdata, axis=0, keepdims=True
)
col_list = ["pseudotime" + suffix] + [
"prob_" + state + suffix for state in term_states
]
for col in col_list:
cdata.obs[col] = (
np.array(
nhoods_cdata_norm.dot(pb_adata.obs[col]).T
/ np.sum(nhoods_cdata_norm, axis=1)
)
.flatten()
.copy()
)
cdata.uns["pseudobulk_assignments"] = pb_assign_trim.copy()
return cdata
[docs]
def pseudobulk_gex(
adata_raw: AnnData,
pbs: np.ndarray | sp.sparse.csr_matrix | None = None,
obs_to_bulk: list[str] | str | None = None,
obs_to_take: list[str] | str | None = None,
) -> AnnData:
"""Function to pseudobulk gene expression (raw count).
Parameters
----------
adata_raw : AnnData
Needs to have raw counts in .X
pbs : np.ndarray | sp.sparse.csr_matrix | None, optional
Optional binary matrix with cells as rows and pseudobulk groups as columns
obs_to_bulk : list[str] | str | None, optional
Optional obs column(s) to group pseudobulks into; if multiple are provided, they
will be combined
obs_to_take : list[str] | str | None, optional
Optional obs column(s) to identify the most common value of for each pseudobulk
Returns
-------
AnnData
pb_adata whereby each observation is a cell neighbourhood\n
pseudobulked gene expression stored in pb_adata.X\n
genes stored in pb_adata.var\n
pseudobulk metadata stored in pb_adata.obs\n
pseudobulk assignment (binary matrix with input cells as columns) stored in pb_adata.obsm['pbs']\n
"""
# get our cells by pseudobulks matrix
pbs = _get_pbs(pbs, obs_to_bulk, adata_raw)
# make pseudobulk matrix
pbs_X = adata_raw.X.T.dot(pbs)
# create obs for the new pseudobulk object
pbs_obs = _get_pbs_obs(pbs, obs_to_take, adata_raw)
## Make new anndata object
pb_adata = AnnData(pbs_X.T, obs=pbs_obs, var=adata_raw.var)
# store the pseudobulk assignments, as a sparse for storage efficiency
# transpose as the original matrix is cells x pseudobulks
pb_adata.obsm["pbs"] = sp.sparse.csr_matrix(pbs.T)
return pb_adata