from __future__ import annotations
import multiprocessing
import re
import time
import dask.array as da
import networkx as nx
import numpy as np
import pandas as pd
import polars as pl
from collections import defaultdict
from pathlib import Path
from polyleven import levenshtein
from scanpy import logging as logg
from scipy.sparse.csgraph import (
connected_components as scipy_cc,
minimum_spanning_tree as scipy_mst,
)
from scipy.sparse import coo_matrix, csr_matrix
from tqdm import tqdm
from typing import Callable, Literal, TYPE_CHECKING
if TYPE_CHECKING:
from anndata import AnnData
from dandelion.polars.core._core import DandelionPolars
from dandelion.polars.tools._tools import vdj_sample
from dandelion.utilities._layout import generate_layout
from dandelion.utilities._distances import (
Metric,
prepare_sequences_with_separator,
resolve_metric,
)
from dandelion.utilities._utilities import FALSES
def _load_lazy_distance_from_zarr(zarr_path: Path | str):
"""Load lazy distance matrix from a user/path-compatible Zarr location."""
import dask.array as da
base = str(zarr_path).rstrip("/\\")
if base.lower().endswith(".zarr"):
candidates = [f"{base}/distance_matrix", base]
else:
candidates = [
f"{base}/distance_matrix.zarr/distance_matrix",
f"{base}/distance_matrix.zarr",
f"{base}/distance_matrix",
]
last_error = None
for candidate in candidates:
try:
return da.from_zarr(candidate)
except Exception as exc:
last_error = exc
raise ValueError(
f"Could not load lazy distance matrix from zarr_path={zarr_path}. "
f"Tried: {candidates}"
) from last_error
def _merge_overlapping_clones(
membership: pl.DataFrame, clone_col: str
) -> pl.DataFrame:
"""
Merge overlapping clones into single membership groups.
Uses scipy connected_components on a bipartite cell-clone graph for O(n)
performance instead of iterative DataFrame joins.
For cells with multiple clones ("|"-separated), this finds all cells that share
any clone and assigns them to the same group.
Parameters
----------
membership : pl.DataFrame
DataFrame with 'cell_id' and clone column ("|"-separated values)
clone_col : str
Name of the clone column
Returns
-------
pl.DataFrame
DataFrame with 'cell_id' and 'membership_id' where overlapping clones are merged
"""
# Explode "|"-separated clones to get one row per cell-clone pair
exploded = (
membership.with_columns(
pl.col(clone_col).str.split("|").alias("_clones")
)
.explode("_clones")
.with_columns(pl.col("_clones").str.strip_chars().alias("_clone"))
.filter(
pl.col("_clone").is_not_null()
& (pl.col("_clone") != "")
& (pl.col("_clone").str.to_lowercase() != "none")
)
.select(["cell_id", "_clone"])
)
if exploded.height == 0:
return pl.DataFrame(
{
"cell_id": membership["cell_id"],
"membership_id": [None] * membership.height,
}
)
# Encode cells and clones as integer indices
cell_ids = exploded["cell_id"].unique().sort()
clone_ids = exploded["_clone"].unique().sort()
n_cells = cell_ids.len()
n_clones = clone_ids.len()
# Build lookup dicts: cell/clone string → integer index
cell_to_idx = dict(zip(cell_ids.to_list(), range(n_cells)))
clone_to_idx = dict(zip(clone_ids.to_list(), range(n_clones)))
# Map exploded pairs to integer indices
cell_indices = np.array(
[cell_to_idx[c] for c in exploded["cell_id"].to_list()]
)
clone_indices = np.array(
[clone_to_idx[c] + n_cells for c in exploded["_clone"].to_list()]
)
# Build bipartite adjacency: cells [0, n_cells) ↔ clones [n_cells, n_cells+n_clones)
n_total = n_cells + n_clones
ones = np.ones(len(cell_indices), dtype=np.float32)
adj = csr_matrix(
(ones, (cell_indices, clone_indices)), shape=(n_total, n_total)
)
adj = adj + adj.T # make symmetric
# Single-pass connected components
_, labels = scipy_cc(adj, directed=False)
# Extract labels for cell nodes only (first n_cells entries)
cell_labels = labels[:n_cells]
cell_id_list = cell_ids.to_list()
cell_groups = pl.DataFrame(
{
"cell_id": cell_id_list,
"membership_id": [str(lbl) for lbl in cell_labels],
}
)
# Join back to original membership order
result = membership.select("cell_id").join(
cell_groups, on="cell_id", how="left"
)
return result
[docs]
def generate_network(
vdj: DandelionPolars,
adata: AnnData | None = None,
key: str | None = None,
clone_key: str | None = None,
min_size: int = 2,
sample: int | None = None,
force_replace: bool = False,
verbose: bool = True,
compute_graph: bool = True,
compute_layout: bool = True,
layout_method: Literal[
"mod_fr",
"mod_fr2",
"mod_fr2_gpu",
"mod_fr_bh",
"mod_fr_bh_gpu",
"fa2",
] = "mod_fr2",
singleton_mass: float = 0.5,
expanded_only: bool = False,
use_existing_distance: bool = True,
use_existing_graph: bool = True,
n_cpus: int = 1,
sequential_chain: bool = False,
distance_mode: Literal["clone", "full"] = "clone",
dist_func: Callable | str | None = None,
pad_to_max: bool = False,
lazy: bool = False,
zarr_path: Path | str | None = None,
chunk_size: int | None = None,
memory_limit_gb: float | None = None,
memory_safety_fraction: float = 0.3,
compress: bool = True,
random_state: int | np.random.RandomState | None = None,
**kwargs,
) -> DandelionPolars | tuple[DandelionPolars, AnnData]:
"""
Generate a Levenshtein distance network based on VDJ and VJ sequences.
The distance matrices are then combined into a singular matrix.
Parameters
----------
vdj : DandelionPolars
Dandelion object.
key : str | None, optional
column name for distance calculations. None defaults to 'sequence_alignment_aa'.
clone_key : str | None, optional
column name to build network on.
min_size : int, optional
For visualization purposes, two graphs are created where one contains all cells and a trimmed second graph.
This value specifies the minimum number of edges required otherwise node will be trimmed in the secondary graph.
sample : int | None, optional
If specified, cells will be randomly sampled to the integer provided. If the integer is larger than the number of cells,
sampling with replacement is used and the same cell may appear multiple times with different sequence and cell ids. If None,
no resampling is performed. A new Dandelion class will be returned.
force_replace : bool, optional
whether or not to sample with replacement when `sample` is smaller or equal to than the number of cells.
verbose : bool, optional
whether or not to print the progress bars.
compute_graph : bool, optional
whether or not to generate the graph after distance matrix calculation.
compute_layout : bool, optional
whether or not to generate the layout. May be time consuming if too many cells.
layout_method : Literal["mod_fr", "mod_fr2", "mod_fr2_gpu", "mod_fr_bh", "mod_fr_bh_gpu", "fa2"], optional
Layout algorithm. Options:
- 'mod_fr': Original python modified FR layout
- 'mod_fr2': Numba-accelerated modified FR (faster CPU)
- 'mod_fr2_gpu': PyTorch GPU modified FR (auto-tiles for >100K nodes)
- 'mod_fr_bh': Barnes-Hut O(N log N) CPU layout (scalable for large graphs)
- 'mod_fr_bh_gpu': Barnes-Hut O(N log N) GPU layout (scalable for large graphs, requires CUDA)
- 'fa2': ForceAtlas2 (requires fa2-modified)
singleton_mass : float, optional
Mass assigned to singleton nodes (no edges) in Barnes-Hut layouts.
Lower values reduce their impact on pushing connected components apart.
Default 0.5. Only used with 'mod_fr_bh' and 'mod_fr_bh_gpu'.
expanded_only : bool, optional
whether or not to only compute layout on expanded clonotypes.
use_existing_distance : bool, optional
whether or not to use the pre-computed distance matrix in `vdj.distances` if it exists. If False, distances will be re-computed even if they already exist.
use_existing_graph : bool, optional
whether or not to just compute the layout using the existing graph if it exists in the object.
n_cpus : int, optional
number of cores to use for parallelizable steps. -1 uses all available cores.
sequential_chain : bool, optional
whether or not to use the original method for distance calculation method where each chain is calculated
separately and sequentially added to the total distance matrix. This method is slower but would be more
precise calculation. If False, concatenated sequences with a long separator are used for distance calculation.
Ignored if lazy=True as the lazy method always uses the long separator approach. The long separator approach
inserts a long string of consistent characters on a per-chain basis to ensure that distances between chains are large
and do not interfere with intra-chain distances.
distance_mode : Literal["clone", "full"], optional
method to compute distance matrix. 'clone' refers to the original membership-based distance calculation where
only distances within clones are calculated. Whereas 'full' computes the full pairwise distance matrix.
dist_func : Callable | str | None, optional
distance function to use. If None, `polyleven.levenshtein` is used. If a string is provided, it will use Bio.Align's
substitution matrices (e.g., 'BLOSUM62', 'PAM250'). See `Bio.Align.substitution_matrices.load` for available options.
pad_to_max : bool, optional
whether or not to pad sequences to the maximum length in the dataset before distance calculation. This will
allow for distance calculations that need sequences of the same length (e.g., Hamming distance). Note that this
may increase memory usage and computation time.
lazy: bool, optional
If True, computation will be performed lazily using Dask/Zarr arrays. True will also return a Dask array view of the
distance matrix stored on disk instead of a numpy array stored in memory.
zarr_path: Path | str | None, optional
Path to store Zarr array when using lazy mode. If None, "distance_matrix.zarr" will be created in the current working directory.
chunk_size: int | None, optional
Chunk size for distance matrix computation when using lazy mode. If None, chunk size is automatically computed
based on available memory and number of cores. The automatic chunk size can be further adjusted using
`memory_limit_gb` and `memory_safety_fraction` parameters.
memory_limit_gb: float | None, optional
Memory limit per worker in GB for Dask. None defaults to all available memory/cores.
memory_safety_fraction: float, optional
Fraction of available memory to use. Defaults to 0.3 (i.e., 30% of available memory will be used for chunk size calculation).
compress: bool, optional
Whether to compress the Zarr array using Blosc with zstd.
random_state : int | np.random.RandomState | None, optional
Random state for reproducible sampling.
**kwargs
additional kwargs passed to layout functions in `generate_layout`.
Returns
-------
DandelionPolars | tuple[DandelionPolars, AnnData]
DandelionPolars object with `.edges`, `.layout`, `.graph` initialized.
Raises
------
ValueError
if any errors with dandelion input.
"""
# normalize n_cpus convention (-1 => use all CPUs)
if n_cpus == -1:
n_cpus = multiprocessing.cpu_count()
n_cpus = max(1, int(n_cpus))
clone_key = clone_key if clone_key is not None else "clone_id"
dist_func = levenshtein if dist_func is None else dist_func
metric = resolve_metric(dist_func)
if not compute_graph:
compute_layout = False
if distance_mode == "clone" or compute_graph or compute_layout:
if clone_key not in vdj._data:
raise ValueError(
"Data does not contain clone information. Please run ddl.tl.find_clones."
)
regenerate = True
if vdj.graph is not None:
if (min_size != 2) or (sample is not None):
pass
elif use_existing_graph:
start = logg.info(
"Generating network layout from pre-computed network"
)
if isinstance(vdj, DandelionPolars):
regenerate = False
g, g_, lyt, lyt_ = generate_layout(
vertices=None,
edges=None,
min_size=min_size,
weight=None,
verbose=verbose,
compute_layout=compute_layout,
layout_method=layout_method,
expanded_only=expanded_only,
graphs=(vdj.graph[0], vdj.graph[1]),
singleton_mass=singleton_mass,
**kwargs,
)
if regenerate:
start = logg.info("Generating network")
key_ = key if key is not None else "sequence_alignment_aa"
if key_ not in vdj._data:
raise ValueError(f"key {key_} not found in data.")
if sample is not None:
if adata is not None:
vdj, adata = vdj_sample(
vdj_data=vdj,
size=sample,
adata=adata,
force_replace=force_replace,
random_state=random_state,
)
else:
vdj = vdj_sample(
vdj_data=vdj,
size=sample,
force_replace=force_replace,
random_state=random_state,
)
dat = vdj[
vdj.data.locus.is_in(
["IGH", "TRB", "TRD", "IGK", "IGL", "TRA", "TRG"]
)
]
if "ambiguous" in dat.data.collect_schema().names():
# Convert FALSES to strings only (remove boolean False) for Polars compatibility
falses_strings = [str(f) for f in FALSES if f is not False]
falses_strings.append(
"False"
) # Ensure both uppercase and lowercase are included
dat = dat[
dat.data["ambiguous"].cast(pl.String).is_in(falses_strings)
]
dat_seq = dat._split(key_, explode=True)
dat_seq = dat_seq.rename(
{
col: re.sub(f"^{key_}_", "", col)
for col in dat_seq.collect_schema().names()
if col.startswith(f"{key_}_")
}
)
# Align dat_seq to vdj._metadata order to ensure indices match
# pre-computed distances from find_clones (which are indexed by metadata position)
meta_cell_ids = vdj._metadata.select("cell_id")
if isinstance(meta_cell_ids, pl.LazyFrame):
meta_cell_ids = meta_cell_ids.collect(engine="streaming")
# Left join to preserve metadata order, missing cells get null sequences
dat_seq = meta_cell_ids.join(dat_seq, on="cell_id", how="left")
# Build a position lookup table (cell_id → row index)
# Positions match metadata indices
pos_map = (
dat_seq.with_row_index("_row_pos")
if isinstance(dat_seq, pl.DataFrame)
else dat_seq.lazy()
.with_row_index("_row_pos")
.collect(engine="streaming")
).select(["cell_id", pl.col("_row_pos").cast(pl.Int64).alias("pos")])
cell_to_pos = dict(
zip(pos_map["cell_id"].to_list(), pos_map["pos"].to_list())
)
if compute_graph or compute_layout or distance_mode == "clone":
# Get clone membership as DataFrame and merge overlapping clones
clone_df = dat._merge(clone_key, unique=True)
membership = _merge_overlapping_clones(clone_df, clone_key)
# Check if pre-computed distances are available
if (
use_existing_distance
and hasattr(vdj, "distances")
and vdj.distances is not None
):
logg.info("Using pre-computed distances from .distances\n")
total_dist = vdj.distances
if isinstance(total_dist, np.ndarray):
total_dist = csr_matrix(total_dist)
# also force lazy=F if pre-computed distance is a csr_matrix
if isinstance(total_dist, csr_matrix):
lazy = False
else:
# compute total_dist using chosen mode (original uses membership)
logg.info(
f"Calculating distance matrix {'lazily ' if lazy else ' '}with distance_mode = '{distance_mode}'\n"
)
if distance_mode == "clone":
if lazy:
from dandelion.polars.tools._lazydistances import (
calculate_distance_matrix_zarr,
)
# Determine Zarr destination and mark embedding intent
if zarr_path is None:
import tempfile
zarr_tmp = tempfile.mkdtemp()
# Flags on object to indicate pending embed on write
try:
setattr(vdj, "_distance_zarr_path", zarr_tmp)
setattr(vdj, "_distance_embed_pending", True)
except Exception:
pass
else:
# External Zarr mode
try:
setattr(vdj, "_distance_zarr_path", str(zarr_path))
setattr(vdj, "_distance_embed_pending", False)
except Exception:
pass
_ = calculate_distance_matrix_zarr(
dat_seq,
metric=metric,
pad_to_max=pad_to_max,
membership=membership,
zarr_path=(
zarr_tmp if zarr_path is None else zarr_path
),
chunk_size=chunk_size,
n_cpus=n_cpus,
memory_limit_gb=memory_limit_gb,
memory_safety_fraction=memory_safety_fraction,
compress=compress,
verbose=verbose,
)
zpath = zarr_tmp if zarr_path is None else zarr_path
total_dist = _load_lazy_distance_from_zarr(zpath)
else:
if sequential_chain:
total_dist = calculate_distance_matrix_original(
dat_seq,
membership,
metric=metric,
pad_to_max=pad_to_max,
verbose=verbose,
)
else:
total_dist = calculate_distance_matrix_long(
dat_seq,
membership=membership,
metric=metric,
pad_to_max=pad_to_max,
n_cpus=n_cpus,
verbose=verbose,
)
elif distance_mode == "full":
if lazy:
from dandelion.polars.tools._lazydistances import (
calculate_distance_matrix_zarr,
)
# Determine Zarr destination and mark embedding intent
if zarr_path is None:
import tempfile
zarr_tmp = tempfile.mkdtemp()
try:
setattr(vdj, "_distance_zarr_path", zarr_tmp)
setattr(vdj, "_distance_embed_pending", True)
except Exception:
pass
else:
try:
setattr(vdj, "_distance_zarr_path", str(zarr_path))
setattr(vdj, "_distance_embed_pending", False)
except Exception:
pass
_ = calculate_distance_matrix_zarr(
dat_seq,
metric=metric,
pad_to_max=pad_to_max,
membership=None,
zarr_path=(
zarr_tmp if zarr_path is None else zarr_path
),
chunk_size=chunk_size,
n_cpus=n_cpus,
memory_limit_gb=memory_limit_gb,
memory_safety_fraction=memory_safety_fraction,
compress=compress,
verbose=verbose,
)
zpath = zarr_tmp if zarr_path is None else zarr_path
total_dist = _load_lazy_distance_from_zarr(zpath)
else:
if sequential_chain:
total_dist = calculate_distance_matrix_original_full(
dat_seq,
metric=metric,
pad_to_max=pad_to_max,
n_cpus=n_cpus,
verbose=verbose,
)
else:
total_dist = calculate_distance_matrix_long(
dat_seq,
membership=None,
metric=metric,
pad_to_max=pad_to_max,
n_cpus=n_cpus,
verbose=verbose,
)
if compute_graph:
# ===================================================================
# NORMALIZE METADATA
# ===================================================================
if isinstance(vdj._metadata, pl.LazyFrame):
meta_df = vdj._metadata.collect(engine="streaming")
elif isinstance(vdj._metadata, pl.DataFrame):
meta_df = vdj._metadata
else:
meta_df = pl.from_pandas(vdj._metadata)
# ===================================================================
# SPLIT CLONE KEY AND DETECT OVERLAPS
# ===================================================================
meta_with_order = meta_df.with_row_index("_cell_order")
meta_df_split = meta_with_order.with_columns(
pl.col(str(clone_key)).str.split("|").alias("_clone_list")
)
# Explode to one row per clone_id, join positions
meta_exploded = (
meta_df_split.select(["cell_id", "_clone_list"])
.explode("_clone_list")
.filter(pl.col("_clone_list") != "None")
.rename({"_clone_list": clone_key})
.join(pos_map, on="cell_id", how="left")
)
# ===================================================================
# BUILD OVERLAP GROUPS
# Each overlap group = sorted, pipe-joined string of all clone ids
# that co-occur in at least one cell. This matches the original's
# `"|".join(ol)` key built from `overlap` lists.
# ===================================================================
# Group labels for zero-distance handling.
overlap_cells = (
meta_df_split.with_columns(
pl.col("_clone_list")
.list.eval(pl.element().filter(pl.element() != "None"))
.alias("_overlap_group")
)
.filter(pl.col("_overlap_group").list.len() > 1)
.select(
[
"_cell_order",
pl.col("_overlap_group")
.list.join("|")
.alias("group_key"),
"_overlap_group",
]
)
)
# Map clone_id -> canonical overlap group key
# (a clone may appear in multiple cells; we need the UNION of all
# cells that belong to every clone in the group)
clone_to_group = (
overlap_cells.explode("_overlap_group")
.rename({"_overlap_group": clone_key})
.select([clone_key, "group_key"])
.unique()
)
# Attach group_key to every exploded row
meta_exploded = meta_exploded.join(
clone_to_group, on=clone_key, how="left"
)
# For non-overlap clones group_key is null; fill with clone_id
meta_exploded = meta_exploded.with_columns(
pl.when(pl.col("group_key").is_null())
.then(pl.col(clone_key))
.otherwise(pl.col("group_key"))
.alias("group_key")
)
# ===================================================================
# MST construction
# ===================================================================
clone_rows = (
meta_df_split.select(["_cell_order", "cell_id", "_clone_list"])
.explode("_clone_list")
.filter(pl.col("_clone_list") != "None")
.rename({"_clone_list": clone_key})
.with_row_index("_clone_seen_order")
)
clone_members = clone_rows.group_by(
clone_key, maintain_order=True
).agg(
[
pl.col("cell_id").alias("cell_ids"),
pl.col("_clone_seen_order")
.min()
.alias("_first_clone_seen_order"),
]
)
overlap_group_clone_map = (
overlap_cells.select(["group_key", "_overlap_group"])
.explode("_overlap_group")
.rename({"_overlap_group": clone_key})
)
overlap_clone_ids = overlap_group_clone_map.select(
clone_key
).unique()
overlap_candidates = overlap_group_clone_map.join(
clone_members, on=clone_key, how="left"
)
overlap_union_cells = (
overlap_candidates.select(["group_key", "cell_ids"])
.explode("cell_ids")
.group_by("group_key", maintain_order=True)
.agg(pl.col("cell_ids").unique().alias("_union_cell_ids"))
)
overlap_winner_rank = overlap_candidates.group_by(
"group_key", maintain_order=True
).agg(pl.col("_first_clone_seen_order").max().alias("_winner_rank"))
overlap_group_summary = (
overlap_union_cells.join(
overlap_winner_rank, on="group_key", how="inner"
)
.with_columns(
pl.col("_union_cell_ids").list.len().alias("_union_size")
)
.filter(pl.col("_union_size") > 1)
)
overlap_mst_groups = (
overlap_candidates.join(
overlap_group_summary.select(["group_key", "_winner_rank"]),
on="group_key",
how="inner",
)
.filter(
pl.col("_first_clone_seen_order") == pl.col("_winner_rank")
)
.group_by("group_key", maintain_order=True)
.agg(
[
pl.first("cell_ids").alias("cell_ids"),
pl.first("_winner_rank").alias("_group_order"),
]
)
)
non_overlap_mst_groups = (
clone_members.join(overlap_clone_ids, on=clone_key, how="anti")
.filter(pl.col("cell_ids").list.len() > 1)
.select(
[
pl.col(clone_key).alias("group_key"),
pl.col("cell_ids"),
pl.col("_first_clone_seen_order").alias("_group_order"),
]
)
)
mst_groups_df = (
pl.concat(
[overlap_mst_groups, non_overlap_mst_groups],
how="vertical_relaxed",
)
.sort("_group_order")
.select(["group_key", "cell_ids"])
)
# ===================================================================
# ZERO-DISTANCE GROUPS
# Same union logic; kept separate because the edge-filtering differs.
# ===================================================================
zero_groups_df = (
meta_exploded.select(["group_key", "cell_id", "pos"])
.unique(subset=["group_key", "cell_id"])
.group_by("group_key")
.agg(
[
pl.col("cell_id").alias("cell_ids"),
pl.col("pos").alias("positions"),
]
)
.filter(pl.col("positions").list.len() >= 2)
)
# ===================================================================
# MST COMPUTATION
# ===================================================================
mst_edge_dict = {}
for row in mst_groups_df.iter_rows(named=True):
ids = row["cell_ids"]
pos = [cell_to_pos[i] for i in ids]
edges = _create_mst_edges(
total_dist=total_dist,
positions=pos,
cell_ids=ids,
lazy=lazy,
)
if edges is not None:
mst_edge_dict[row["group_key"]] = edges
# ===================================================================
# ZERO-DISTANCE EDGES
# ===================================================================
zero_edge_dict = {}
for row in zero_groups_df.iter_rows(named=True):
edges = _find_zero_dist_edges(
total_dist=total_dist,
positions=row["positions"],
cell_ids=row["cell_ids"],
lazy=lazy,
)
if edges is not None:
zero_edge_dict[row["group_key"]] = edges
# ===================================================================
# MERGE MST AND ZERO-DISTANCE EDGES
#
# Index MUST be the canonical "min_cell_id|max_cell_id" string so
# combine_first aligns on edge identity, not row position.
# This matches the original's set_edge_list_index / _add_sorted_index
# behaviour.
# ===================================================================
try:
if mst_edge_dict:
edge_listx = _make_canonical_index(
pd.concat(
list(mst_edge_dict.values()), ignore_index=True
)
)
else:
edge_listx = pd.DataFrame(
columns=["source", "target", "weight"]
)
if zero_edge_dict:
tmp_edge_listx = _make_canonical_index(
pd.concat(
list(zero_edge_dict.values()), ignore_index=True
)
)
tmp_edge_listx = tmp_edge_listx[
tmp_edge_listx["weight"] == 0
]
else:
tmp_edge_listx = pd.DataFrame(
columns=["source", "target", "weight"]
)
# MST edges take priority (combine_first fills NaN from tmp)
edge_list_final = edge_listx.combine_first(tmp_edge_listx)
# ---------------------------------------------------------------
# WEIGHT LOOKUP
# Re-read actual distances from total_dist for every edge.
# Build cell_id -> position map from deduplicated meta_exploded.
# ---------------------------------------------------------------
if len(edge_list_final) > 0:
# Deduplicate cell_id -> pos (a cell always has one position)
sources = edge_list_final["source"].tolist()
targets = edge_list_final["target"].tolist()
src_pos = np.array([cell_to_pos[s] for s in sources])
tgt_pos = np.array([cell_to_pos[t] for t in targets])
if lazy:
actual_weights = da.stack(
[
total_dist[src_pos[i], tgt_pos[i]]
for i in range(len(src_pos))
]
).compute()
else:
actual_weights = total_dist[src_pos, tgt_pos]
if hasattr(actual_weights, "A1"):
actual_weights = actual_weights.A1
# Apply weight offset correction matching original's
# csgraph_from_dense(total_dist + 1) / weight - 1 pattern:
# _create_mst_edges should already return corrected weights,
# but if total_dist itself carries the +1 offset uncomment:
# actual_weights = actual_weights - 1
# np.clip(actual_weights, 0, None, out=actual_weights)
# Only overwrite MST edge weights; zero-distance edges must
# stay at 0 (matching original's weight == 0 filter).
is_zero_edge = edge_list_final.index.isin(
tmp_edge_listx.index
)
edge_list_final["weight"] = actual_weights
edge_list_final.loc[is_zero_edge, "weight"] = 0
edge_list_final = edge_list_final.reset_index(drop=True)
except Exception:
edge_list_final = None
# ===================================================================
# FINAL LAYOUT + GRAPH CREATION
# ===================================================================
g, g_, lyt, lyt_ = generate_layout(
vertices=meta_df["cell_id"].to_list(),
edges=edge_list_final,
min_size=min_size,
weight=None,
verbose=verbose,
compute_layout=compute_layout,
layout_method=layout_method,
expanded_only=expanded_only,
singleton_mass=singleton_mass,
**kwargs,
)
logg.info(
" finished.\n Updated Dandelion object\n",
time=start,
deep=(
" 'layout', graph layout\n"
if compute_layout
else (
""
" 'graph', network constructed from distance matrices of VDJ- and VJ- chains\n"
if compute_graph
else (
"" " 'distances', VDJ + VJ distance matrix\n"
if regenerate
else ""
)
)
),
)
# return or re-initialize vdj
germline = getattr(vdj, "germline", None)
if regenerate:
if lazy:
distances = total_dist
elif isinstance(total_dist, csr_matrix):
distances = total_dist
else:
distances = csr_matrix(total_dist)
if not lazy:
distances._index_names = vdj.metadata_names
else:
distances = vdj.distances
graph = (g, g_) if compute_graph else None
layout = (lyt, lyt_) if compute_graph and compute_layout else None
if sample is not None:
out = DandelionPolars(
data=vdj._data.collect(),
metadata=vdj._metadata.collect(),
clone_key=clone_key,
layout=layout,
graph=graph,
distances=distances,
germline=germline,
verbose=False,
)
if adata is None:
return out
else:
return out, adata
else:
vdj._reinitialize_attributes(
data=(
vdj._data.collect()
if isinstance(vdj._data, pl.LazyFrame)
else vdj._data
),
metadata=(
vdj._metadata.collect()
if isinstance(vdj._metadata, pl.LazyFrame)
else vdj._metadata
),
clone_key=clone_key,
layout=layout,
graph=graph,
distances=distances,
germline=germline,
reinitialize=False,
)
def _get_positions_for_group(
cells_list: list, meta_exploded: pl.DataFrame
) -> list[int]:
"""
Get positions for a group of cells by joining with metadata.
Parameters
----------
cells_list : list
List of cell IDs
meta_exploded : pl.DataFrame
Exploded metadata DataFrame with cell_id and pos columns
Returns
-------
list[int]
List of unique positions for the cells
"""
cells_df = pl.DataFrame({"cell_id": cells_list})
positions = (
cells_df.join(
meta_exploded.select(["cell_id", "pos"]),
on="cell_id",
how="inner",
)
.select("pos")
.unique()
.to_series()
.to_list()
)
return positions
def _create_mst_edges(
total_dist: np.ndarray,
positions: list[int],
cell_ids: list[str],
lazy: bool = False,
) -> pd.DataFrame | None:
if len(positions) < 2:
return None
if lazy:
from dandelion.polars.tools._lazydistances import dask_safe_slice_square
submat = dask_safe_slice_square(total_dist, positions).compute()
else:
submat = total_dist[np.ix_(positions, positions)]
if hasattr(submat, "toarray"):
submat = submat.toarray()
if submat.shape[0] < 2 or submat.shape[1] < 2:
return None
shifted = submat.astype(float) + 1.0
shifted[np.isnan(shifted)] = 0.0
n = shifted.shape[0]
iu, ju = np.triu_indices(n, k=1)
w = shifted[iu, ju]
keep = w > 0
iu, ju, w = iu[keep], ju[keep], w[keep]
if w.size == 0:
return None
# Deterministic tie-breaking matching the reference edge construction order.
rank = np.arange(w.size, dtype=float)
eps = 1e-9
w_pert = w + eps * rank
rows = np.concatenate([iu, ju])
cols = np.concatenate([ju, iu])
data = np.concatenate([w_pert, w_pert])
mat = csr_matrix((data, (rows, cols)), shape=(n, n))
mst_sparse = scipy_mst(mat)
coo = mst_sparse.tocoo()
if coo.nnz == 0:
return None
true_shifted = shifted[coo.row, coo.col]
weights = np.maximum(true_shifted - 1.0, 0.0)
return pd.DataFrame(
{
"source": [cell_ids[i] for i in coo.row],
"target": [cell_ids[j] for j in coo.col],
"weight": weights,
}
)
def _find_zero_dist_edges(
total_dist: np.ndarray,
positions: list[int],
cell_ids: list[str],
lazy: bool = False,
) -> pd.DataFrame | None:
if len(positions) < 2:
return None
if lazy:
from dandelion.polars.tools._lazydistances import dask_safe_slice_square
submat = dask_safe_slice_square(total_dist, positions).compute()
else:
submat = total_dist[np.ix_(positions, positions)]
if hasattr(submat, "toarray"):
submat = submat.toarray()
n = len(cell_ids)
row_idx, col_idx = np.tril_indices(n, k=-1)
mask = submat[row_idx, col_idx] == 0
if not mask.any():
return None
return pd.DataFrame(
{
"source": [cell_ids[i] for i in row_idx[mask]],
"target": [cell_ids[j] for j in col_idx[mask]],
"weight": 0.0,
}
)
def _make_canonical_index(df: pd.DataFrame) -> pd.DataFrame:
if df.empty:
return df
idx = ["|".join(sorted([s, t])) for s, t in zip(df["source"], df["target"])]
df = df.copy()
df.index = idx
return df
def _add_sorted_index(df: pd.DataFrame) -> pd.DataFrame:
"""
Add sorted pair index to edge DataFrame.
Creates an index from sorted (source, target) pairs in format "a|b".
Parameters
----------
df : pd.DataFrame
DataFrame with source and target columns
Returns
-------
pd.DataFrame
Copy of DataFrame with sorted pair index
"""
pairs = np.sort(df[["source", "target"]].values, axis=1)
df = df.copy()
df.index = [f"{a}|{b}" for a, b in pairs]
return df
[docs]
def clone_degree(
vdj: DandelionPolars, weight: str | None = None
) -> DandelionPolars:
"""
Calculate node degree in BCR/TCR network.
Parameters
----------
vdj : DandelionPolars
DandelionPolars object after `tl.generate_network` has been run.
weight : str | None, optional
Attribute name for retrieving edge weight in graph. None defaults to ignoring this. See `networkx.Graph.degree`.
Returns
-------
None
Modifies ``vdj._metadata`` in place, adding a ``clone_degree`` column.
Raises
------
AttributeError
if graph not found.
TypeError
if input is not DandelionPolars class.
"""
if isinstance(vdj, DandelionPolars):
if vdj.graph is None:
raise AttributeError(
"Graph not found. Please run tl.generate_network."
)
else:
G = vdj.graph[0]
degree_dict = dict(G.degree(weight=weight))
df = pl.DataFrame(
{
"cell_id": list(degree_dict.keys()),
"clone_degree": list(degree_dict.values()),
}
)
# now merge into vdj._metadata
vdj._metadata = (
vdj._metadata.lazy()
.with_row_index("_orig_idx")
.join(df.lazy(), on="cell_id", how="left")
.sort("_orig_idx")
.drop("_orig_idx")
.collect(engine="streaming")
)
else:
raise TypeError("Input object must be of {}".format(DandelionPolars))
[docs]
def clone_centrality(vdj: DandelionPolars):
"""
Calculate node closeness centrality in BCR/TCR network.
Parameters
----------
vdj : DandelionPolars
DandelionPolars object after `tl.generate_network` has been run.
Returns
-------
None
Modifies ``vdj._metadata`` in place, adding a ``clone_centrality`` column.
Raises
------
AttributeError
if graph not found.
TypeError
if input is not DandelionPolars class.
"""
if isinstance(vdj, DandelionPolars):
if vdj.graph is None:
raise AttributeError(
"Graph not found. Please run tl.generate_network."
)
else:
G = vdj.graph[0]
cc = nx.closeness_centrality(G)
df = pl.DataFrame(
{
"cell_id": list(cc.keys()),
"clone_centrality": list(cc.values()),
}
)
vdj._metadata = (
vdj._metadata.lazy()
.with_row_index("_orig_idx")
.join(df.lazy(), on="cell_id", how="left")
.sort("_orig_idx")
.drop("_orig_idx")
.collect(engine="streaming")
)
else:
raise TypeError("Input object must be of {}".format(DandelionPolars))
def calculate_distance_matrix_original(
dat_seq: pl.DataFrame,
membership: pl.DataFrame,
metric: Metric,
pad_to_max: bool = False,
verbose: bool = True,
) -> csr_matrix:
"""
Re-implementation of original membership-based distance calculation.
Parameters
----------
dat_seq : pl.DataFrame
Polars DataFrame with sequence columns and 'cell_id' column.
membership : pl.DataFrame
DataFrame with 'cell_id' and 'membership_id' columns mapping cells to clone groups.
metric : Metric
Distance metric to use.
pad_to_max : bool, optional
Whether to pad sequences to maximum length before distance calculation.
verbose : bool, optional
Whether to show progress.
Returns
-------
total_dist : csr_matrix
Sparse distance matrix; diagonal is 0 (self-distance).
"""
# Ensure dat_seq is a DataFrame (not LazyFrame)
if isinstance(dat_seq, pl.LazyFrame):
dat_seq = dat_seq.collect(engine="streaming")
n = dat_seq.height
cell_id_list = dat_seq["cell_id"].to_list()
dmat_per_column = defaultdict(list)
# Join with membership and partition by membership_id
dat_seq_with_membership = dat_seq.join(
membership, on="cell_id", how="inner"
)
groups = dat_seq_with_membership.partition_by("membership_id", as_dict=True)
for group_df in tqdm(
groups.values(),
disable=not verbose,
bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}",
):
if group_df.height > 1:
tmp_cell_ids = group_df["cell_id"].to_list()
seq_cols = [
col
for col in group_df.collect_schema().names()
if col not in ("cell_id", "membership_id")
]
for col in seq_cols:
seq_series = (
group_df[col]
.cast(pl.String)
.str.replace_all(r"\.", "")
.fill_null("")
.str.replace_all("None", "")
)
seqs = seq_series.to_list()
seqs_raw = [[s] for s in seqs]
prepared_seqs = prepare_sequences_with_separator(
seqs_raw,
metric=metric,
pad_to_max=pad_to_max,
sep="" if not pad_to_max else "#",
)
# Deduplicate sequences for faster computation
unique_seqs = list(set(prepared_seqs))
seq_to_unique_idx = {
seq: i for i, seq in enumerate(unique_seqs)
}
# Compute distances only for unique sequences
d_mat_unique = metric.compute_vectorized(unique_seqs)
# Use vectorized indexing to expand to full matrix
unique_indices = np.array(
[seq_to_unique_idx[seq] for seq in prepared_seqs]
)
d_mat_tmp = d_mat_unique[np.ix_(unique_indices, unique_indices)]
df_block = pd.DataFrame(
d_mat_tmp, index=tmp_cell_ids, columns=tmp_cell_ids
)
dmat_per_column[col].append(df_block)
dist_matrices = []
for col, blocks in dmat_per_column.items():
if not blocks:
continue
full = pd.concat(blocks)
if any(full.index.duplicated()):
dup_indices = full.index[full.index.duplicated()]
tmp1 = full.drop(dup_indices)
tmp2 = full.loc[dup_indices]
tmp2 = tmp2.groupby(level=0).apply(lambda df: df.sum(axis=0))
full = pd.concat([tmp1, tmp2])
full = full.reindex(index=cell_id_list, columns=cell_id_list).fillna(
0.0
)
dist_matrices.append(full.values)
if len(dist_matrices) == 0:
return csr_matrix((n, n))
total_dist = np.sum(dist_matrices, axis=0)
np.fill_diagonal(total_dist, 0.0)
return csr_matrix(total_dist)
def calculate_distance_matrix_original_full(
dat_seq: pl.DataFrame,
metric: Metric,
pad_to_max: bool = False,
n_cpus: int = 1,
verbose: bool = True,
) -> csr_matrix:
"""
Re-implementation of original membership-based distance calculation.
Parameters
----------
dat_seq : pl.DataFrame
Polars DataFrame with sequence columns and 'cell_id' column.
metric : Metric
Distance metric to use.
pad_to_max : bool, optional
Whether to pad sequences to maximum length before distance calculation.
n_cpus : int, optional
Number of cores to run this step. Parallelise using `sklearn.metrics.pairwise_distances` if n_cpus > 1.
verbose : bool, optional
Whether to show progress.
Returns
-------
total_dist : csr_matrix
Sparse distance matrix; diagonal is 0 (self-distance).
"""
start_time = time.time()
n = dat_seq.height
total_dist = np.zeros((n, n), dtype=float)
seq_cols = [
col for col in dat_seq.collect_schema().names() if col != "cell_id"
]
for col in seq_cols:
seq_series = (
dat_seq[col]
.cast(pl.String)
.str.replace_all(r"\.", "")
.fill_null("")
.str.replace_all("None", "")
)
# Check if we have any non-empty sequences (matching pandas logic)
nonnull = seq_series.drop_nulls()
if nonnull.len() <= 1:
continue
# Prepare sequences for single column (reshape to list of single-element lists)
seqs_raw = [[s] for s in seq_series.to_numpy()]
prepared_seqs = prepare_sequences_with_separator(
seqs_raw,
metric=metric,
pad_to_max=pad_to_max,
sep="" if not pad_to_max else "#",
)
# Deduplicate sequences for faster computation
unique_seqs = list(set(prepared_seqs))
seq_to_unique_idx = {seq: i for i, seq in enumerate(unique_seqs)}
# Compute distances only for unique sequences
results_unique = metric.compute_vectorized(unique_seqs, n_cpus=n_cpus)
# Use vectorized indexing to expand to full matrix
unique_indices = np.array(
[seq_to_unique_idx[seq] for seq in prepared_seqs]
)
results = results_unique[np.ix_(unique_indices, unique_indices)]
total_dist += results
np.fill_diagonal(total_dist, 0.0)
if verbose:
end_time = time.time()
logg.info(
f"Distances calculated in {end_time - start_time:.2f} seconds"
)
return csr_matrix(total_dist)
def calculate_distance_matrix_long(
dat_seq: pl.DataFrame,
membership: pl.DataFrame | None,
metric: Metric,
pad_to_max: bool = False,
n_cpus: int = 1,
verbose: bool = True,
) -> csr_matrix:
"""
Re-implementation of original membership-based distance calculation but using concatenated sequences
using a long separator.
Parameters
----------
dat_seq : pl.DataFrame
Polars DataFrame with sequence columns and 'cell_id' column.
membership : pl.DataFrame | None
DataFrame with 'cell_id' and 'membership_id' columns mapping cells to clone groups.
None indicates full pairwise distance calculation.
metric : Metric
Distance metric to use.
pad_to_max : bool, optional
whether or not to pad sequences to the maximum length in the dataset before distance calculation. This will
allow for distance calculations that need sequences of the same length (e.g., Hamming distance). Note that this
may increase memory usage and computation time.
n_cpus : int, optional
Number of cores to run this step. Parallelise using `sklearn.metrics.pairwise_distances` if n_cpus > 1..
verbose : bool, optional
Whether to show progress.
Returns
-------
total_dist : csr_matrix (n x n)
Sparse distance matrix; diagonal is 0 (self-distance).
"""
start_time = time.time()
# Step 1: clean sequences
# Ensure dat_seq is a DataFrame (not LazyFrame)
if isinstance(dat_seq, pl.LazyFrame):
dat_seq = dat_seq.collect(engine="streaming")
seq_cols = [
col for col in dat_seq.collect_schema().names() if col != "cell_id"
]
dat_seq_clean = dat_seq.select(
[
pl.col("cell_id"),
*[
pl.col(col)
.cast(pl.String)
.str.replace_all(r"\.", "")
.fill_null("")
.str.replace_all("None", "")
.alias(col)
for col in seq_cols
],
]
)
# Step 2: prepare sequences (concatenate with separators, apply padding)
# This happens ONCE upfront, not per-pair
seqs_raw = dat_seq_clean.select(seq_cols).to_numpy().tolist()
prepared_seqs = prepare_sequences_with_separator(
seqs_raw,
metric=metric,
pad_to_max=pad_to_max,
sep="#",
)
# Step 3: initialize
n = dat_seq_clean.height
cell_id_list = dat_seq_clean["cell_id"].to_list()
cell_id_to_idx = {cell_id: idx for idx, cell_id in enumerate(cell_id_list)}
if membership is None:
# Full mode: dense computation, convert to CSR at end
results = metric.compute_vectorized(prepared_seqs, n_cpus=n_cpus)
np.fill_diagonal(results, 0.0)
total_dist = csr_matrix(results)
else:
# Clone mode: sparse COO accumulation — no dense N×N allocation
all_rows: list[np.ndarray] = []
all_cols: list[np.ndarray] = []
all_vals: list[np.ndarray] = []
dat_seq_with_membership = dat_seq_clean.join(
membership, on="cell_id", how="inner"
)
groups = dat_seq_with_membership.partition_by(
"membership_id", as_dict=True
)
for group_df in tqdm(
groups.values(),
disable=not verbose,
bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}",
):
if group_df.height > 1:
tmp_cell_ids = group_df["cell_id"].to_list()
# Map cell_ids to indices
indices = np.array(
[cell_id_to_idx[cid] for cid in tmp_cell_ids]
)
# Extract prepared sequences for this clone
clone_seqs = [prepared_seqs[i] for i in indices]
# Deduplicate sequences for faster computation
unique_seqs = list(set(clone_seqs))
seq_to_unique_idx = {
seq: i for i, seq in enumerate(unique_seqs)
}
# Compute distances only for unique sequences
d_mat_unique = metric.compute_vectorized(
unique_seqs, n_cpus=n_cpus
)
# Use vectorized indexing to expand to full matrix
unique_indices = np.array(
[seq_to_unique_idx[seq] for seq in clone_seqs]
)
d_mat_tmp = d_mat_unique[np.ix_(unique_indices, unique_indices)]
# Collect COO entries (exclude diagonal — self-distance is 0)
k = len(indices)
row_global = np.repeat(indices, k)
col_global = np.tile(indices, k)
vals_flat = d_mat_tmp.ravel()
off_diag = row_global != col_global
all_rows.append(row_global[off_diag])
all_cols.append(col_global[off_diag])
all_vals.append(vals_flat[off_diag])
if all_rows:
total_dist = csr_matrix(
coo_matrix(
(
np.concatenate(all_vals),
(
np.concatenate(all_rows),
np.concatenate(all_cols),
),
),
shape=(n, n),
)
)
else:
total_dist = csr_matrix((n, n))
if verbose:
end_time = time.time()
logg.info(
f"Distances calculated in {end_time - start_time:.2f} seconds"
)
return total_dist