from __future__ import annotations
import functools
import math
import networkx as nx
import numpy as np
import pandas as pd
import polars as pl
import scanpy as sc
from anndata import AnnData
from collections import defaultdict, Counter
from contextlib import contextmanager
from scanpy import logging as logg
from scipy.sparse import coo_matrix, csr_matrix, eye as speye
from scipy.sparse.csgraph import connected_components
from tqdm import tqdm
from typing import Callable, Literal, TYPE_CHECKING
if TYPE_CHECKING:
from mudata import MuData
from awkward import Array
from dandelion.polars.core._core import DandelionPolars, SCHEMA_OVERRIDES
from dandelion.utilities._utilities import (
VCALL,
TRUES_STR,
FALSES_STR,
FALSES,
JCALL,
VCALLG,
STRIPALLELENUM,
EMPTIES,
is_categorical,
Tree,
)
from dandelion.utilities._distances import (
IdentityMetric,
resolve_metric,
)
[docs]
def find_clones(
vdj: DandelionPolars,
identity: dict[str, float] | float = 0.85,
hard_cutoff: int | float | None = None,
key: dict[str, str] | str | None = None,
dist_func: (
Literal["hamming", "levenshtein", "identity"] | Callable | str
) = "hamming",
same_vj: bool = True,
same_length: bool = True,
by_alleles: bool = False,
key_added: str | None = None,
recalculate_length: bool = True,
store_distances: bool = True,
verbose: bool = True,
) -> DandelionPolars:
"""
Find clones based on VDJ chain and VJ chain CDR3 junction hamming distance.
Parameters
----------
vdj : DandelionPolars
Dandelion object.
identity : dict[str, float] | float, optional
Similarity parameter. Default 0.85. Distance cutoff is calculated as
`threshold = floor(length * (1 - identity))`. If `dist_func` is 'identity', `threshold` is set to 0.
If `dist_func` is 'levenshtein' or a substitution matrix, the threshold is calculated based on normalized
length internally. If a single float value is provided, this will be used for all loci.
If provided as a dictionary, please use the following keys:'ig', 'tr-ab', 'tr-gd'.
hard_cutoff : int | float | None, optional
Absolute distance cutoff. If supplied, `identity` is ignored. Only for use with specific distance functions
such as levenshtein and substitution matrices. Default is `None`.
key : dict[str, str] | str | None, optional
column name for performing clone clustering. `None` defaults to a dictionary where:
{'ig': 'junction_aa', 'tr-ab': 'junction', 'tr-gd': 'junction'}
If provided as a string, this key will be used for all loci.
dist_func : Literal["hamming", "levenshtein", "identity"] | Callable | str, optional
Distance function to use. Can be 'hamming', 'levenshtein', 'identity', substitution matrix name, or a custom lambda function.
`None` defaults to 'hamming'.
same_vj : bool, optional
whether or not to require same V and J gene assignments to be in the same clone. Default is True.
same_length : bool, optional
whether or not to require same junction length to be in the same clone. Default is True.
by_alleles : bool, optional
whether or not to collapse alleles to genes. `None` defaults to False.
key_added : str | None, optional
If specified, this will be the column name for clones. `None` defaults to 'clone_id'
recalculate_length : bool, optional
whether or not to re-calculate junction length, rather than rely on parsed assignment (which occasionally is
wrong). Default is True
store_distances : bool, optional
whether or not to store the distance matrix as a sparse matrix in `vdj.distances`. Default is True.
verbose : bool, optional
whether or not to print progress.
Returns
-------
Dandelion
Dandelion object with clone_id annotated in `.data` slot and `.metadata` initialized.
Raises
------
ValueError
if `key` not found in Dandelion.data.
"""
start = logg.info("Finding clonotypes")
df = vdj._data
# Collect lazy frame if necessary, then convert to pandas
if isinstance(df, pl.LazyFrame):
# we will load this to memory to enable the rest.
df = df.collect(engine="streaming")
# Default locus dictionary
locus_dict = {
"ig": (["IGH"], ["IGK", "IGL"]),
"tr-ab": (["TRB"], ["TRA"]),
"tr-gd": (["TRD"], ["TRG"]),
}
# Locus logging dictionary
locus_log = {"ig": "B", "tr-ab": "abT", "tr-gd": "gdT"}
# Default identity
default_identity = {"ig": 0.85, "tr-ab": 1.0, "tr-gd": 1.0}
default_key = {
"ig": "junction_aa",
"tr-ab": "junction",
"tr-gd": "junction",
}
metric = resolve_metric(dist_func)
# Default identity
if identity is None:
identity = default_identity
elif isinstance(identity, dict):
default_identity.update(identity)
identity = default_identity
elif not isinstance(identity, dict):
# Single float value - use for all loci
identity = {"ig": identity, "tr-ab": identity, "tr-gd": identity}
# Default key (junction column)
if key is None:
key = default_key
elif isinstance(key, str):
# Single string - use for all loci
key = {"ig": key, "tr-ab": key, "tr-gd": key}
locus_to_col = {
"IGH": key["ig"],
"IGK": key["ig"],
"IGL": key["ig"],
"TRB": key["tr-ab"],
"TRA": key["tr-ab"],
"TRD": key["tr-gd"],
"TRG": key["tr-gd"],
}
# Default key_added
key_added = "clone_id" if key_added is None else key_added
# Initialize clone column
df = df.with_columns(pl.lit("").alias(key_added))
# Also initialise the original order column
df = df.with_row_index("_original_order")
# Store results from each locus
locus_results = {}
# Initialize distance storage based on backend choice
distance_results = []
if store_distances:
# Memory-based mode (original approach)
metadata_for_mapping = vdj._metadata
if isinstance(metadata_for_mapping, pl.LazyFrame):
metadata_for_mapping = metadata_for_mapping.collect(
engine="streaming"
)
all_cell_ids = metadata_for_mapping["cell_id"].to_list()
cell_id_to_meta_idx = {
cell_id: i for i, cell_id in enumerate(all_cell_ids)
}
n_cells = len(all_cell_ids)
# Process each locus
for locus, (vdj_loci, vj_loci) in locus_dict.items():
# Filter to this locus
df_locus = df.filter(pl.col("locus").is_in(vdj_loci + vj_loci))
if "ambiguous" in df_locus.collect_schema():
df_locus = df_locus.filter(~pl.col("ambiguous").is_in(TRUES_STR))
# early skip if no rows
if df_locus.height == 0:
continue
# Get locus-specific parameters
locus_identity = identity[locus]
locus_key = key[locus]
locus_celltype = locus_log[locus]
# Add celltype to this locus
df_locus = df_locus.with_columns(
pl.lit(locus_celltype).alias("_celltype")
)
# Check for VDJ and VJ chains
has_vdj, has_vj = _check_chains(df_locus, vdj_loci, vj_loci)
# Initialize results for this locus
df_vdj_result = None
df_vj_result = None
vdj_chain, vj_chain = "VDJ", "VJ"
# process VDJ chain
if has_vdj:
df_vdj = df_locus.filter(pl.col("locus").is_in(vdj_loci))
# Group sequences
df_vdj_grp = _group_sequences(
df=df_vdj,
key=locus_key,
same_vj=same_vj,
same_length=same_length,
recalculate_length=recalculate_length,
by_alleles=by_alleles,
)
# Build aggregation list dynamically
# Collect both sequences and their original order for proper tracking
agg_cols = [
pl.col(locus_key),
pl.col("_original_order"),
pl.col("cell_id"),
]
if same_length:
agg_cols.append(pl.col(f"_{locus_key}_length").first())
else:
agg_cols.append(pl.col(f"_{locus_key}_length").max())
grouped = (
df_vdj_grp.group_by("_membership")
.agg(agg_cols)
.sort("_membership")
)
clones_vdj = defaultdict(dict)
# Also track which original rows belong to which clone
row_to_clone = {}
for row in tqdm(
grouped.iter_rows(named=True),
desc=f"Finding clones based on {locus_celltype} cell {vdj_chain} chains using {locus_key}",
bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}",
total=df_vdj_grp.select("_membership").n_unique(),
disable=not verbose,
):
seqs = row[locus_key]
orig_orders = row["_original_order"]
cell_ids = row["cell_id"]
membership = row["_membership"]
length = row[f"_{locus_key}_length"]
if isinstance(metric, IdentityMetric):
threshold = 0
else:
if hard_cutoff is None:
threshold = math.floor(
int(length) * (1 - locus_identity)
)
else:
threshold = None
# Filter out empty strings for distance calculation only
seqs_non_empty = [s for s in seqs if s]
seqs_empty = [s for s in seqs if not s]
# Also filter cell_ids to match seqs_non_empty
cell_ids_non_empty = [
cell_ids[i] for i, s in enumerate(seqs) if s
]
if seqs_non_empty:
# Deduplicate sequences for faster computation
unique_seqs = list(set(seqs_non_empty))
seq_to_unique_idx = {
seq: i for i, seq in enumerate(unique_seqs)
}
# Compute distances only for unique sequences
d_mat_unique = metric.compute_vectorized(unique_seqs)
# Filter to only cells that are in metadata (ensure alignment)
if store_distances:
valid_mask = [
cid in cell_id_to_meta_idx
for cid in cell_ids_non_empty
]
seqs_filtered = [
s
for s, valid in zip(seqs_non_empty, valid_mask)
if valid
]
meta_indices = np.array(
[
cell_id_to_meta_idx[cid]
for cid, valid in zip(
cell_ids_non_empty, valid_mask
)
if valid
]
)
if len(meta_indices) > 0 and d_mat_unique.size > 0:
# Use vectorized indexing to expand unique distances to full matrix
unique_indices = np.array(
[
seq_to_unique_idx[seq]
for seq in seqs_filtered
]
)
d_mat = d_mat_unique[
np.ix_(unique_indices, unique_indices)
]
# Collect COO components for in-memory assembly
n_local = len(meta_indices)
rows, cols = np.meshgrid(
range(n_local),
range(n_local),
indexing="ij",
)
rows_flat = rows.ravel()
cols_flat = cols.ravel()
data_flat = d_mat.ravel()
global_rows = meta_indices[rows_flat]
global_cols = meta_indices[cols_flat]
distance_results.append(
(global_rows, global_cols, data_flat)
)
# Cluster on unique sequences only, then map back
if len(unique_seqs) > 1:
seq_tmp_dict = _clustering_scipy(
d_mat_unique,
threshold=threshold,
sequences=unique_seqs,
hard_threshold=hard_cutoff,
)
else:
seq_tmp_dict = {unique_seqs[0]: (unique_seqs[0],)}
else:
# All sequences in this membership are empty - assign them together
if seqs_empty:
seq_tmp_dict = {seqs_empty[0]: tuple(seqs_empty)}
else:
seq_tmp_dict = {}
# Sort by size
clones_tmp = sorted(
list(set(seq_tmp_dict.values())), key=len, reverse=True
)
for sub_group, clone_group in enumerate(clones_tmp, 1):
clones_vdj[membership][sub_group] = clone_group
# Map each original row to its clone based on its sequence
for seq_idx, seq_val in enumerate(seqs):
if seq_val and seq_val in clone_group:
row_to_clone[orig_orders[seq_idx]] = (
f"{membership}_{sub_group}"
)
# Apply clone IDs using row mapping instead of sequence mapping
df_vdj_grp = df_vdj_grp.with_columns(
pl.col("_original_order")
.replace_strict(row_to_clone, default=None)
.alias(f"_{key_added}_{vdj_chain}")
)
# Keep only what we need
df_vdj_result = df_vdj_grp.select(
[
"_original_order",
f"_{key_added}_VDJ",
]
)
del df_vdj, df_vdj_grp
# process VJ chains next
if has_vj:
df_vj = df_locus.filter(pl.col("locus").is_in(vj_loci))
# Group sequences
df_vj_grp = _group_sequences(
df=df_vj,
key=locus_key,
same_vj=same_vj,
same_length=same_length,
recalculate_length=recalculate_length,
by_alleles=by_alleles,
)
# Build aggregation list dynamically
# Collect both sequences and their original order for proper tracking
agg_cols = [
pl.col(locus_key),
pl.col("_original_order"),
pl.col("cell_id"),
]
if same_length:
agg_cols.append(pl.col(f"_{locus_key}_length").first())
else:
agg_cols.append(pl.col(f"_{locus_key}_length").max())
grouped = (
df_vj_grp.group_by("_membership")
.agg(agg_cols)
.sort("_membership")
)
clones_vj = defaultdict(dict)
# Also track which original rows belong to which clone
row_to_clone = {}
for row in tqdm(
grouped.iter_rows(named=True),
desc=f"Finding clones based on {locus_celltype} cell {vj_chain} chains using {locus_key}",
bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}",
total=grouped.height,
disable=not verbose,
):
seqs = row[locus_key]
orig_orders = row["_original_order"]
cell_ids = row["cell_id"]
length = row[f"_{locus_key}_length"]
membership = row["_membership"]
if isinstance(metric, IdentityMetric):
threshold = 0
else:
if hard_cutoff is None:
threshold = math.floor(
int(length) * (1 - locus_identity)
)
else:
threshold = None
# Filter out empty strings for distance calculation only
seqs_non_empty = [s for s in seqs if s]
seqs_empty = [s for s in seqs if not s]
# Also filter cell_ids to match seqs_non_empty
cell_ids_non_empty = [
cell_ids[i] for i, s in enumerate(seqs) if s
]
if seqs_non_empty:
# Deduplicate sequences for faster computation
unique_seqs = list(set(seqs_non_empty))
seq_to_unique_idx = {
seq: i for i, seq in enumerate(unique_seqs)
}
# Compute distances only for unique sequences
d_mat_unique = metric.compute_vectorized(unique_seqs)
# Filter to only cells that are in metadata (ensure alignment)
if store_distances:
valid_mask = [
cid in cell_id_to_meta_idx
for cid in cell_ids_non_empty
]
seqs_filtered = [
s
for s, valid in zip(seqs_non_empty, valid_mask)
if valid
]
meta_indices = np.array(
[
cell_id_to_meta_idx[cid]
for cid, valid in zip(
cell_ids_non_empty, valid_mask
)
if valid
]
)
if len(meta_indices) > 0 and d_mat_unique.size > 0:
# Use vectorized indexing to expand unique distances to full matrix
unique_indices = np.array(
[
seq_to_unique_idx[seq]
for seq in seqs_filtered
]
)
d_mat = d_mat_unique[
np.ix_(unique_indices, unique_indices)
]
# Collect COO components for in-memory assembly
n_local = len(meta_indices)
rows, cols = np.meshgrid(
range(n_local),
range(n_local),
indexing="ij",
)
rows_flat = rows.ravel()
cols_flat = cols.ravel()
data_flat = d_mat.ravel()
global_rows = meta_indices[rows_flat]
global_cols = meta_indices[cols_flat]
distance_results.append(
(global_rows, global_cols, data_flat)
)
# Cluster on unique sequences only, then map back
if len(unique_seqs) > 1:
seq_tmp_dict = _clustering_scipy(
d_mat_unique,
threshold=threshold,
sequences=unique_seqs,
hard_threshold=hard_cutoff,
)
else:
seq_tmp_dict = {unique_seqs[0]: (unique_seqs[0],)}
else:
# All sequences in this membership are empty - assign them together
if seqs_empty:
seq_tmp_dict = {seqs_empty[0]: tuple(seqs_empty)}
else:
seq_tmp_dict = {}
# Sort by size
clones_tmp = sorted(
list(set(seq_tmp_dict.values())), key=len, reverse=True
)
for sub_group, clone_group in enumerate(clones_tmp, 1):
clones_vj[membership][sub_group] = clone_group
# Map each original row to its clone based on its sequence
for seq_idx, seq_val in enumerate(seqs):
if seq_val and seq_val in clone_group:
row_to_clone[orig_orders[seq_idx]] = (
f"{membership}_{sub_group}"
)
# Apply clone IDs using row mapping instead of sequence mapping
df_vj_grp = df_vj_grp.with_columns(
pl.col("_original_order")
.replace_strict(row_to_clone, default=None)
.alias(f"_{key_added}_{vj_chain}")
)
# Keep only what we need
df_vj_result = df_vj_grp.select(
[
"_original_order",
f"_{key_added}_VJ",
]
)
# Combine VDJ + VJ for this locus
if df_vdj_result is not None and df_vj_result is not None:
df_locus_chains = df_vdj_result.join(
df_vj_result,
on="_original_order",
how="full",
coalesce=True,
)
elif df_vdj_result is not None:
df_locus_chains = df_vdj_result.with_columns(
pl.lit(None).alias(f"_{key_added}_VJ")
)
elif df_vj_result is not None:
df_locus_chains = df_vj_result.with_columns(
pl.lit(None).alias(f"_{key_added}_VDJ")
)
else:
# No chains found for this locus
continue
# Add celltype back
df_locus_chains = df_locus_chains.with_columns(
pl.lit(locus_celltype).alias("_celltype")
)
# Join with cell_id from original df
df_locus_chains = df_locus_chains.join(
df.select(["_original_order", "cell_id"]),
on="_original_order",
how="left",
)
# Combine VDJ + VJ at the cell level for this locus
df_locus_summary = df_locus_chains.group_by("cell_id").agg(
[
pl.col(f"_{key_added}_VDJ")
.drop_nulls()
.unique()
.alias("_vdj_set"),
pl.col(f"_{key_added}_VJ")
.drop_nulls()
.unique()
.alias("_vj_set"),
]
)
# Create locus-specific clone IDs
df_locus_summary = df_locus_summary.with_columns(
pl.struct(["_vdj_set", "_vj_set"])
.map_elements(
lambda s: _combine_single_locus(
s["_vdj_set"], s["_vj_set"], locus_celltype
),
return_dtype=pl.List(pl.String),
)
.alias(f"_{key_added}_{locus_celltype}")
)
# Store result for this locus
locus_results[locus_celltype] = df_locus_summary.select(
["cell_id", f"_{key_added}_{locus_celltype}"]
)
# Merge all locus results back to main df
# Start with cell_id from original df
df_final = df.select("cell_id").unique()
# Join each locus result
for locus_celltype, locus_df in locus_results.items():
df_final = df_final.join(locus_df, on="cell_id", how="left")
# Combine all locus clone IDs with '|'
locus_columns = [f"_{key_added}_{ct}" for ct in locus_results.keys()]
if locus_columns:
# Each column contains a list of clone IDs for that locus
# We need to flatten all lists and join with '|'
df_final = df_final.with_columns(
pl.struct(locus_columns)
.map_elements(
lambda row: _flatten_and_join_loci(row, locus_columns),
return_dtype=pl.String,
)
.alias(key_added)
).drop(locus_columns)
else:
# No loci processed, add empty clone_id column
df_final = df_final.with_columns(pl.lit(None).alias(key_added))
# overwrite the original key_added column in df if exists
if key_added in df.collect_schema():
df = df.drop(key_added)
# Join back to original df
df = df.join(df_final, on="cell_id", how="left")
# After propagation, clear clone_id for contigs missing required V/J/key fields
v_col_common = VCALLG if VCALLG in df.collect_schema() else VCALL
# Build when/then chain dynamically
# Get unique column names from the key dict
possible_cols = set(key.values())
# Build expression to check if the relevant column is null/empty
checks = []
for col in possible_cols:
checks.append(
pl.when(pl.col("_key") == col).then(
(pl.col(col).is_null()) | (pl.col(col) == "")
)
)
# Chain them
is_empty_expr = checks[0]
for check in checks[1:]:
is_empty_expr = is_empty_expr.otherwise(check)
has_ambiguous = "ambiguous" in df.collect_schema()
condition = (
(~pl.col("_is_key_empty"))
& (pl.col(v_col_common).is_not_null())
& (pl.col(v_col_common).ne(""))
& (pl.col(JCALL).is_not_null())
& (pl.col(JCALL).ne(""))
)
if has_ambiguous:
condition = condition & (~pl.col("ambiguous").is_in(TRUES_STR))
df = (
df.with_columns(pl.col("locus").replace(locus_to_col).alias("_key"))
.with_columns(is_empty_expr.alias("_is_key_empty"))
.with_columns(
pl.when(condition)
.then(pl.col(key_added))
.otherwise(None)
.alias(key_added)
)
.drop(["_original_order", "_key", "_is_key_empty"])
)
# return
vdj._data = df
vdj.update_metadata(clone_key=str(key_added))
# offload memory
vdj._cache_data()
# Build sparse distance matrix from collected data
if store_distances:
# Batched COO construction - avoids single large concatenation
logg.info("Storing distance matrix...")
# Get matrix dimensions
n_cells = len(all_cell_ids)
# Build COO matrices in batches and sum them
batch_size = 100 # Process 100 submatrices at a time
csr_dist = None
for batch_start in tqdm(
range(0, len(distance_results), batch_size),
desc="Building distance matrix (batched)",
disable=not verbose,
):
batch_end = min(batch_start + batch_size, len(distance_results))
batch = distance_results[batch_start:batch_end]
# Build COO for this batch only
batch_coo = coo_matrix(
(
np.concatenate([d for r, c, d in batch]),
(
np.concatenate([r for r, c, d in batch]),
np.concatenate([c for r, c, d in batch]),
),
),
shape=(n_cells, n_cells),
)
# Convert to CSR and accumulate
if csr_dist is None:
csr_dist = batch_coo.tocsr()
else:
csr_dist = csr_dist + batch_coo.tocsr()
del batch_coo
# Clear immediately
distance_results.clear()
# Store in vdj.distances
vdj.distances = csr_dist
logg.info(
f"Stored distances as CSR sparse matrix: {csr_dist.shape}, density={csr_dist.nnz / (n_cells**2):.2%}"
)
logg.info(
" finished",
time=start,
deep=(
"Updated Dandelion object: \n"
" 'data', contig AIRR table\n"
" 'metadata', cell observations table\n"
" 'distances', sparse distance matrix\n"
if store_distances
else ""
),
)
return vdj
def _check_chains(
df: pl.DataFrame | pl.LazyFrame,
vdj_loci: list[str],
vj_loci: list[str],
) -> tuple[bool, bool]:
"""
Check if VDJ and VJ chains exist for a locus using polars.
Vectorized check using polars filtering operations.
Parameters
----------
df : pl.DataFrame | pl.LazyFrame
Input AIRR dataframe.
vdj_loci : list[str]
VDJ loci (e.g., ['IGH'], ['TRB']).
vj_loci : list[str]
VJ loci (e.g., ['IGK', 'IGL'], ['TRA']).
Returns
-------
tuple[bool, bool]
(has_vdj, has_vj) indicating presence of chains.
"""
if isinstance(df, pl.LazyFrame):
df = df.collect(engine="streaming")
# Vectorized check for VDJ chains
has_vdj = df.filter(pl.col("locus").is_in(vdj_loci)).shape[0] > 0
# Vectorized check for VJ chains
has_vj = df.filter(pl.col("locus").is_in(vj_loci)).shape[0] > 0
return has_vdj, has_vj
def _group_sequences(
df: pl.DataFrame | pl.LazyFrame,
key: str,
same_vj: bool = True,
same_length: bool = True,
recalculate_length: bool = True,
by_alleles: bool = False,
):
"""
Group sequences by V/J genes and junction length using vectorized polars.
Vectorized polars implementation that groups contigs by (V gene, J gene)
pairs and then by junction length. Returns numerical group IDs.
Parameters
----------
df : pl.DataFrame | pl.LazyFrame
Input AIRR dataframe.
key : str
Column name for junction sequences (e.g., 'junction', 'junction_aa').
same_vj : bool, optional
Whether to group by same V and J genes.
same_length : bool, optional
Whether to group by same junction length.
recalculate_length : bool, optional
Whether to recalculate junction length from sequences.
by_alleles : bool, optional
Whether to group by alleles or genes.
Returns
-------
pl.DataFrame
DataFrame with added '_membership' column.
"""
# Ensure LazyFrame for efficiency
if isinstance(df, pl.DataFrame):
df = df.lazy()
v_col = VCALLG if VCALLG in df.collect_schema() else VCALL
# Vectorized V/J gene stripping using polars string operations
if same_vj:
if not by_alleles:
df = df.with_columns(
[
pl.col(v_col)
.str.replace_all(STRIPALLELENUM, "")
.alias("_v_gene"),
pl.col(JCALL)
.str.replace_all(STRIPALLELENUM, "")
.alias("_j_gene"),
]
)
else:
df = df.with_columns(
[
pl.col(v_col).alias("_v_gene"),
pl.col(JCALL).alias("_j_gene"),
]
)
# Prepare length column independently of same_vj/same_length.
# Downstream code expects `_{key}_length` to exist for aggregation.
length_col = key + "_length"
if recalculate_length or (length_col not in df.collect_schema()):
df = df.with_columns(
pl.when(pl.col(key).is_null())
.then(pl.lit(0))
.otherwise(pl.col(key).cast(pl.String).str.len_bytes())
.alias(f"_{key}_length")
)
else:
df = df.with_columns(pl.col(length_col).alias(f"_{key}_length"))
# Filter out rows with null or empty junction (empty strings can't be clustered by distance)
filter_conditions = [
pl.col(key).is_not_null(),
pl.col(key).ne(""), # Filter out empty strings as well
]
if same_vj:
filter_conditions.extend(
[
pl.col("_v_gene").is_not_null(),
pl.col("_j_gene").is_not_null(),
pl.col("_v_gene").ne(""),
pl.col("_j_gene").ne(""),
]
)
df = df.filter(pl.all_horizontal(filter_conditions))
# Build grouping columns dynamically
grouping_cols = []
if same_vj:
grouping_cols.extend(["_v_gene", "_j_gene"])
if same_length:
grouping_cols.append(f"_{key}_length")
# Create membership based on selected grouping
if grouping_cols:
# Create a combined group ID
if same_vj and same_length:
# Both VJ and length grouping
df = df.with_columns(
pl.struct(["_v_gene", "_j_gene"])
.rank(method="dense")
.alias("_vj_group")
)
df = df.with_columns(
pl.col(f"_{key}_length")
.rank(method="dense")
.over("_vj_group")
.alias("_length_group")
)
df = df.with_columns(
pl.concat_str(
[
pl.col("_vj_group").cast(pl.String),
pl.lit("_"),
pl.col("_length_group").cast(pl.String),
]
).alias("_membership")
)
df = df.drop(["_v_gene", "_j_gene", "_vj_group", "_length_group"])
elif same_vj:
# Only VJ grouping
df = df.with_columns(
pl.struct(["_v_gene", "_j_gene"])
.rank(method="dense")
.cast(pl.String)
.alias("_membership")
)
df = df.drop(["_v_gene", "_j_gene"])
elif same_length:
# Only length grouping
df = df.with_columns(
pl.col(f"_{key}_length")
.rank(method="dense")
.cast(pl.String)
.alias("_membership")
)
else:
# No grouping - all sequences in one group
df = df.with_columns(pl.lit("1").alias("_membership"))
if isinstance(df, pl.LazyFrame):
df = df.collect(engine="streaming")
return df
def _clustering_scipy(
d_mat: np.ndarray,
threshold: float,
sequences: list[str],
hard_threshold: int | float | None = None,
) -> dict:
"""
Cluster sequences using scipy connected components (fastest).
Parameters
----------
d_mat : np.ndarray
Distance matrix (n x n).
threshold : float
Distance threshold for clustering.
sequences : list[str]
List of sequences.
hard_threshold : int | float | None, optional
Absolute distance cutoff. If supplied, `threshold` is ignored.
Returns
-------
dict
Dictionary mapping sequences to cluster groups.
"""
_threshold = threshold if hard_threshold is None else hard_threshold
# Create adjacency matrix
adjacency = (d_mat <= _threshold).astype(int)
# Find connected components
_, labels = connected_components(
csgraph=csr_matrix(adjacency), directed=False
)
# Group sequences by cluster labels
clusters = defaultdict(list)
for idx, label in enumerate(labels):
clusters[label].append(idx)
# Build output dict
out_dict = {}
for cluster_indices in clusters.values():
cluster_seqs = tuple(
sorted([sequences[idx] for idx in cluster_indices])
)
for idx in cluster_indices:
out_dict[sequences[idx]] = cluster_seqs
return out_dict
def _combine_single_locus(
vdj_list: list[str] | None, vj_list: list[str] | None, celltype: str
) -> list[str]:
"""Combine VDJ/VJ for a single celltype/locus.
Args:
vdj_list: List of VDJ clone IDs
vj_list: List of VJ clone IDs
celltype: Cell type identifier (e.g., 'B', 'abT', 'gdT')
Returns:
List of combined clone IDs for this locus
"""
if not vdj_list and not vj_list:
return []
vdj_vals = vdj_list if vdj_list else [None]
vj_vals = vj_list if vj_list else [None]
combos: list[str] = []
for vdj in vdj_vals:
for vj in vj_vals:
parts = [celltype]
if vdj is not None:
parts.append(f"VDJ_{vdj}")
if vj is not None:
parts.append(f"VJ_{vj}")
combos.append("_".join(parts))
return combos
def _flatten_and_join_loci(row: dict, locus_columns: list[str]) -> str | None:
"""Flatten clone IDs from all loci and join with '|'.
Args:
row: Dictionary containing clone ID lists for each locus
locus_columns: List of column names containing locus-specific clone IDs
Returns:
Pipe-separated string of all clone IDs, or None if no clones found
"""
all_clones: list[str] = []
for col_name in locus_columns:
locus_clones = row[col_name]
if locus_clones is not None:
all_clones.extend(locus_clones)
return "|".join(all_clones) if all_clones else None
[docs]
def transfer(
adata: AnnData | MuData,
vdj: DandelionPolars,
main_view: Literal["all", "expanded", "full"] = "all",
gex_key: str | None = None,
vdj_key: str | None = None,
clone_key: str | None = None,
collapse_nodes: bool = False,
overwrite: bool | list[str] | str | None = None,
obs: bool = True,
obsm: bool = True,
uns: bool = True,
obsp: bool = True,
) -> None:
"""
Transfer data in Dandelion slots to AnnData, updating `.obs`, `.uns`, `.obsm`, and `.obsp`.
Transfers both graphs:
- graph[0] -> adata.uns['dandelion']['X_vdj_all']
- graph[1] -> adata.uns['dandelion']['X_vdj_expanded']
The `main_view` flag controls which graph becomes the *main* adjacency written to
adata.obsp['connectivities'] / ['distances'] (but both graphs are stored).
Parameters
----------
adata : AnnData | MuData
AnnData object or `MuData` object.
vdj : DandelionPolars
Dandelion object.
main_view : Literal["all", "expanded", "full"], optional
Which graph becomes the *main* adjacency written to adata.obsp['connectivities'] / ['distances'].
If 'full', the full distance matrix from Dandelion is transferred and stored in adata.obsp and no graph is transferred to obsm.
gex_key : str | None, optional
prefix for stashed RNA connectivities and distances.
vdj_key : str | None, optional
prefix for stashed VDJ connectivities and distances.
clone_key : str | None, optional
column name of clone/clonotype ids. Only used for integration with scirpy.
collapse_nodes : bool, optional
Whether or not to transfer a cell x cell or clone x clone connectivity matrix into `.uns`. Only used for
integration with scirpy.
overwrite : bool | list[str] | str | None, optional
Whether or not to overwrite existing anndata columns. Specifying a string indicating column name or
list of column names will overwrite that specific column(s).
obs : bool, optional
Whether to transfer `.metadata` columns to `adata.obs`. Defaults to True.
obsm : bool, optional
Whether to transfer layout embeddings to `adata.obsm`. Defaults to True.
uns : bool, optional
Whether to transfer graph and layout data to `adata.uns`. Defaults to True.
obsp : bool, optional
Whether to transfer distance and connectivity matrices to `adata.obsp`. Defaults to True.
"""
start = logg.info("Transferring network")
# if the provide adata is an MuData, we need to transfer to mudata.mod['gex']
# but we don't want to add mudata as a dependency here, so we do a duck-typing check
if hasattr(adata, "mod"):
if "airr" in adata.mod:
recipient = adata.mod["airr"]
else:
raise ValueError(
"Provided AnnData is a MuData object without 'airr' modality."
)
# we just associate recipient to adata directly
else:
recipient = adata
original_backend = vdj._backend
original_lazy = vdj._lazy
if original_backend == "polars":
vdj.to_pandas()
# --- 1) metadata -> adata.obs (preserve original overwrite semantics) ---
if obs:
for x in vdj._metadata.columns:
assigned = False
if x not in recipient.obs.columns:
recipient.obs[x] = pd.Series(vdj._metadata[x]).reindex(
recipient.obs_names
)
assigned = True
elif overwrite is True:
recipient.obs[x] = pd.Series(vdj._metadata[x]).reindex(
recipient.obs_names
)
assigned = True
if assigned and recipient.obs[x].dtype == "bool":
recipient.obs[x] = recipient.obs[x].astype(str)
# explicit overwrite list/string handling (matches original)
if (overwrite is not None) and (overwrite is not True):
if not isinstance(overwrite, list):
overwrite = [overwrite]
for ow in overwrite:
recipient.obs[ow] = pd.Series(vdj._metadata[ow]).reindex(
recipient.obs_names
)
if recipient.obs[ow].dtype == "bool":
recipient.obs[ow] = recipient.obs[ow].astype(str)
# also check that all the cells in dandelion are in recipient
common_cells = recipient.obs_names.intersection(vdj._metadata.index)
# subset to common cells only
vdj = vdj[vdj._metadata.index.isin(common_cells)]
# If there's no graph, we're done with metadata only
if vdj.graph is None:
logg.info(
" finished", time=start, deep=("updated `.obs` with `.metadata`\n")
)
return
# --- 2) prepare neighbor keys and stash RNA neighbors/connectivities if present ---
neighbors_key = "neighbors"
skip_stash = neighbors_key not in recipient.uns
if obsp:
gex_key = "gex" if gex_key is None else gex_key
g_connectivities_key = f"{gex_key}_connectivities"
g_distances_key = f"{gex_key}_distances"
vdj_key = "vdj" if vdj_key is None else vdj_key
v_connectivities_key = f"{vdj_key}_connectivities"
v_distances_key = f"{vdj_key}_distances"
# Stash RNA connectivities/distances before we overwrite connectivities/distances
if not skip_stash:
# stash in uns["dandelion"] instead of obsp
recipient.uns.setdefault("dandelion", {})[g_connectivities_key] = (
recipient.obsp["connectivities"]
)
recipient.uns["dandelion"][g_distances_key] = recipient.obsp[
"distances"
]
g_neighbors_key = f"{gex_key}_{neighbors_key}"
recipient.uns[g_neighbors_key] = recipient.uns[neighbors_key]
# --- 3) Convert both graphs ---
graph_connectivities, graph_distances = {}, {}
if main_view == "full":
main_idx = 2
# explicitly only transfer full distance matrix if requested
if getattr(vdj, "distances", None) is not None:
graph_connectivities[2], graph_distances[2] = _graph_to_matrices(
None, recipient, vdj.distances
)
else:
# handle graph[0] and graph[1]
for idx in (0, 1):
G = None
if vdj.graph is not None:
try:
G = vdj.graph[idx]
except Exception:
pass
if G is not None:
graph_connectivities[idx], graph_distances[idx] = (
_graph_to_matrices(G, recipient, None)
)
# Determine main graph index
main_idx = 1 if main_view == "expanded" else 0
if main_idx not in graph_connectivities:
main_idx = next(iter(graph_connectivities.keys()))
if main_idx not in graph_connectivities:
if main_view == "full":
raise ValueError(
"main_view='full' requested but `vdj.distances` is not available. "
"Run clone finding with `store_distances=True` first."
)
raise ValueError(
"No VDJ graph/distance information available to transfer."
)
if obsp:
# --- 4) Update recipient.obsp (active view only) ---
recipient.obsp["connectivities"] = graph_connectivities[main_idx]
recipient.obsp["distances"] = graph_distances[main_idx]
# store non-active views in uns["dandelion"] to keep .obsp clean
_ddl = recipient.uns.setdefault("dandelion", {})
if main_idx != 2:
# store the all (graph[0]) and expanded graph (graph[1]) if available
if 0 in graph_connectivities:
_ddl[f"{v_connectivities_key}_all"] = graph_connectivities[0]
_ddl[f"{v_distances_key}_all"] = graph_distances[0]
if 1 in graph_connectivities:
_ddl[f"{v_connectivities_key}_expanded"] = graph_connectivities[
1
]
_ddl[f"{v_distances_key}_expanded"] = graph_distances[1]
else:
if 2 in graph_connectivities:
_ddl[f"{v_connectivities_key}_full"] = graph_connectivities[2]
_ddl[f"{v_distances_key}_full"] = graph_distances[2]
recipient.uns[neighbors_key] = {
"connectivities_key": "connectivities",
"distances_key": "distances",
"params": {
"n_neighbors": 1,
"method": "custom",
"metric": "precomputed",
},
}
if uns:
# --- 5) Clone-level mapping (scirpy compatible) ---
clone_key = clone_key if clone_key is not None else "clone_id"
if not collapse_nodes:
for idx in graph_connectivities:
graph_connectivities[idx].data[:] = 1
cell_indices = {
str(i): np.array([k])
for i, k in zip(
range(0, len(recipient.obs_names)), recipient.obs_names
)
}
bin_conn = graph_connectivities[main_idx]
else:
invalid = [
"",
"unassigned",
"NaN",
"NA",
"nan",
"None",
"none",
None,
]
cell_indices = Tree()
for x, y in recipient.obs[clone_key].items():
if y not in invalid:
cell_indices[y][x].value = 1
cell_indices = {
str(x): np.array(list(r))
for x, r in zip(
range(0, len(cell_indices)), cell_indices.values()
)
}
bin_conn = speye(len(cell_indices), format="csr")
recipient.uns[clone_key] = {
# this is a symmetrical, pairwise, sparse distance matrix of clonotypes
# the matrix is offset by 1, i.e. 0 = no connection, 1 = distance 0
"distances": bin_conn,
# '0' refers to the row/col index in the `distances` matrix
# (numeric index, but needs to be strbecause of h5py)
# np.array(["cell1", "cell2"]) points to the rows in `recipient.obs`
"cell_indices": cell_indices,
}
if obsm:
# --- 6) Layouts ---
if vdj.layout is not None:
stored_embeddings = {}
for idx, obsm_name in (
(0, "X_vdj_all"),
(1, "X_vdj_expanded"),
):
try:
layout = vdj.layout[idx]
except Exception:
continue
if layout is None:
continue
coord = pd.DataFrame.from_dict(layout, orient="index")
coord = coord.reindex(index=recipient.obs_names).fillna(np.nan)
if coord.shape[1] == 0:
# Empty layout - skip this embedding
continue
elif coord.shape[1] >= 2:
embedding = coord.iloc[:, :2].to_numpy(dtype=np.float32)
else:
col0 = (
coord.iloc[:, 0]
.to_numpy(dtype=np.float32)
.reshape(-1, 1)
)
col1 = np.zeros_like(col0)
embedding = np.hstack([col0, col1])
_ddl = recipient.uns.setdefault("dandelion", {})
_ddl[obsm_name] = embedding
stored_embeddings[idx] = obsm_name
# Set the "active" embedding safely
active_obsm = stored_embeddings.get(main_idx)
if active_obsm is not None:
_ddl = recipient.uns.setdefault("dandelion", {})
recipient.obsm["X_vdj"] = _ddl[active_obsm]
# break up the message depending on which parts were executed
message_parts = []
if obs:
message_parts += ["updated `.obs` with `.metadata`\n"]
if obsm:
message_parts += [
"wrote active layout to `.obsm['X_vdj']`; stashed all views in `.uns['dandelion']` ('X_vdj_all', 'X_vdj_expanded')\n"
]
if obsp:
message_parts += [
f"wrote `.obsp['connectivities']` & `['distances']` from graph[{main_idx}]\n",
(
f"stashed GEX matrices in `.uns['dandelion']` ('{g_connectivities_key}', '{g_distances_key}')\n"
if not skip_stash
else ""
),
(
f"stashed VDJ matrices in `.uns['dandelion']` under '{v_connectivities_key}_all' / '_expanded'\n"
if main_idx != 2
else f"stashed VDJ matrices in `.uns['dandelion']` under '{v_connectivities_key}_full'\n"
),
]
if uns:
message_parts += [f"added `.uns['{clone_key}']` clone-level mapping"]
# convert back
if original_backend == "polars":
vdj.to_polars(lazy=original_lazy)
# --- 7) Done ---
logg.info(
" finished",
time=start,
deep="".join(message_parts),
)
tf = transfer # alias for transfer
def _graph_to_matrices(
G: nx.Graph | None,
adata: AnnData,
distances: csr_matrix | None = None,
) -> tuple[csr_matrix, csr_matrix]:
"""
Convert a graph or provided distances into properly aligned sparse
connectivities and distances matrices.
Rules:
- If G is provided, convert edges → sparse distance matrix.
- If a CSR distance matrix is provided, must have `._index_names`.
- Reindex to `adata.obs_names` without dense conversion.
- Compute connectivities as exp(-d) on non-zero entries.
- Add tiny self-edge if matrix is entirely empty.
Optimized to avoid memory explosion by minimizing intermediate copies.
"""
target_names = list(adata.obs_names)
n = len(target_names)
name_to_new = {name: i for i, name in enumerate(target_names)}
# CASE A: Build distances from a NetworkX graph
if distances is None and G is not None:
edges = list(G.edges(data=True))
if not edges:
distances = csr_matrix((n, n), dtype=np.float32)
else:
# Single-pass filtering and mapping - no intermediate arrays
valid_edges = []
for u_name, v_name, edge_data in edges:
u_idx = name_to_new.get(u_name)
v_idx = name_to_new.get(v_name)
if u_idx is not None and v_idx is not None:
weight = edge_data.get("weight", 1.0)
valid_edges.append((u_idx, v_idx, weight))
if not valid_edges:
distances = csr_matrix((n, n), dtype=np.float32)
else:
# Unpack and create symmetric matrix directly
u_idx, v_idx, weights = zip(*valid_edges)
u_idx = np.array(u_idx, dtype=np.int32)
v_idx = np.array(v_idx, dtype=np.int32)
weights = np.array(weights, dtype=np.float32)
# Make symmetric - concatenate once
rows = np.concatenate([u_idx, v_idx])
cols = np.concatenate([v_idx, u_idx])
vals = np.concatenate([weights, weights])
vals += 1.0
distances = csr_matrix(
(vals, (rows, cols)), shape=(n, n), dtype=np.float32
)
# Clean up immediately
del rows, cols, vals, u_idx, v_idx, weights, valid_edges
# CASE B: distances provided as a csr_matrix with _index_names
elif isinstance(distances, csr_matrix):
old_names = np.array(distances._index_names)
coo = distances.tocoo()
# Vectorized mapping using searchsorted (much faster than list comprehension)
# Build arrays for lookup
target_arr = np.array(target_names)
# Get unique old names and their mapping
unique_old = np.unique(old_names)
# Find which old names exist in target
# Use np.isin for vectorized membership test
old_in_target = np.isin(unique_old, target_arr)
# Create mapping for valid names only
valid_old_names = unique_old[old_in_target]
if len(valid_old_names) > 0:
# Create mapping using pandas for efficient lookup
name_series = pd.Series(
range(n), index=target_names, dtype=np.int32
)
# Map row and column names
old_row_names = old_names[coo.row]
old_col_names = old_names[coo.col]
row_idx = name_series.reindex(old_row_names, fill_value=-1).values
col_idx = name_series.reindex(old_col_names, fill_value=-1).values
# Keep only valid mappings
mask = (row_idx >= 0) & (col_idx >= 0)
rows = row_idx[mask]
cols = col_idx[mask]
vals = coo.data[mask]
vals += 1.0
distances = csr_matrix(
(vals, (rows, cols)), shape=(n, n), dtype=np.float32
)
# Clean up
del row_idx, col_idx, rows, cols, vals, mask
else:
distances = csr_matrix((n, n), dtype=np.float32)
# Build connectivities = exp(-d) for non-zero entries
connectivities = distances.copy()
if connectivities.nnz > 0:
connectivities.data = np.exp(-connectivities.data)
connectivities.data = np.clip(connectivities.data, 1e-45, np.inf)
distances.data -= 1.0
# Ensure matrix is not completely empty - avoid expensive conversion
if connectivities.nnz == 0:
# Create minimal non-empty matrix directly without conversion
connectivities = csr_matrix(
([1e-10], ([0], [0])), shape=(n, n), dtype=np.float32
)
distances = csr_matrix(
([0.0], ([0], [0])), shape=(n, n), dtype=np.float32
)
return connectivities, distances
[docs]
def clone_view(
adata: AnnData,
mode: Literal["all", "expanded", "full", "gex"] | None = "expanded",
connectivities_key: str | None = None,
distances_key: str | None = None,
embedding_key: str | None = None,
):
"""
Swap the 'active' connectivities, distances, and optionally embedding in AnnData.
Parameters
----------
adata : AnnData
The AnnData object.
mode : Literal["all", "expanded", "full", "gex"] | None, optional
If specified, set the active connectivities/distances/embedding to one of the preset modes.
connectivities_key : str | None, optional
The key in `.obsp` to set as active `.obsp["connectivities"]` if `mode` is None.
distances_key : str | None, optional
The key in `.obsp` to set as active `.obsp["distances"]` if `mode` is None.
embedding_key : str | None, optional
If specified, set `.obsm["X_vdj"]` to `.obsm[embedding_key]` if `mode` is None.
Raises
------
KeyError
if the requested connectivities, distances, or embedding key is not found.
"""
if mode is None:
# use the other key directly
_ddl = adata.uns.get("dandelion", {})
if connectivities_key in _ddl:
adata.obsp["connectivities"] = _ddl[connectivities_key]
else:
raise KeyError(
f"{connectivities_key} not found in adata.uns['dandelion']"
)
if distances_key in _ddl:
adata.obsp["distances"] = _ddl[distances_key]
else:
raise KeyError(
f"{distances_key} not found in adata.uns['dandelion']"
)
if embedding_key is not None:
if embedding_key in adata.obsm:
adata.obsm["X_vdj"] = adata.obsm[embedding_key]
else:
raise KeyError(f"{embedding_key} not found in adata.obsm")
else:
if mode == "gex":
conn_key = f"{mode}_connectivities"
dist_key = f"{mode}_distances"
neighbors_key = f"{mode}_neighbors"
emb_key = None
else:
conn_key = f"vdj_connectivities_{mode}"
dist_key = f"vdj_distances_{mode}"
neighbors_key = None
emb_key = f"X_vdj_{mode}" if mode != "full" else None
_ddl = adata.uns.get("dandelion", {})
if conn_key not in _ddl:
raise KeyError(f"{conn_key} not found in adata.uns['dandelion']")
if dist_key not in _ddl:
raise KeyError(f"{dist_key} not found in adata.uns['dandelion']")
adata.obsp["connectivities"] = _ddl[conn_key]
adata.obsp["distances"] = _ddl[dist_key]
if emb_key is not None:
if emb_key not in _ddl:
raise KeyError(f"{emb_key} not found in adata.uns['dandelion']")
adata.obsm["X_vdj"] = _ddl[emb_key]
if neighbors_key is not None:
adata.uns["neighbors"] = adata.uns[neighbors_key]
else:
adata.uns["neighbors"] = {
"connectivities_key": "connectivities",
"distances_key": "distances",
"params": {
"n_neighbors": 1,
"method": "custom",
"metric": "precomputed",
},
}
def _categorize_clone_size(size: int | float, max_size: int) -> str | float:
"""
Categorize clone size into bins or clip at maximum.
Parameters
----------
size : int | float
Clone size value
max_size : int
Maximum size for categorization
Returns
-------
str | float
String category if size is valid, otherwise NaN
"""
if pd.isna(size):
return np.nan
if size < max_size:
return str(int(size))
else:
return f">= {max_size}"
[docs]
def clone_size(
vdj: DandelionPolars | AnnData | MuData,
group_by: str | None = None,
max_size: int | None = None,
clone_key: str | None = None,
key_added: str | None = None,
) -> None:
"""
Quantify clone sizes, globally or per group.
For each clone, the **proportion** is defined as the number of cells
belonging to that clone divided by the denominator:
- **Global** (``group_by=None``): denominator is the total number of
cells in the metadata.
- **Per group** (``group_by`` specified): denominator is the total number
of cells within that group. Proportions are therefore independent
across groups.
Each clone proportion is then mapped to a **frequency category** using
the following bins (matching scRepertoire conventions):
+------------------+-------------------------+
| Category | Proportion range |
+==================+=========================+
| Rare | 0 – 0.0001 |
+------------------+-------------------------+
| Small | 0.0001 – 0.001 |
+------------------+-------------------------+
| Medium | 0.001 – 0.01 |
+------------------+-------------------------+
| Large | 0.01 – 0.1 |
+------------------+-------------------------+
| Hyperexpanded | 0.1 – 1 |
+------------------+-------------------------+
If a cell is assigned to multiple clones (e.g. multiple chains mapped to
different clone IDs, separated by ``|``), the clone with the largest size
is used for all annotation columns.
The following columns are added to the metadata:
- ``{key_added}_size`` : number of cells in the clone.
- ``{key_added}_size_prop`` : clone proportion (see above).
- ``{key_added}_size_category`` : frequency category label
(Rare / Small / Medium / Large / Hyperexpanded).
- ``{key_added}_size_max_{max_size}`` : *(only when* ``max_size`` *is
set)* clone size as a string, with any size ≥ ``max_size`` collapsed
to the label ``">= {max_size}"``.
Parameters
----------
vdj : DandelionPolars | AnnData | MuData
VDJ data.
group_by : str | None, optional
Column in metadata to group by before calculating clone sizes.
If None, calculates global clone sizes across all cells.
max_size : int | None, optional
When provided, adds an extra column where clone sizes are represented
as string labels; sizes strictly below ``max_size`` are kept as their
integer value, while sizes ≥ ``max_size`` are labelled
``">= {max_size}"``.
clone_key : str | None, optional
Column specifying clone identifiers. Defaults to ``'clone_id'``.
key_added : str | None, optional
Prefix for the new metadata column names
(e.g. ``{key_added}_size``, ``{key_added}_size_prop``).
Defaults to the value of ``clone_key``.
Raises
------
KeyError
if ``clone_key`` is not found in metadata.
"""
# --- Select metadata
if hasattr(vdj, "mod"):
metadata_ = vdj.mod["airr"].obs.copy()
elif isinstance(vdj, AnnData):
metadata_ = vdj.obs.copy()
elif isinstance(vdj, DandelionPolars):
original_backend = vdj._backend
if original_backend == "polars":
# originally lazy or not
original_lazy = vdj._lazy
vdj.to_pandas()
else:
original_lazy = False
metadata_ = vdj._metadata.copy()
clone_key = "clone_id" if clone_key is None else clone_key
if clone_key not in metadata_.columns:
raise KeyError(f"Column '{clone_key}' not found in metadata.")
# --- Expand multi-clone entries
tmp = metadata_[clone_key].astype(str).str.split("|", expand=True).stack()
# drop None/No_contig entries
tmp = tmp[~tmp.isin(["No_contig", "unassigned"] + EMPTIES)]
tmp = tmp.reset_index(drop=False)
tmp.columns = ["cell_id", "tmp", clone_key]
# --- Compute clone sizes (global or per group)
if group_by is None:
clonesize = tmp[clone_key].value_counts()
prop = clonesize / metadata_.shape[0]
else:
# Merge with group_by column using cell_id as key
# Reset index to make cell_id a regular column for merging
metadata_with_index = metadata_.reset_index()
metadata_with_index = metadata_with_index.rename(
columns={"index": "cell_id"}
)
tmp = tmp.merge(
metadata_with_index[["cell_id", group_by]], on="cell_id", how="left"
)
clonesize = tmp.groupby([group_by, clone_key]).size()
group_sizes = metadata_[group_by].value_counts()
# Calculate proportion correctly for each group
prop_dict = {}
for grp in clonesize.index.get_level_values(0).unique():
group_clones = clonesize.loc[grp]
group_total = group_sizes[grp]
for clone_id, size in group_clones.items():
prop_dict[(grp, clone_id)] = size / group_total
# Create Series with MultiIndex
prop = pd.Series(prop_dict)
prop.index = pd.MultiIndex.from_tuples(
prop.index, names=[group_by, clone_key]
)
# --- Create max_size categories if specified
if max_size is not None:
# Use partial to bind max_size to the helper function
categorize_fn = functools.partial(
_categorize_clone_size, max_size=max_size
)
clonesize_cat = clonesize.apply(categorize_fn)
clonesize_cat_map = clonesize_cat.to_dict()
# --- Define clone frequency bins
bins = [0, 0.0001, 0.001, 0.01, 0.1, 1]
labels = ["Rare", "Small", "Medium", "Large", "Hyperexpanded"]
if group_by is None:
prop_bins = pd.cut(prop, bins=bins, labels=labels, include_lowest=True)
else:
# Apply pd.cut to the entire Series at once, preserving the MultiIndex
prop_bins = pd.cut(prop, bins=bins, labels=labels, include_lowest=True)
# --- Build lookup maps
size_map = clonesize.to_dict()
prop_map = prop.to_dict()
cat_map = prop_bins.to_dict()
# --- Assign to each cell
cell_sizes = []
cell_props = []
cell_cats = []
cell_size_cats = [] if max_size is not None else None
for i, row in metadata_.iterrows():
clone_ids = str(row[clone_key])
# Check for empty/invalid entries
if pd.isna(clone_ids) or clone_ids in [
"No_contig",
"unassigned",
"None",
"nan",
]:
cell_sizes.append(np.nan)
cell_props.append(np.nan)
cell_cats.append(np.nan)
if max_size is not None:
cell_size_cats.append(np.nan)
continue
clones = clone_ids.split("|")
if group_by is None:
# look up sizes directly
sizes = [size_map.get(c, np.nan) for c in clones]
props = [prop_map.get(c, np.nan) for c in clones]
cats = [cat_map.get(c, np.nan) for c in clones]
if max_size is not None:
size_cats = [clonesize_cat_map.get(c, np.nan) for c in clones]
else:
grp = row[group_by]
# Use tuple keys for grouped lookups
sizes = [size_map.get((grp, c), np.nan) for c in clones]
props = [prop_map.get((grp, c), np.nan) for c in clones]
cats = [cat_map.get((grp, c), np.nan) for c in clones]
if max_size is not None:
size_cats = [
clonesize_cat_map.get((grp, c), np.nan) for c in clones
]
# take the largest available clone (by numeric size)
if len(sizes) == 0 or all(pd.isna(sizes)):
cell_sizes.append(np.nan)
cell_props.append(np.nan)
cell_cats.append(np.nan)
if max_size is not None:
cell_size_cats.append(np.nan)
else:
max_idx = np.nanargmax(sizes)
cell_sizes.append(sizes[max_idx])
cell_props.append(props[max_idx])
cell_cats.append(cats[max_idx])
if max_size is not None:
cell_size_cats.append(size_cats[max_idx])
metadata_[f"{clone_key}_size"] = cell_sizes
metadata_[f"{clone_key}_size_prop"] = cell_props
metadata_[f"{clone_key}_size_category"] = cell_cats
if max_size is not None:
metadata_[f"{clone_key}_size_max_{max_size}"] = cell_size_cats
# --- Write results back to object
col_key = key_added if key_added is not None else clone_key
if isinstance(vdj, DandelionPolars):
vdj._metadata[f"{col_key}_size"] = metadata_[f"{clone_key}_size"]
vdj._metadata[f"{col_key}_size_prop"] = metadata_[
f"{clone_key}_size_prop"
]
vdj._metadata[f"{col_key}_size_category"] = metadata_[
f"{clone_key}_size_category"
]
if max_size is not None:
vdj._metadata[f"{col_key}_size_max_{max_size}"] = metadata_[
f"{clone_key}_size_max_{max_size}"
]
# check if lazy backend and sync
if original_backend == "polars":
vdj.to_polars(lazy=original_lazy)
elif isinstance(vdj, AnnData):
vdj.obs[f"{col_key}_size"] = metadata_[f"{clone_key}_size"]
vdj.obs[f"{col_key}_size_prop"] = metadata_[f"{clone_key}_size_prop"]
vdj.obs[f"{col_key}_size_category"] = metadata_[
f"{clone_key}_size_category"
]
if max_size is not None:
vdj.obs[f"{col_key}_size_max_{max_size}"] = metadata_[
f"{clone_key}_size_max_{max_size}"
]
elif hasattr(vdj, "mod"):
vdj.mod["airr"].obs[f"{col_key}_size"] = metadata_[f"{clone_key}_size"]
vdj.mod["airr"].obs[f"{col_key}_size_prop"] = metadata_[
f"{clone_key}_size_prop"
]
vdj.mod["airr"].obs[f"{col_key}_size_category"] = metadata_[
f"{clone_key}_size_category"
]
if max_size is not None:
vdj.mod["airr"].obs[f"{col_key}_size_max_{max_size}"] = metadata_[
f"{clone_key}_size_max_{max_size}"
]
[docs]
def clone_overlap(
vdj: DandelionPolars | AnnData,
group_by: str,
min_clone_size: int | None = None,
weighted_overlap: bool = False,
clone_key: str | None = None,
) -> pd.DataFrame:
"""
A function to tabulate clonal overlap for input as a circos-style plot.
Parameters
----------
vdj : DandelionPolars | AnnData
DandelionPolars or AnnData object.
group_by : str
column name in obs/metadata for collapsing to columns in the clone_id x group_by data frame.
min_clone_size : int | None, optional
minimum size of clone for plotting connections. Defaults to 2 if left as None.
weighted_overlap : bool, optional
if True, instead of collapsing to overlap to binary, overlap will be returned as the number of cells.
In the future, there will be the option to use something like a jaccard index.
clone_key : str | None, optional
column name for clones. `None` defaults to 'clone_id'.
Returns
-------
pd.DataFrame
clone_id x group_by overlap :class:`pandas.core.frame.DataFrame'.
Raises
------
ValueError
if min_clone_size is 0.
"""
start = logg.info("Calculating clone overlap")
if isinstance(vdj, DandelionPolars):
if vdj._backend == "polars":
vdj.to_pandas()
data = vdj._metadata.copy()
elif isinstance(vdj, AnnData):
data = vdj.obs.copy()
elif hasattr(vdj, "mod"):
data = vdj.mod["airr"].obs.copy()
if min_clone_size is None:
min_size = 2
else:
min_size = int(min_clone_size)
if clone_key is None:
clone_ = "clone_id"
else:
clone_ = clone_key
# get rid of problematic rows that appear because of category conversion?
allgroups = list(data[group_by].unique())
data = data[
~(
data[clone_].isin(
[np.nan, "nan", "NaN", "No_contig", "unassigned", "None", None]
)
)
]
# prepare a summary table
datc_ = data[clone_].str.split("|", expand=True).stack()
datc_ = pd.DataFrame(datc_)
datc_.reset_index(drop=False, inplace=True)
datc_.columns = ["cell_id", "tmp", clone_]
datc_.drop("tmp", inplace=True, axis=1)
datc_ = datc_[
~(
datc_[clone_].isin(
[
"",
np.nan,
"nan",
"NaN",
"No_contig",
"unassigned",
"None",
None,
]
)
)
]
dictg_ = dict(data[group_by])
datc_[group_by] = [dictg_[cell] for cell in datc_["cell_id"]]
overlap = pd.crosstab(datc_[clone_], datc_[group_by])
for x in allgroups:
if x not in overlap:
overlap[x] = 0
if min_size == 0:
raise ValueError("min_size must be greater than 0.")
if not weighted_overlap:
if min_size > 2:
overlap[overlap < min_size] = 0
overlap[overlap >= min_size] = 1
elif min_size == 2:
overlap[overlap >= min_size] = 1
overlap.index.name = None
overlap.columns.name = None
if isinstance(vdj, AnnData):
vdj.uns["clone_overlap"] = overlap.copy()
logg.info(
" finished",
time=start,
deep=("Updated AnnData: \n" " 'uns', clone overlap table"),
)
else:
return overlap
@contextmanager
def _vj_usage_context(
adata: AnnData, vdj: DandelionPolars, mode: Literal["B", "abT", "gdT"]
):
"""
Context manager that temporarily adds V/J gene columns to adata.obs
and removes them upon exit.
"""
# Get the split data using celltype mode
# _split_first gives us the "main" (first) gene
v_call_main = vdj._split_first(
cols="v_call", key_added=f"v_call_{mode}", celltype=mode
)
j_call_main = vdj._split_first(
cols="j_call", key_added=f"j_call_{mode}", celltype=mode
)
# _split with join=True gives us all genes joined with "|"
v_call_full = vdj._split(
cols="v_call",
key_added=f"v_call_{mode}",
join=True,
unique=False,
celltype=mode,
)
j_call_full = vdj._split(
cols="j_call",
key_added=f"j_call_{mode}",
join=True,
unique=False,
celltype=mode,
)
# Merge all splits into one dataframe
merged = (
v_call_main.drop("celltype_group")
.join(j_call_main.drop("celltype_group"), on="cell_id", how="left")
.join(
v_call_full.drop("celltype_group"),
on="cell_id",
how="left",
suffix="_full",
)
.join(
j_call_full.drop("celltype_group"),
on="cell_id",
how="left",
suffix="_full",
)
)
# Convert to pandas and set index
merged_pd = merged.to_pandas().set_index("cell_id")
# Track original obs
original_obs = adata.obs.copy()
try:
# Add columns to adata.obs
for col in merged_pd.columns:
adata.obs[col] = merged_pd[col]
yield adata
finally:
# Clean up: restore original obs
adata.obs = original_obs
[docs]
def vj_usage_pca(
adata: AnnData,
vdj: DandelionPolars,
group_by: str,
min_size: int = 20,
mode: Literal["B", "abT", "gdT"] = "abT",
use_vdj_v: bool = True,
use_vdj_j: bool = True,
use_vj_v: bool = True,
use_vj_j: bool = True,
transfer_mapping=None,
n_comps: int = 30,
groups: list[str] | None = None,
allowed_chain_status: list[str] | None = [
"Single pair",
"Extra pair",
"Extra pair-exception",
"Orphan VDJ-exception",
],
verbose=False,
**kwargs,
) -> AnnData:
"""
Extract productive V/J gene usage from single cell data and compute PCA.
Parameters
----------
adata : AnnData
AnnData object holding the cell level metadata.
vdj : DandelionPolars
Dandelion VDJ object to extract V/J usage from.
group_by : str
Column name in `adata.obs` to group_by as observations for PCA.
min_size : int, optional
Minimum cell size numbers to keep for computing the final matrix. Defaults to 20.
mode : Literal["B", "abT", "gdT"], optional
Mode for extract the V/J genes.
use_vdj_v : bool, optional
Whether to use V gene from VDJ contigs for tabulation. Defaults to True.
use_vdj_j : bool, optional
Whether to use J gene from VDJ contigs for tabulation. Defaults to True.
use_vj_v : bool, optional
Whether to use V genes from VJ contigs for tabulation. Defaults to True.
use_vj_j : bool, optional
Whether to use J genes from VJ contigs for tabulation. Defaults to True.
transfer_mapping : None, optional
If provided, the columns will be mapped to the output AnnData from the original AnnData.
n_comps : int, optional
Number of principal components to compute. Defaults to 30.
groups : list[str] | None, optional
If provided, only the following groups/categories will be used for computing the PCA.
allowed_chain_status : list[str] | None, optional
If provided, only the ones in this list are kept from the `chain_status` column.
Defaults to ["Single pair", "Extra pair", "Extra pair-exception", "Orphan VDJ-exception"].
verbose : bool, optional
Whether to display progress
**kwargs
Additional keyword arguments passed to `scanpy.pp.pca`.
Returns
-------
AnnData
AnnData object with obs as groups and V/J genes as features.
"""
start = logg.info("Computing PCA for V/J gene usage")
# Use context manager to temporarily add V/J columns
with _vj_usage_context(adata, vdj, mode) as adata_:
# filtering
if allowed_chain_status is not None:
adata_ = adata_[
adata_.obs["chain_status"].isin(allowed_chain_status)
].copy()
if groups is not None:
adata_ = adata_[adata_.obs[group_by].isin(groups)].copy()
# build config - now using the temporarily added columns
gene_config = {
"vdj_v": dict(
enabled=use_vdj_v,
main=f"v_call_{mode}_VDJ",
full=f"v_call_{mode}_VDJ_full",
),
"vdj_j": dict(
enabled=use_vdj_j,
main=f"j_call_{mode}_VDJ",
full=f"j_call_{mode}_VDJ_full",
),
"vj_v": dict(
enabled=use_vj_v,
main=f"v_call_{mode}_VJ",
full=f"v_call_{mode}_VJ_full",
),
"vj_j": dict(
enabled=use_vj_j,
main=f"j_call_{mode}_VJ",
full=f"j_call_{mode}_VJ_full",
),
}
if not any(cfg["enabled"] for cfg in gene_config.values()):
raise ValueError("At least one of the use_vj/vdj_v/j must be True.")
# Determine which groups to keep
cell_counts = adata_.obs[group_by].value_counts()
keep_groups = cell_counts[cell_counts >= min_size].index
# collect gene lists
gene_lists = {}
for key, cfg in gene_config.items():
if cfg["enabled"]:
uniq = adata_.obs[cfg["main"]].unique().tolist()
gene_lists[key] = [
g
for g in uniq
if g not in ("None", "No_contig", None) and pd.notna(g)
]
else:
gene_lists[key] = []
all_genes = [g for genes in gene_lists.values() for g in genes]
# initialise results df
vdj_df = pd.DataFrame(
index=keep_groups, columns=all_genes, dtype=float
).fillna(0)
# count genes per group
for group in tqdm(
vdj_df.index,
desc="Tabulating V/J gene usage",
disable=not verbose,
):
group_mask = adata_.obs[group_by] == group
obs_group = adata_.obs.loc[group_mask]
for key, cfg in gene_config.items():
if not cfg["enabled"]:
continue
# Handle pipe-separated values from join=True
all_values = []
for val in obs_group[cfg["full"]]:
if pd.notna(val) and val not in ("None", "No_contig"):
all_values.extend(str(val).split("|"))
counts = Counter(all_values)
for gene in gene_lists[key]:
vdj_df.loc[group, gene] = counts.get(gene, 0)
# normalize each chain separately
for key, cfg in gene_config.items():
if not cfg["enabled"]:
continue
cols = gene_lists[key]
colsum = vdj_df[cols].sum(axis=1)
# Avoid division by zero
colsum = colsum.replace(0, 1)
vdj_df.loc[:, cols] = vdj_df[cols].div(colsum, axis=0) * 100
# Create new AnnData + PCA (outside context manager)
obs_df = pd.DataFrame(index=vdj_df.index)
obs_df["cell_type"] = vdj_df.index
obs_df["cell_count"] = cell_counts.loc[vdj_df.index]
vdj_adata = AnnData(
X=vdj_df.values,
obs=obs_df,
var=pd.DataFrame(index=vdj_df.columns),
)
sc.pp.pca(vdj_adata, n_comps=n_comps, use_highly_variable=False, **kwargs)
# Transfer old obs columns to new AnnData
if transfer_mapping is not None:
# Need to get the original adata for this
collapsed = adata.obs.drop_duplicates(subset=group_by)
for to in transfer_mapping:
mapping = dict(zip(collapsed[group_by], collapsed[to]))
vdj_adata.obs[to] = vdj_adata.obs.index.map(mapping)
logg.info(
" finished",
time=start,
deep=("Returned AnnData: \n" " 'obsm', X_pca for V/J gene usage"),
)
return vdj_adata
[docs]
def vdj_sample(
vdj_data: DandelionPolars,
size: int | None = None,
adata: AnnData | MuData | None = None,
p: list[float] | np.ndarray[float] | None = None,
force_replace: bool = False,
random_state: int | np.random.RandomState | None = None,
) -> tuple[DandelionPolars, AnnData] | DandelionPolars:
"""
Resample vdj data and corresponding AnnData to a specified size.
Parameters
----------
vdj_data : DandelionPolars
Dandelion object containing VDJ data.
size : int
Desired size for resampling.
adata : AnnData | MuData | None, optional
AnnData or MuData object corresponding to the gene expression data.
p : list[float] | np.ndarray[float] | None, optional
Drawing probabilities for each cell, must sum to 1. If None, uniform probabilities are used.
force_replace : bool, optional
Whether to force sampling with replacement, by default False.
random_state : int | np.random.RandomState | None, optional
Random state for reproducibility, by default None.
Returns
-------
tuple[DandelionPolars, AnnData] | DandelionPolars
Resampled Dandelion and AnnData objects if adata is provided, otherwise only Dandelion.
"""
if size is None:
raise TypeError("vdj_sample requires `size` to be provided.")
vdj = vdj_data
logg.info("Resampling to {} cells.".format(str(size)))
rng = np.random.default_rng(random_state)
if adata is None:
# Determine if we need replacement based on metadata size (one row per cell)
n_cells = vdj.n_obs
replace = (size > n_cells) or force_replace
# Normalize probabilities if provided
if p is not None:
p_array = np.asarray(p, dtype=float)
if p_array.ndim != 1 or p_array.shape[0] != n_cells:
raise ValueError(
"`p` must be a 1D array-like with one value per cell in metadata."
)
# Treat missing/non-finite probabilities as zero weight.
p_array[~np.isfinite(p_array)] = 0.0
if np.any(p_array < 0):
raise ValueError("`p` must not contain negative probabilities.")
p_sum = p_array.sum()
if p_sum <= 0:
raise ValueError(
"`p` sums to 0 after cleaning missing values; provide at least one positive probability."
)
p_array = p_array / p_sum
else:
p_array = None
# Sample indices using numpy (faster than polars/pandas sampling)
sample_indices = rng.choice(
n_cells, size=size, replace=replace, p=p_array
)
# Get cell IDs at sampled indices - only collect what we need
if isinstance(vdj._metadata, pl.LazyFrame):
# Stay lazy and only get the cell_id column
keep_cells = (
vdj._metadata.select("cell_id")
.collect(streaming=True)
.to_series()
.gather(sample_indices)
.to_list()
)
elif isinstance(vdj._metadata, pl.DataFrame):
keep_cells = (
vdj._metadata["cell_id"].gather(sample_indices).to_list()
)
else:
# pandas DataFrame
keep_cells = vdj._metadata.iloc[sample_indices].index.tolist()
else:
# Check if MuData and extract the gex modality
if hasattr(adata, "mod"):
adata = adata.mod["gex"].copy()
else:
adata = adata.copy()
# Get common cells between vdj and adata - only collect cell_id column
if isinstance(vdj._metadata, pl.LazyFrame):
vdj_cell_ids = set(
vdj._metadata.select("cell_id")
.collect(streaming=True)["cell_id"]
.to_list()
)
elif isinstance(vdj._metadata, pl.DataFrame):
vdj_cell_ids = set(vdj._metadata["cell_id"].to_list())
else:
vdj_cell_ids = set(vdj._metadata.index)
common_cells = list(vdj_cell_ids.intersection(set(adata.obs_names)))
# Filter to common cells - pass the list of common_cells directly
# The __getitem__ will treat it as cell_ids to filter by
adata = adata[adata.obs_names.isin(common_cells)].copy()
vdj_filtered = vdj[common_cells]
# Determine replacement based on filtered vdj
n_cells = vdj_filtered.n_obs
replace = (size > n_cells) or force_replace
# Use scanpy to sample
sc.pp.sample(adata, n=size, replace=replace, rng=random_state, p=p)
keep_cells = list(adata.obs_names)
# Use the filtered vdj for downstream operations
vdj = vdj_filtered
# Now filter the DATA (contigs) - stay lazy as long as possible
vdj_dat = vdj._data
# Check if ambiguous column exists
if isinstance(vdj_dat, (pl.LazyFrame, pl.DataFrame)):
cols = set(vdj_dat.collect_schema().names())
has_ambiguous = "ambiguous" in cols
else:
has_ambiguous = "ambiguous" in vdj_dat.columns
# Apply filters to data while staying lazy
if isinstance(vdj_dat, pl.LazyFrame):
# Chain filters while staying lazy
if has_ambiguous:
vdj_dat = vdj_dat.filter(pl.col("ambiguous").is_in(FALSES_STR))
vdj_dat = vdj_dat.filter(pl.col("cell_id").is_in(keep_cells))
# Only collect if we need replacement logic
if replace:
vdj_dat = vdj_dat.collect(streaming=True)
elif isinstance(vdj_dat, pl.DataFrame):
if has_ambiguous:
vdj_dat = vdj_dat.filter(pl.col("ambiguous").is_in(FALSES_STR))
vdj_dat = vdj_dat.filter(pl.col("cell_id").is_in(keep_cells))
else:
# pandas DataFrame
if has_ambiguous:
vdj_dat = vdj_dat[vdj_dat["ambiguous"].isin(FALSES)].copy()
vdj_dat = vdj_dat[vdj_dat["cell_id"].isin(keep_cells)].copy()
# Handle replacement (requires collected data)
if replace:
cell_counts = Counter(keep_cells)
duplicated_cells = {
cell: count for cell, count in cell_counts.items() if count > 1
}
if duplicated_cells:
if isinstance(vdj_dat, pl.DataFrame):
# Separate data for duplication
duplicated_mask = pl.col("cell_id").is_in(
list(duplicated_cells.keys())
)
vdj_dat_to_duplicate = vdj_dat.filter(duplicated_mask)
vdj_dat_to_keep = vdj_dat.filter(~duplicated_mask)
# Create duplicates
all_duplicated_vdj = []
for cell_id, count in duplicated_cells.items():
cell_rows = vdj_dat_to_duplicate.filter(
pl.col("cell_id") == cell_id
)
for i in range(count):
temp_rows = cell_rows.clone()
if i > 0:
suffix = f"-{i}"
temp_rows = temp_rows.with_columns(
[
(pl.col("cell_id") + suffix).alias(
"cell_id"
),
(pl.col("sequence_id") + suffix).alias(
"sequence_id"
),
]
)
all_duplicated_vdj.append(temp_rows)
# Combine everything
vdj_dat = pl.concat([vdj_dat_to_keep] + all_duplicated_vdj)
else:
# pandas DataFrame
vdj_dat_to_duplicate = vdj_dat[
vdj_dat["cell_id"].isin(duplicated_cells.keys())
].copy()
vdj_dat_to_keep = vdj_dat[
~vdj_dat["cell_id"].isin(duplicated_cells.keys())
].copy()
all_duplicated_vdj = []
for cell_id, count in duplicated_cells.items():
cell_rows = vdj_dat_to_duplicate[
vdj_dat_to_duplicate["cell_id"] == cell_id
].copy()
for i in range(count):
temp_rows = cell_rows.copy()
if i > 0:
suffix = f"-{i}"
temp_rows["cell_id"] = temp_rows["cell_id"] + suffix
temp_rows["sequence_id"] = (
temp_rows["sequence_id"] + suffix
)
all_duplicated_vdj.append(temp_rows)
vdj_dat = pd.concat(
[vdj_dat_to_keep] + all_duplicated_vdj, ignore_index=True
)
# Reinitialize Dandelion object
vdj = DandelionPolars(vdj_dat)
if adata is not None:
adata.obs_names_make_unique()
if hasattr(adata, "mod"):
return vdj, to_scirpy(vdj, gex_adata=adata)
else:
return vdj, adata
else:
return vdj
[docs]
def to_scirpy(
data: DandelionPolars,
transfer: bool = False,
to_mudata: bool = True,
gex_adata: AnnData | None = None,
key: tuple[str, str] = ("gex", "airr"),
**kwargs,
) -> AnnData | MuData:
"""
Convert Dandelion data to scirpy-compatible format.
Parameters
----------
data: DandelionPolars
The Dandelion object containing the data to be converted.
transfer : bool, optional
Whether to transfer additional information from Dandelion to the converted data. Defaults to False.
to_mudata : bool, optional
Whether to convert the data to MuData format instead of AnnData. Defaults to True.
If converting to AnnData, it will assert that the same cell_ids and .obs_names are present in the `gex_adata` provided.
gex_adata : AnnData, optional
An existing AnnData object to be used as the base for the converted data if provided.
key : tuple[str, str], optional
A tuple specifying the keys for the 'gex' and 'airr' fields in the converted data. Defaults to ("gex", "airr").
**kwargs
Additional keyword arguments passed to `scirpy.io.read_airr`.
Returns
-------
AnnData | MuData
The converted data in either AnnData or MuData format.
Raises
------
ImportError
if ``scirpy`` is not installed.
"""
original_backend = data._backend
original_lazy = data._lazy
if original_backend == "polars":
data.to_pandas()
# if gex_adata is provided, make sure to only transfer cells that are present in both
# we will only filter the data to match gex_adata
if gex_adata is not None:
data = data[data.metadata_names.isin(gex_adata.obs_names)].copy()
if data._backend == "polars":
data.to_pandas()
tmp_gex = gex_adata.copy()
if not to_mudata:
tf(
tmp_gex, data, obs=False, uns=True, obsp=False, obsm=False
) # so that the slots are properly filled
else:
tmp_gex = None
if "umi_count" not in data._data and "duplicate_count" in data._data:
data._data["umi_count"] = data._data["duplicate_count"]
for h in [
"sequence",
"rev_comp",
"sequence_alignment",
"germline_alignment",
"v_cigar",
"d_cigar",
"j_cigar",
]:
if h not in data._data:
data._data[h] = None
airr, obs = to_ak(data._data, **kwargs)
# conver back to original backend
if original_backend == "polars":
data.to_polars(lazy=original_lazy)
if to_mudata:
airr_adata = _create_anndata(airr, obs)
if tmp_gex is not None:
tf(
airr_adata,
data,
obs=False,
uns=True,
obsp=False,
obsm=False,
)
mdata = _create_mudata(tmp_gex, airr_adata, key)
if transfer:
tf(mdata, data)
return mdata
else:
adata = _create_anndata(airr, obs, tmp_gex)
if transfer:
tf(adata, data)
return adata
[docs]
def from_scirpy(data: AnnData | MuData) -> DandelionPolars:
"""
Convert data from scirpy format to Dandelion format.
Parameters
----------
data : AnnData | MuData
The input data in scirpy format.
Returns
-------
DandelionPolars
The converted data in Dandelion format.
"""
if not isinstance(data, AnnData):
data = data.mod["airr"]
data = data.copy()
data.obsm["airr"]["cell_id"] = data.obs.index
df = from_ak(data.obsm["airr"])
vdj = DandelionPolars(df, verbose=False)
# Reverse transfer (recover metadata + clone graph)
_reverse_transfer(data, vdj)
return vdj
def _reverse_transfer(
data: AnnData | MuData,
dandelion: DandelionPolars,
clone_key: str = "clone_id",
) -> None:
"""
Reverse-transfer scirpy data (AnnData/MuData) into a Dandelion object.
Pulls metadata, clone mappings, graphs, and embeddings from scirpy's structure.
Parameters
----------
data : AnnData | MuData
Input scirpy object (AnnData or MuData with .mod['airr']).
dandelion : Dandelion
The Dandelion object to update in place.
clone_key : str, optional
Key under .uns containing scirpy clone-level mapping (default: 'clone_id').
"""
# --- Handle MuData case ---
if hasattr(data, "mod"):
if "airr" not in data.mod:
raise ValueError(
"MuData object must contain an 'airr' modality for scirpy data."
)
adata = data.mod["airr"]
else:
adata = data
# --- Copy metadata ---
existing_cols = (
dandelion._metadata.collect_schema().names()
if isinstance(dandelion._metadata, pl.LazyFrame)
else dandelion._metadata.columns
)
new_cols = [col for col in adata.obs if col not in existing_cols]
if new_cols:
obs_sub = adata.obs[new_cols].copy()
obs_sub.index.name = "cell_id"
obs_pl = pl.from_pandas(obs_sub.reset_index())
if isinstance(dandelion._metadata, pl.LazyFrame):
obs_pl = obs_pl.lazy()
dandelion._metadata = dandelion._metadata.join(
obs_pl, on="cell_id", how="left"
)
# --- Extract clone-level connection info ---
if clone_key in adata.uns:
clone_uns = adata.uns[clone_key]
distances = clone_uns["distances"]
cell_indices = clone_uns["cell_indices"]
# --- Rebuild graph ---
G = nx.from_scipy_sparse_array(distances)
# Relabel nodes: scirpy stores numeric keys ("0", "1", ...) mapped to arrays of cell_ids
mapping = {}
for k, v in cell_indices.items():
k_int = int(k)
if isinstance(v, (list, np.ndarray)):
# If clone node has multiple cells, store them all in node attribute
mapping[k_int] = str(v[0]) if len(v) > 0 else str(k)
G.nodes[k_int]["cells"] = list(v)
else:
mapping[k_int] = str(v)
G.nodes[k_int]["cells"] = [v]
G = nx.relabel_nodes(G, mapping)
# Store the graph
dandelion.graph = [G, None]
# map the obs back to data as well
dandelion.update_data()
def from_ak(airr: Array) -> pd.DataFrame:
"""
Convert an AIRR-formatted array to a pandas DataFrame.
Parameters
----------
airr : Array
The AIRR-formatted array to be converted.
Returns
-------
pd.DataFrame
The converted pandas DataFrame.
Raises
------
KeyError
If `sequence_id` not found in the data.
"""
import awkward as ak
df = ak.to_dataframe(airr)
# check if 'sequence_id' column does not exist or if any value in 'sequence_id' is NaN
if "sequence_id" not in df.columns or df["sequence_id"].isnull().any():
df_reset = df.reset_index()
# create a new 'sequence_id' column
df_reset["sequence_id"] = df_reset.apply(
lambda row: f"{row['cell_id']}_contig_{row['subentry'] + 1}", axis=1
)
# set 'entry' and 'subentry' back as the index
df = df_reset.set_index(["entry", "subentry"])
if "sequence_id" in df.columns:
df.set_index("sequence_id", drop=False, inplace=True)
if "cell_id" not in df.columns:
df["cell_id"] = [c.split("_contig")[0] for c in df["sequence_id"]]
return df
def to_ak(
data: pd.DataFrame,
**kwargs,
) -> tuple[Array, pd.DataFrame]:
"""
Convert data from a DataFrame to an AnnData object with AIRR format.
Parameters
----------
data : pd.DataFrame
The input DataFrame containing the data.
**kwargs
Additional keyword arguments passed to `scirpy.io.read_airr`.
Returns
-------
tuple[Array, pd.DataFrame]
A tuple containing the AIRR-formatted data as an ak.Array and the cell-level attributes as a pd.DataFrame.
"""
try:
import scirpy as ir
except ImportError:
raise ImportError("Please install scirpy to use this function.")
if isinstance(data, pl.LazyFrame):
data = data.collect().to_pandas()
elif isinstance(data, pl.DataFrame):
data = data.to_pandas()
adata = ir.io.read_airr(data, **kwargs)
return adata.obsm["airr"], adata.obs
def _create_anndata(
airr: Array,
obs: pd.DataFrame,
adata: AnnData | None = None,
) -> AnnData:
"""
Create an AnnData object with the given AIRR array and observation data.
Parameters
----------
airr : Array
The AIRR array.
obs : pd.DataFrame
The observation data.
adata : AnnData | None, optional
An existing AnnData object to update. If None, a new AnnData object will be created.
Returns
-------
AnnData
The AnnData object with the AIRR array and observation data.
"""
obsm = {"airr": airr}
temp = AnnData(X=None, obs=obs, obsm=obsm)
if adata is None:
adata = temp
else:
cell_names = adata.obs_names.intersection(temp.obs_names)
adata = adata[adata.obs_names.isin(cell_names)].copy()
temp = temp[temp.obs_names.isin(cell_names)].copy()
adata.obsm = dict() if adata.obsm is None else adata.obsm
adata.obsm.update(temp.obsm)
return adata
def _create_mudata(
gex: AnnData,
adata: AnnData,
key: tuple[str, str] = ("gex", "airr"),
) -> MuData:
"""
Create a MuData object from the given AnnData objects.
Parameters
----------
gex : AnnData
The AnnData object containing gene expression data.
adata : AnnData
The AnnData object containing additional data.
key : tuple[str, str], optional
The keys to use for the gene expression and additional data in the MuData object. Defaults to ("gex", "airr").
Returns
-------
MuData
The created MuData object.
Raises
------
ImportError
If the mudata package is not installed.
"""
try:
import mudata
except ImportError:
raise ImportError("Please install mudata. pip install mudata")
if gex is not None:
return mudata.MuData({key[0]: gex, key[1]: adata})
return mudata.MuData({key[1]: adata})
def _dtype_supertype(dt1: pl.DataType, dt2: pl.DataType) -> pl.DataType:
"""Return the lowest common numeric supertype of two Polars dtypes.
For numeric type pairs (int/float with different widths or signedness vs
floating-point) this follows a simple promotion ladder so that concat never
sees mismatched types for the same column. For non-numeric mismatches we
fall back to ``pl.String`` so the concat at least does not crash.
"""
if dt1 == dt2:
return dt1
# Promotion rank: higher rank wins. Floats beat ints; wider beats narrower.
_RANK: dict[pl.DataType, int] = {
pl.Float64: 100,
pl.Float32: 90,
pl.Int64: 80,
pl.UInt64: 75,
pl.Int32: 60,
pl.UInt32: 55,
pl.Int16: 40,
pl.UInt16: 35,
pl.Int8: 20,
pl.UInt8: 15,
}
r1, r2 = _RANK.get(dt1), _RANK.get(dt2)
if r1 is not None and r2 is not None:
return dt1 if r1 >= r2 else dt2
# Non-numeric mismatch — cast to String as a safe fallback.
return pl.String
[docs]
def concat(
arrays: (
list[DandelionPolars | pl.DataFrame | pl.LazyFrame | pd.DataFrame]
| dict[
str, DandelionPolars | pl.DataFrame | pl.LazyFrame | pd.DataFrame
]
),
check_unique: bool = True,
collapse_cells: bool = True,
sep: str = "_",
suffixes: list[str] | None = None,
prefixes: list[str] | None = None,
remove_trailing_hyphen_number: bool = False,
verbose: bool = True,
) -> DandelionPolars:
"""
Concatenate data frames and return as Dandelion object.
If both suffixes and prefixes are `None` and check_unique is True, then a sequential number suffix will be appended.
Parameters
----------
arrays : list[DandelionPolars | pl.DataFrame | pl.LazyFrame | pd.DataFrame] | dict[
str, DandelionPolars | pl.DataFrame | pl.LazyFrame | pd.DataFrame
]
List or dictionary of DandelionPolars objects or pandas/polars DataFrames to concatenate.
check_unique : bool, optional
Check the new index for duplicates. Otherwise defer the check until necessary.
Setting to False will improve the performance of this method.
collapse_cells : bool, optional
whether or not to collapse multiple contigs per cell into one row in the
metadata. By default True.
sep : str, optional
the separator to append suffix/prefix.
suffixes : list[str] | None, optional
List of suffixes to append to sequence_id and cell_id.
prefixes : list[str] | None, optional
List of prefixes to append to sequence_id and cell_id.
remove_trailing_hyphen_number : bool, optional
whether or not to remove the trailing hyphen number e.g. '-1' from the
cell/contig barcodes.
verbose : bool, optional
Whether to print the messages, by default True.
Returns
-------
DandelionPolars
concatenated Dandelion object
Raises
------
ValueError
if both prefixes and suffixes are provided.
"""
if (suffixes is not None) and (prefixes is not None):
raise ValueError("Please provide only prefixes or suffixes, not both.")
if suffixes is not None:
if len(arrays) != len(suffixes):
raise ValueError(
"Please provide the same number of suffixes as the number of objects to concatenate."
)
if prefixes is not None:
if len(arrays) != len(prefixes):
raise ValueError(
"Please provide the same number of prefixes as the number of objects to concatenate."
)
# Convert dict to list if necessary
if isinstance(arrays, dict):
arrays = [
(
arrays[x].copy()
if isinstance(arrays[x], DandelionPolars)
else arrays[x]
)
for x in arrays
]
# Collect metadata column names from originals BEFORE deep-copying.
# Deep-copying a parquet-backed DandelionPolars can lose columns that were
# added to _metadata via with_columns() after construction because the
# lazy plan is re-materialised against the original parquet on disk.
# Reading the schema directly here avoids that round-trip.
all_meta_cols: set[str] = set()
for x in arrays:
if isinstance(x, DandelionPolars) and x._metadata is not None:
if isinstance(x._metadata, pl.LazyFrame):
all_meta_cols.update(x._metadata.collect_schema().names())
else:
all_meta_cols.update(x._metadata.columns)
# Convert all inputs to DandelionPolars
vdjs_ = []
for x in arrays:
if isinstance(x, DandelionPolars):
vdjs_.append(x.copy())
elif isinstance(x, pl.LazyFrame):
tmp = DandelionPolars(x.collect(engine="streaming"), verbose=False)
vdjs_.append(tmp)
elif isinstance(x, pl.DataFrame):
tmp = DandelionPolars(x, verbose=False)
vdjs_.append(tmp)
elif isinstance(x, pd.DataFrame):
# Convert pandas to polars
tmp = DandelionPolars(
pl.from_pandas(x, schema_overrides=SCHEMA_OVERRIDES),
verbose=False,
)
vdjs_.append(tmp)
else:
raise ValueError(
"All input arrays must be either DandelionPolars instances, "
"Polars DataFrames/LazyFrames, or pandas DataFrames."
)
# Collect metadata and data names
tmp_meta_names, tmp_data_names = [], []
for tmp in vdjs_:
# Get names as lists
if isinstance(tmp._metadata, pl.LazyFrame):
tmp_meta_names.extend(
tmp._metadata.select("cell_id")
.collect(engine="streaming")
.to_series()
.to_list()
)
elif isinstance(tmp._metadata, pl.DataFrame):
tmp_meta_names.extend(
tmp._metadata.select("cell_id").to_series().to_list()
)
if isinstance(tmp._data, pl.LazyFrame):
tmp_data_names.extend(
tmp._data.select("sequence_id")
.collect(engine="streaming")
.to_series()
.to_list()
)
elif isinstance(tmp._data, pl.DataFrame):
tmp_data_names.extend(
tmp._data.select("sequence_id").to_series().to_list()
)
if collapse_cells:
# Preserve order but remove duplicates
tmp_meta_names = list(dict.fromkeys(tmp_meta_names))
if len(tmp_meta_names) != len(set(tmp_meta_names)):
metadata_index_order = None
else:
metadata_index_order = tmp_meta_names
if len(tmp_data_names) != len(set(tmp_data_names)):
data_index_order = None
else:
data_index_order = tmp_data_names
# Handle unique indices with suffixes/prefixes
if check_unique:
if metadata_index_order is None and data_index_order is None:
metadata_index_order, data_index_order = [], []
for i in range(0, len(vdjs_)):
if (suffixes is None) and (prefixes is None):
vdjs_[i].add_cell_suffix(
str(i),
sep=sep,
remove_trailing_hyphen_number=remove_trailing_hyphen_number,
)
elif suffixes is not None:
vdjs_[i].add_cell_suffix(
str(suffixes[i]),
sep=sep,
remove_trailing_hyphen_number=remove_trailing_hyphen_number,
)
elif prefixes is not None:
vdjs_[i].add_cell_prefix(
str(prefixes[i]),
sep=sep,
remove_trailing_hyphen_number=remove_trailing_hyphen_number,
)
# Collect updated names
if isinstance(vdjs_[i]._metadata, pl.LazyFrame):
metadata_index_order.extend(
vdjs_[i]
._metadata.select("cell_id")
.collect(engine="streaming")
.to_series()
.to_list()
)
else:
metadata_index_order.extend(
vdjs_[i]
._metadata.select("cell_id")
.to_series()
.to_list()
)
if isinstance(vdjs_[i]._data, pl.LazyFrame):
data_index_order.extend(
vdjs_[i]
._data.select("sequence_id")
.collect(engine="streaming")
.to_series()
.to_list()
)
else:
data_index_order.extend(
vdjs_[i]
._data.select("sequence_id")
.to_series()
.to_list()
)
elif data_index_order is None:
data_index_order = []
for i in range(0, len(vdjs_)):
if (suffixes is None) and (prefixes is None):
vdjs_[i].add_sequence_suffix(
str(i),
sep=sep,
sync=False,
remove_trailing_hyphen_number=remove_trailing_hyphen_number,
)
elif suffixes is not None:
vdjs_[i].add_sequence_suffix(
str(suffixes[i]),
sep=sep,
sync=False,
remove_trailing_hyphen_number=remove_trailing_hyphen_number,
)
elif prefixes is not None:
vdjs_[i].add_sequence_prefix(
str(prefixes[i]),
sep=sep,
sync=False,
remove_trailing_hyphen_number=remove_trailing_hyphen_number,
)
if isinstance(vdjs_[i]._data, pl.LazyFrame):
data_index_order.extend(
vdjs_[i]
._data.select("sequence_id")
.collect(engine="streaming")
.to_series()
.to_list()
)
else:
data_index_order.extend(
vdjs_[i]
._data.select("sequence_id")
.to_series()
.to_list()
)
else:
if metadata_index_order is None or data_index_order is None:
raise ValueError(
"Cell/contig indices are not unique. Please set check_unique=True to append suffixes/prefixes or ensure unique indices before concatenation."
)
# Handle v_call_genotyped consistency
genotyped_v_call = [
True
for vdj in vdjs_
if "v_call_genotyped"
in (
vdj._data.collect_schema().names()
if isinstance(vdj._data, pl.LazyFrame)
else vdj._data.columns
)
]
if len(genotyped_v_call) > 0:
if len(genotyped_v_call) != len(vdjs_):
if verbose:
logg.info(
"For consistency, 'v_call_genotyped' will be used where available. Filling missing values from 'v_call'."
)
for i in range(0, len(vdjs_)):
data_cols = (
vdjs_[i]._data.collect_schema().names()
if isinstance(vdjs_[i]._data, pl.LazyFrame)
else vdjs_[i]._data.columns
)
if "v_call_genotyped" not in data_cols:
vdjs_[i]._data = vdjs_[i]._data.with_columns(
pl.col("v_call").alias("v_call_genotyped")
)
# Concatenate the data (Polars DataFrames).
# Manually align schemas before calling pl.concat so that frames which
# already possess all union columns are never passed through
# with_columns([]) — an empty node that fails when polars tries to resolve
# it against a parquet-backed LazyFrame. collect_schema() is a no-op
# (it reads the plan's output schema without executing any query), so this
# remains fully lazy for LazyFrame inputs.
all_schema: dict[str, pl.DataType] = {}
for vdj in vdjs_:
frame_schema = (
vdj._data.collect_schema()
if isinstance(vdj._data, pl.LazyFrame)
else vdj._data.schema
)
for col_name, col_dtype in frame_schema.items():
if col_name not in all_schema:
all_schema[col_name] = col_dtype
elif all_schema[col_name] != col_dtype:
# Promote to the lowest common supertype so that pl.concat
# never sees mismatched dtypes for the same column.
all_schema[col_name] = _dtype_supertype(
all_schema[col_name], col_dtype
)
col_order = list(all_schema.keys())
arrays_ = []
for vdj in vdjs_:
frame = vdj._data
present_schema = dict(
frame.collect_schema()
if isinstance(frame, pl.LazyFrame)
else frame.schema
)
fix_exprs: list[pl.Expr] = []
for n in col_order:
target_dt = all_schema[n]
if n not in present_schema:
# Column absent — add a typed null.
fix_exprs.append(pl.lit(None).cast(target_dt).alias(n))
elif present_schema[n] != target_dt:
# Column present but wrong dtype — cast to the resolved type.
fix_exprs.append(pl.col(n).cast(target_dt))
if fix_exprs:
frame = frame.with_columns(fix_exprs)
arrays_.append(frame.select(col_order))
# pl.concat requires a uniform type; if any frame is lazy, lazify all
# eager DataFrames so the list is homogeneous.
any_lazy = any(isinstance(f, pl.LazyFrame) for f in arrays_)
if any_lazy:
arrays_ = [
f if isinstance(f, pl.LazyFrame) else f.lazy() for f in arrays_
]
vdj_concat = DandelionPolars(
pl.concat(arrays_, how="diagonal"), verbose=False
)
# Handle missing metadata columns
vdj_meta_cols = set(
vdj_concat._metadata.collect_schema().names()
if isinstance(vdj_concat._metadata, pl.LazyFrame)
else vdj_concat._metadata.columns
)
missing_meta_cols = all_meta_cols - vdj_meta_cols
if len(missing_meta_cols) > 0:
# Collect metadata if lazy for easier manipulation
if isinstance(vdj_concat._metadata, pl.LazyFrame):
meta_df = vdj_concat._metadata.collect(engine="streaming")
else:
meta_df = vdj_concat._metadata.clone()
# Add missing columns with nulls
for col in missing_meta_cols:
meta_df = meta_df.with_columns(pl.lit(None).alias(col))
# Fill in values from the original (pre-copy) inputs so that columns
# added to _metadata after construction (e.g. via with_columns) are
# not lost due to the parquet-backing limitation in deep copies.
for x in arrays:
if not isinstance(x, DandelionPolars) or x._metadata is None:
continue
# Collect if lazy
if isinstance(x._metadata, pl.LazyFrame):
source_meta = x._metadata.collect(engine="streaming")
else:
source_meta = x._metadata
for col in missing_meta_cols:
source_cols = source_meta.columns
if col in source_cols:
# Create a mapping from cell_id to values
mapping = source_meta.select(["cell_id", col])
# Join to update values
meta_df = (
meta_df.join(
mapping.rename({col: f"_temp_{col}"}),
on="cell_id",
how="left",
)
.with_columns(
pl.coalesce(
pl.col(f"_temp_{col}"), pl.col(col)
).alias(col)
)
.drop(f"_temp_{col}")
)
vdj_concat._metadata = meta_df
# Reorder metadata according to original order
if isinstance(vdj_concat._metadata, pl.LazyFrame):
concat_meta = vdj_concat._metadata.collect(engine="streaming")
else:
concat_meta = vdj_concat._metadata
# Create a mapping dataframe with desired order
order_df = pl.DataFrame(
{
"cell_id": metadata_index_order,
"_original_order": range(len(metadata_index_order)),
}
)
# Join to add order column, then sort and drop
reordered_meta = (
concat_meta.join(order_df, on="cell_id", how="inner")
.sort("_original_order")
.drop("_original_order")
)
vdj_concat._metadata = (
reordered_meta.lazy() if vdj_concat._lazy else reordered_meta
)
return vdj_concat
[docs]
def productive_ratio(
adata: AnnData,
vdj: DandelionPolars,
group_by: str,
groups: list[str] | None = None,
locus: Literal["TRB", "TRA", "TRD", "TRG", "IGH", "IGK", "IGL"] = "TRB",
):
"""
Compute the cell-level productive/non-productive contig ratio.
Only the contig with the highest umi count in a cell will be used for this
tabulation.
Parameters
----------
adata : AnnData
AnnData object holding the cell level metadata (`.obs`).
vdj : DandelionPolars
DandelionPolars object holding the repertoire data (`.data`).
group_by : str
Name of column in `AnnData.obs` to return the row tabulations.
groups : list[str] | None, optional
Optional list of categories to return.
locus : Literal["TRB", "TRA", "TRD", "TRG", "IGH", "IGK", "IGL"], optional
One of the accepted locuses to perform the tabulation
Returns
-------
None
Modifies ``adata`` in place, storing the result in ``adata.uns['productive_ratio']``.
"""
start = logg.info("Tabulating productive ratio")
# Filter to cells present in adata
vdjx = vdj[vdj.metadata["cell_id"].is_in(list(adata.obs_names))]
# Get data, handling lazy frames
if isinstance(vdjx._data, pl.LazyFrame):
data = vdjx._data.collect(engine="streaming")
else:
data = vdjx._data
# Filter by locus and ambiguous status
if "ambiguous" in data.columns:
data_filtered = data.filter(
(pl.col("locus") == locus) & (pl.col("ambiguous").is_in(FALSES_STR))
)
else:
data_filtered = data.filter(pl.col("locus") == locus)
# Drop duplicates on cell_id, keeping first (highest umi due to sorting)
df_unique = data_filtered.unique(subset=["cell_id"], keep="first")
# Create mapping of cell_id to productive status
dict_df = dict(zip(df_unique["cell_id"], df_unique["productive"]))
# Determine groups if not provided
if groups is None:
if is_categorical(adata.obs[group_by]):
groups = list(adata.obs[group_by].cat.categories)
else:
groups = list(set(adata.obs[group_by]))
# Initialize results DataFrame
res = pd.DataFrame(
columns=["productive", "non-productive", "total"],
index=groups,
)
# Add productive status to adata
adata.obs[locus + "_productive"] = pd.Series(dict_df)
# Calculate ratios per group
for i in range(res.shape[0]):
cell = res.index[i]
res.loc[cell, "total"] = sum(adata.obs[group_by] == cell)
if res.loc[cell, "total"] > 0:
res.loc[cell, "productive"] = (
sum(
adata.obs.loc[
adata.obs[group_by] == cell, locus + "_productive"
].isin(["T"])
)
/ res.loc[cell, "total"]
* 100
)
res.loc[cell, "non-productive"] = (
sum(
adata.obs.loc[
adata.obs[group_by] == cell, locus + "_productive"
].isin(["F"])
)
/ res.loc[cell, "total"]
* 100
)
res[group_by] = res.index
res["productive+non-productive"] = res["productive"] + res["non-productive"]
out = {"results": res, "locus": locus, "group_by": group_by}
adata.uns["productive_ratio"] = out
logg.info(
" finished",
time=start,
deep=(
f"Updated AnnData: \n"
f" 'obs', '{locus}_productive'\n"
" 'uns', 'productive_ratio'\n"
),
)