#!/usr/bin/env python
from __future__ import annotations
import copy
import h5py
import os
import re
import warnings
import networkx as nx
import numpy as np
import pandas as pd
from changeo.IO import readGermlines
from collections import defaultdict
from pandas.api.types import infer_dtype
from pathlib import Path
from scanpy import logging as logg
from scipy.sparse import csr_matrix
from textwrap import dedent
from tqdm import tqdm
from typing import Literal
from dandelion.utilities._utilities import (
RECEPTOR_SET,
TRUES,
FALSES,
EMPTIES_STR,
CHECK_COLS,
MUTATIONS,
VDJLENGTHS,
SEQINFO,
deprecated,
cmp_to_key,
present,
all_missing,
all_missing2,
same_call,
sanitize_data_for_saving,
sanitize_data,
format_isotype1,
format_isotype2,
format_locus,
lib_type,
movecol,
format_chain_status,
clear_h5file,
get_vcall_key,
Tree,
write_fasta,
)
from dandelion.external.anndata._compat import (
_normalize_index,
unpack_index,
Index,
)
class Dandelion:
"""Dandelion class object."""
def __init__(
self,
data: pd.DataFrame | Path | str | None = None,
metadata: pd.DataFrame | None = None,
germline: dict[str, str] | None = None,
layout: tuple[dict[str, np.array], dict[str, np.array]] | None = None,
graph: tuple[nx.Graph, nx.Graph] | None = None,
distances: csr_matrix | None = None,
initialize: bool = True,
library_type: Literal["tr-ab", "tr-gd", "ig"] | None = None,
verbose: bool = True,
**kwargs,
) -> None:
"""
Init method for Dandelion.
Parameters
----------
data : pd.DataFrame | Path | str | None, optional
AIRR formatted data.
metadata : pd.DataFrame | None, optional
AIRR data collapsed per cell.
germline : dict[str, str] | None, optional
dictionary of germline gene:sequence records.
layout : tuple[dict[str, np.array], dict[str, np.array]] | None, optional
node positions for computed graph.
graph : tuple[nx.Graph, nx.Graph] | None, optional
networkx graphs for clonotype networks.
distances : csr_matrix | None, optional
distance matrix for sequences.
initialize : bool, optional
whether or not to initialize `.metadata` slot.
library_type : Literal["tr-ab", "tr-gd", "ig"] | None, optional
One of "tr-ab", "tr-gd", "ig".
verbose : bool, optional
whether or not to print initialization messages.
**kwargs
passed to `Dandelion.update_metadata`.
"""
self._data = data
self._metadata = metadata
self.layout = layout
self.graph = graph
self.distances = distances
self.germline = {}
self.querier = None
self.library_type = library_type
self.data = self._data
self.metadata = self._metadata
if germline is not None:
self.germline.update(germline)
if self.data is not None:
if self.library_type is not None:
acceptable = lib_type(self.library_type)
else:
acceptable = None
if acceptable is not None:
self._data = self._data[
self._data.locus.isin(acceptable)
].copy()
try:
self._data = check_travdv(self._data)
except:
pass
if (
pd.Series(["cell_id", "umi_count", "productive"])
.isin(self._data.columns)
.all()
): # sort so that the productive contig with the largest umi is first
self._data.sort_values(
by=["cell_id", "productive", "umi_count"],
inplace=True,
ascending=[True, False, False],
)
# self._data = sanitize_data(self._data) # this is too slow. and unnecessary at this point.
self.n_contigs = self._data.shape[0]
if metadata is None:
if initialize is True:
self._ensure_sanitized_data(verbose=verbose)
self.update_metadata(**kwargs)
try:
self.n_obs = self._metadata.shape[0]
except:
self.n_obs = 0
else:
self._metadata = metadata
self.n_obs = self._metadata.shape[0]
else:
self.n_contigs = 0
self.n_obs = 0
# self._original_data_ids = self._data.index.copy()
# self._original_metadata_ids = self._metadata.index.copy()
self._original_sequence_ids = self._data["sequence_id"].copy()
self._original_cell_ids = self._data["cell_id"].copy()
def _gen_repr(self, n_obs, n_contigs) -> str:
"""Report."""
# inspire by AnnData's function
descr = f"Dandelion class object with n_obs = {n_obs} and n_contigs = {n_contigs}"
for attr in ["data", "metadata"]:
try:
keys = getattr(self, attr).keys()
except AttributeError:
keys = []
if len(keys) > 0:
descr += f"\n {attr}: {str(list(keys))[1:-1]}"
if self.layout is not None:
descr += f"\n layout: {', '.join(['layout for '+ str(len(x)) + ' vertices' for x in (self.layout[0], self.layout[1]) if x is not None])}"
if self.graph is not None:
descr += f"\n graph: {', '.join(['networkx graph of '+ str(len(x)) + ' vertices' for x in (self.graph[0], self.graph[1]) if x is not None])} "
if self.distances is not None:
descr += f"\n distances: distance matrix of shape {self.distances.shape}"
return descr
def __repr__(self) -> str:
"""Report."""
# inspire by AnnData's function
return self._gen_repr(self.n_obs, self.n_contigs)
def __getitem__(self, index: Index) -> Dandelion:
"""Return a sliced Dandelion object with synchronized data and metadata."""
# Determine index type (metadata-based or data-based)
if isinstance(index, np.ndarray):
if len(index) == self._metadata.shape[0]:
idx, idxtype = self._normalize_indices(
self._metadata.index[index]
)
elif len(index) == self._data.shape[0]:
idx, idxtype = self._normalize_indices(self._data.index[index])
else:
raise IndexError(
"Index length does not match either metadata or data dimensions."
)
else:
# Expecting index to be a boolean Series or DataFrame subset
idx, idxtype = self._normalize_indices(index[index].index)
# Slice data and metadata based on idxtype
if idxtype == "metadata":
selected_cells = self._metadata.iloc[idx].index
_metadata = self._metadata.loc[selected_cells]
_data = self._data[self._data["cell_id"].isin(selected_cells)]
if self.distances is not None:
# also filter distances matrix accordingly. the distance matrix is in the same order as metadata
_distances = self.distances[idx, :][:, idx]
if isinstance(_distances, csr_matrix):
_distances._index_names = _metadata.index
else:
_distances = None
elif idxtype == "data":
selected_cells = self._data.iloc[idx]["cell_id"]
_data = self._data.iloc[idx]
_metadata = self._metadata.loc[
self._metadata.index.intersection(selected_cells)
]
if self.distances is not None:
# get the indices of the selected cells in the metadata before filtering
# using np.where to preserve duplicates
meta_indices = np.where(
self._metadata.index.isin(selected_cells)
)[0]
_distances = self.distances[meta_indices, :][:, meta_indices]
if isinstance(_distances, csr_matrix):
_distances._index_names = _metadata.index
else:
_distances = None
else:
raise TypeError(f"Unrecognized idxtype: {idxtype}")
# --- Final synchronization step ---
valid_cells = set(_data["cell_id"]).intersection(_metadata.index)
_data = _data[_data["cell_id"].isin(valid_cells)].copy()
_metadata = _metadata.loc[_metadata.index.isin(valid_cells)].copy()
# -------------------------------------
# Filter layout and graph if present
_keep_cells = valid_cells
if self.layout is not None:
_layout0 = {
k: r for k, r in self.layout[0].items() if k in _keep_cells
}
_layout1 = {
k: r for k, r in self.layout[1].items() if k in _keep_cells
}
_layout = (_layout0, _layout1)
else:
_layout = None
if self.graph is not None:
_g0 = self.graph[0].subgraph(_keep_cells)
_g1 = self.graph[1].subgraph(
[n for n in self.graph[1].nodes if n in _keep_cells]
)
_graph = (_g0, _g1)
else:
_graph = None
# Construct new object
return Dandelion(
data=_data,
metadata=_metadata,
layout=_layout,
graph=_graph,
distances=_distances,
verbose=False,
)
@property
def data(self) -> pd.DataFrame:
"""One-dimensional annotation of contig observations.
Returns
-------
pd.DataFrame
The underlying contig-level data frame.
"""
return self._data
@data.setter
def data(self, value: pd.DataFrame):
"""data setter"""
value = load_data(value)
self._set_dim_df(value, "data")
@property
def data_names(self) -> pd.Index:
"""Names of observations (alias for `.data.index`).
Returns
-------
pd.Index
Index of sequence_id values.
"""
return self._data.index
@data_names.setter
def data_names(self, names: list[str]):
"""data names setter"""
names = self._prep_dim_index(names, "data")
self._set_dim_index(names, "data")
@property
def metadata(self) -> pd.DataFrame:
"""One-dimensional annotation of cell observations.
Returns
-------
pd.DataFrame
The underlying cell-level metadata frame.
"""
return self._metadata
@metadata.setter
def metadata(self, value: pd.DataFrame):
"""metadata setter"""
self._set_dim_df(value, "metadata")
@property
def metadata_names(self) -> pd.Index:
"""Names of observations (alias for `.metadata.index`).
Returns
-------
pd.Index
Index of cell_id values.
"""
return self._metadata.index
@metadata_names.setter
def metadata_names(self, names: list[str]):
"""metadata names setter"""
names = self._prep_dim_index(names, "metadata")
self._set_dim_index(names, "metadata")
def _ensure_sanitized_data(self, verbose: bool = False) -> None:
"""Ensure that the data is sanitized."""
if not self._is_sanitized(self._data):
if verbose:
logg.info(
"The AIRR data needs to undergo sanitization, apologies for any delays..."
)
self._data = sanitize_data(self._data)
def _is_sanitized(self, df):
"""Check if the data is sanitized."""
check = []
for col in CHECK_COLS:
if col in self._data:
# check that in these columns, all values are str 'T' or 'F'
if not all(df[col].isin(TRUES + FALSES)):
check.append(False)
else:
check.append(True)
return True if all(check) else False
def _normalize_indices(self, index: Index) -> tuple[slice, str]:
"""retrieve indices"""
return _normalize_indices(index, self.metadata_names, self.data_names)
def _set_dim_df(self, value: pd.DataFrame, attr: str):
"""dim df setter"""
if value is not None:
if not isinstance(value, pd.DataFrame):
raise ValueError(f"Can only assign pd.DataFrame to {attr}.")
value_idx = self._prep_dim_index(value.index, attr)
setattr(self, f"_{attr}", value)
def _prep_dim_index(self, value, attr: str) -> pd.Index:
"""Prepares index to be uses as metadata_names or data_names for Dandelion object.
If a pd.Index is passed, this will use a reference, otherwise a new index object is created.
"""
if isinstance(value, pd.Index) and not isinstance(
value.name, (str, type(None))
):
raise ValueError(
f"Dandelion expects .{attr}.index.name to be a string or None, "
f"but you passed a name of type {type(value.name).__name__!r}"
)
else:
value = pd.Index(value)
if not isinstance(value.name, (str, type(None))):
value.name = None
# fmt: off
if (
not isinstance(value, pd.RangeIndex)
and not infer_dtype(value) in ("string", "bytes")
):
sample = list(value[: min(len(value), 5)])
warnings.warn(dedent(
f"""
Dandelion expects .{attr}.index to contain strings, but got values like:
{sample}
Inferred to be: {infer_dtype(value)}
"""
), # noqa
stacklevel=2,
)
# fmt: on
return value
def _set_dim_index(self, value: pd.Index, attr: str) -> None:
"""set dim index"""
# Assumes _prep_dim_index has been run
getattr(self, attr).index = value
for v in getattr(self, f"{attr}m").values():
if isinstance(v, pd.DataFrame):
v.index = value
def _update_ids(
self,
column: str,
operation: str,
value: str,
sync: bool = True,
sep: str | None = None,
remove_trailing_hyphen_number: bool = False,
**kwargs,
) -> None:
"""
Internal method to update IDs and optionally sync changes.
Parameters
----------
column : str
The column to update ('sequence_id' or 'cell_id').
operation : str
The operation to perform ('prefix' or 'suffix').
value : str
The value to add as prefix or suffix.
sync : bool, optional
Whether to sync changes to the other column, by default True.
sep : str, optional
Separator to use when adding prefix or suffix, by default None, which means no separator.
remove_trailing_hyphen_number : bool, optional
Whether to remove trailing hyphen numbers, by default False.
**kwargs
Additional arguments to pass to the update_metadata method
"""
other_column = "cell_id" if column == "sequence_id" else "sequence_id"
sep = "" if sep is None else sep
original_values = (
self._original_sequence_ids
if column == "sequence_id"
else self._original_cell_ids
)
clean_func = (
self._clean_sequence_id
if column == "sequence_id"
else self._clean_cell_id
)
cleaned_values = [
clean_func(x, remove_trailing_hyphen_number)
for x in original_values.astype(str)
]
if operation == "prefix":
self._data[column] = [value + sep + x for x in cleaned_values]
elif operation == "suffix":
self._data[column] = [x + sep + value for x in cleaned_values]
if sync:
other_original = (
self._original_cell_ids
if column == "sequence_id"
else self._original_sequence_ids
)
other_clean_func = (
self._clean_cell_id
if column == "sequence_id"
else self._clean_sequence_id
)
cleaned_other = [
other_clean_func(x, remove_trailing_hyphen_number)
for x in other_original.astype(str)
]
if operation == "prefix":
self._data[other_column] = [
value + sep + x for x in cleaned_other
]
elif operation == "suffix":
self._data[other_column] = [
x + sep + value for x in cleaned_other
]
self._data = load_data(self._data)
if self.metadata is not None:
self.update_metadata(**kwargs)
def _clean_sequence_id(
self, value: str, remove_trailing_hyphen_number: bool = False
) -> str:
"""
Clean sequence_id based on specified rules.
Parameters
----------
value : str
Original sequence_id value.
remove_trailing_hyphen_number : bool, optional
Whether to remove trailing hyphen numbers and _contig suffix, by default False.
Returns
-------
str
Cleaned sequence_id value.
"""
if remove_trailing_hyphen_number:
# First remove _contig and everything after it, then remove trailing hyphen number
return (
value.split("_contig")[0].split("-")[0]
+ "_contig"
+ value.split("_contig")[1]
)
return value
def _clean_cell_id(
self, value: str, remove_trailing_hyphen_number: bool = False
) -> str:
"""
Clean cell_id based on specified rules.
Parameters
----------
value : str
Original cell_id value.
remove_trailing_hyphen_number : bool, optional
Whether to remove trailing hyphen numbers, by default False.
Returns
-------
str
Cleaned cell_id value.
"""
if remove_trailing_hyphen_number:
# Remove the last occurrence of hyphen and everything after it
return value.rsplit("-", 1)[0]
return value
[docs]
def add_sequence_prefix(
self,
prefix: str,
sync: bool = True,
remove_trailing_hyphen_number: bool = False,
**kwargs,
) -> None:
"""
Add prefix to sequence_id and then apply to cell_id as well.
Parameters
----------
prefix : str
Prefix to add to the IDs.
sync : bool, optional
Whether to apply the same prefix to cell_id, by default True.
remove_trailing_hyphen_number : bool, optional
Whether to remove trailing hyphen numbers before adding prefix, by default False.
**kwargs
Additional arguments to pass to the update_metadata method
"""
self._update_ids(
column="sequence_id",
operation="prefix",
value=prefix,
sync=sync,
remove_trailing_hyphen_number=remove_trailing_hyphen_number,
**kwargs,
)
[docs]
def add_sequence_suffix(
self,
suffix: str,
sync: bool = True,
remove_trailing_hyphen_number: bool = False,
**kwargs,
) -> None:
"""
Add suffix to sequence_id and then apply to cell_id as well.
Parameters
----------
suffix : str
Suffix to add to the IDs.
sync : bool, optional
Whether to apply the same suffix to cell_id, by default True.
remove_trailing_hyphen_number : bool, optional
Whether to remove trailing hyphen numbers before adding suffix, by default False.
**kwargs
Additional arguments to pass to the update_metadata method
"""
self._update_ids(
column="sequence_id",
operation="suffix",
value=suffix,
sync=sync,
remove_trailing_hyphen_number=remove_trailing_hyphen_number,
**kwargs,
)
[docs]
def add_cell_prefix(
self,
prefix: str,
sync: bool = True,
remove_trailing_hyphen_number: bool = False,
**kwargs,
) -> None:
"""
Add prefix to cell_id and optionally to sequence_id.
Parameters
----------
prefix : str
Prefix to add to the IDs.
sync : bool, optional
Whether to apply the same prefix to sequence_id, by default True.
remove_trailing_hyphen_number : bool, optional
Whether to remove trailing hyphen numbers before adding prefix, by default False.
**kwargs
Additional arguments to pass to the update_metadata method
"""
self._update_ids(
column="cell_id",
operation="prefix",
value=prefix,
sync=sync,
remove_trailing_hyphen_number=remove_trailing_hyphen_number,
**kwargs,
)
[docs]
def add_cell_suffix(
self,
suffix: str,
sync: bool = True,
remove_trailing_hyphen_number: bool = False,
**kwargs,
) -> None:
"""
Add suffix to cell_id and optionally to sequence_id.
Parameters
----------
suffix : str
Suffix to add to the IDs.
sync : bool, optional
Whether to apply the same suffix to sequence_id, by default True.
remove_trailing_hyphen_number : bool, optional
Whether to remove trailing hyphen numbers before adding suffix, by default False.
**kwargs
Additional arguments to pass to the update_metadata method
"""
self._update_ids(
column="cell_id",
operation="suffix",
value=suffix,
sync=sync,
remove_trailing_hyphen_number=remove_trailing_hyphen_number,
**kwargs,
)
# def reset_ids(self) -> None:
# """
# Reset both IDs to their original values.
# This method restores both sequence_id and cell_id in the .data and .metadata slots to their original state when the Dandelion class was initialized.
# """
# self._data.index = self._original_data_ids
# self._metadata.index = self._original_metadata_ids
# self._data["sequence_id"] = self._original_sequence_ids
# self._data["cell_id"] = self._original_cell_ids
[docs]
def simplify(self, **kwargs) -> None:
"""Disambiguate VDJ and C gene calls when there's multiple calls separated by commas and strip the alleles.
Parameters
----------
**kwargs
Additional arguments passed to `update_metadata`.
"""
# strip alleles from VDJ and constant gene calls
for col in ["v_call", "v_call_genotyped", "d_call", "j_call", "c_call"]:
if col in self._data:
self._data[col] = self._data[col].str.replace(
r"\*.*", "", regex=True
)
# only keep the main annotation
self._data[col] = self._data[col].str.split(",").str[0]
self.update_metadata(**kwargs)
[docs]
def update_data(self, skip: list[str] = []) -> None:
"""Sync missing metadata columns into data via dictionary mapping.
Parameters
----------
skip : list[str], optional
List of column names to skip when syncing metadata to data. Defaults to an empty list.
"""
new_cols_added = []
for col in self._metadata.columns:
# skip blacklisted columns
if col in skip:
continue
# skip columns that already exist in data
if col in self._data.columns:
continue
# skip if base column already exists (for _VDJ, _VJ, _B, _abT, _gdT variants, _status, _main, etc.)
base_col = col.split("_")[0]
if base_col in self._data.columns:
continue
# create a mapping dictionary and assign new column
mapping = self._metadata[col].to_dict()
self._data[col] = self._data["cell_id"].map(mapping)
new_cols_added.append(col)
def _initialize_metadata(
self,
cols: list[str],
clonekey: str,
v_call_key: str,
collapse_alleles: bool,
report_productive_only: bool,
reinitialize: bool,
custom_isotype_dict: dict[str, str] | None = None,
) -> None:
"""Initialize Dandelion metadata."""
init_dict = {}
for col in cols:
init_dict.update(
{
col: {
"query": col,
"retrieve_mode": "split and merge",
}
}
)
if clonekey in init_dict:
init_dict.update(
{
clonekey: {
"query": clonekey,
"retrieve_mode": "merge and unique only",
}
}
)
if "sample_id" in init_dict:
init_dict.update(
{
"sample_id": {
"query": "sample_id",
"retrieve_mode": "merge and unique only",
}
}
)
self._update_rearrangement_status(v_call_key)
if "ambiguous" in self._data:
dataq = self._data[self._data["ambiguous"].isin(FALSES)]
else:
dataq = self._data
if self.querier is None:
querier = Query(dataq, productive_only=report_productive_only)
self.querier = querier
else:
if self.metadata is not None:
if reinitialize:
querier = Query(
dataq, productive_only=report_productive_only
)
else:
if any(~self.metadata_names.isin(self._data.cell_id)):
querier = Query(
dataq, productive_only=report_productive_only
)
self.querier = querier
else:
querier = self.querier
else:
querier = self.querier
meta_ = defaultdict(dict)
for k, v in init_dict.copy().items():
if all_missing(self._data[k]):
init_dict.pop(k)
continue
meta_[k] = querier.retrieve(**v)
if k in [
"v_call",
"v_call_genotyped",
"d_call",
"j_call",
"c_call",
"productive",
]:
meta_[k + "_split"] = querier.retrieve_celltype(**v)
if k in ["umi_count", "mu_count", "mu_freq"]:
v.update({"retrieve_mode": "split and sum"})
meta_[k] = querier.retrieve_celltype(**v)
tmp_metadata = pd.concat(meta_.values(), axis=1, join="inner")
reqcols1 = [
"locus_VDJ",
]
vcall = get_vcall_key(self._data, v_call_key)
# remap v_call_genotyped_* to just v_call_* for column names in tmp_metadata if vcall == "v_call_genotyped"
if vcall == "v_call_genotyped":
for col in tmp_metadata.columns:
if col.startswith("v_call_genotyped"):
new_col = col.replace("v_call_genotyped", "v_call")
tmp_metadata.rename(columns={col: new_col}, inplace=True)
# This way, the function will only initialise as v_call regardless of whether v_call_key is v_call or v_call_genotyped
reqcols2 = [
"locus_VJ",
"productive_VDJ",
"productive_VJ",
"v_call_VDJ",
"d_call_VDJ",
"j_call_VDJ",
"v_call_VJ",
"j_call_VJ",
"c_call_VDJ",
"c_call_VJ",
"junction_VDJ",
"junction_VJ",
"junction_aa_VDJ",
"junction_aa_VJ",
"v_call_B_VDJ",
"d_call_B_VDJ",
"j_call_B_VDJ",
"v_call_B_VJ",
"j_call_B_VJ",
"c_call_B_VDJ",
"c_call_B_VJ",
"productive_B_VDJ",
"productive_B_VJ",
"v_call_abT_VDJ",
"d_call_abT_VDJ",
"j_call_abT_VDJ",
"v_call_abT_VJ",
"j_call_abT_VJ",
"c_call_abT_VDJ",
"c_call_abT_VJ",
"productive_abT_VDJ",
"productive_abT_VJ",
"v_call_gdT_VDJ",
"d_call_gdT_VDJ",
"j_call_gdT_VDJ",
"v_call_gdT_VJ",
"j_call_gdT_VJ",
"c_call_gdT_VDJ",
"c_call_gdT_VJ",
"productive_gdT_VDJ",
"productive_gdT_VJ",
]
reqcols = reqcols1 + reqcols2
for rc in reqcols:
if rc not in tmp_metadata:
tmp_metadata[rc] = ""
for dc in [
"d_call_VJ",
"d_call_B_VJ",
"d_call_abT_VJ",
"d_call_gdT_VJ",
]:
if dc in tmp_metadata:
tmp_metadata.drop(dc, axis=1, inplace=True)
for _call in ["v_call", "d_call", "j_call", "c_call"]:
tmp_metadata[_call + "_VDJ_main"] = [
return_none_call(x) for x in tmp_metadata[_call + "_VDJ"]
]
if _call != "d_call":
tmp_metadata[_call + "_VJ_main"] = [
return_none_call(x) for x in tmp_metadata[_call + "_VJ"]
]
for mode in ["B", "abT", "gdT"]:
tmp_metadata["v_call_" + mode + "_VDJ_main"] = [
return_none_call(x)
for x in tmp_metadata["v_call_" + mode + "_VDJ"]
]
tmp_metadata["d_call_" + mode + "_VDJ_main"] = [
return_none_call(x)
for x in tmp_metadata["d_call_" + mode + "_VDJ"]
]
tmp_metadata["j_call_" + mode + "_VDJ_main"] = [
return_none_call(x)
for x in tmp_metadata["j_call_" + mode + "_VDJ"]
]
tmp_metadata["v_call_" + mode + "_VJ_main"] = [
return_none_call(x)
for x in tmp_metadata["v_call_" + mode + "_VJ"]
]
tmp_metadata["j_call_" + mode + "_VJ_main"] = [
return_none_call(x)
for x in tmp_metadata["j_call_" + mode + "_VJ"]
]
if "locus_VDJ" in tmp_metadata:
suffix_vdj = "_VDJ"
# suffix_vj = "_VJ"
else:
suffix_vdj = ""
# suffix_vj = ""
if clonekey in init_dict:
tmp_metadata = _add_clone_info(
tmp_metadata=tmp_metadata, clonekey=str(clonekey)
)
conversion_dict = {
"IGHA": "IgA",
"IGHD": "IgD",
"IGHE": "IgE",
"IGHG": "IgG",
"IGHM": "IgM",
"IGKC": "IgK",
"IGLC": "IgL",
}
if custom_isotype_dict is not None:
conversion_dict.update(custom_isotype_dict)
isotype = []
if "c_call" + suffix_vdj in tmp_metadata:
for k, p in zip(
tmp_metadata["c_call" + suffix_vdj],
tmp_metadata["productive" + suffix_vdj],
):
if isinstance(k, str):
if report_productive_only:
isotype.append(
"|".join(
[
str(z)
for z, pp in zip(
[
(
conversion_dict[
y.split(",")[0][:4]
]
if y.split(",")[0][:4]
in conversion_dict
else None
)
for y in k.split("|")
],
p.split("|"),
)
if pp in TRUES + EMPTIES_STR
and z is not None
]
)
)
else:
isotype.append(
"|".join(
[
str(z)
for z in [
(
conversion_dict[y.split(",")[0][:4]]
if y.split(",")[0][:4]
in conversion_dict
else None
)
for y in k.split("|")
if (
conversion_dict.get(
y.split(",")[0][:4], None
)
is not None
)
]
]
)
)
else:
isotype.append(None)
isotype = [x if x != "" else None for x in isotype]
tmp_metadata["isotype"] = isotype
tmp_metadata["isotype_status"] = format_isotype1(tmp_metadata)
vdj_gene_calls = ["v_call", "d_call", "j_call"]
if collapse_alleles:
for x in vdj_gene_calls:
if x in self._data:
for c in tmp_metadata:
if x in c:
tmp_metadata[c] = [
(
"|".join(
[
",".join(list(set(yy.split(","))))
for yy in [
re.sub("[*][0-9][0-9]", "", tx)
for tx in t.split("|")
]
]
)
if isinstance(t, str)
else None
)
for t in tmp_metadata[c]
]
tmp_metadata["locus_status"] = format_locus(
tmp_metadata,
vcall="v_call",
productive_only=report_productive_only,
)
tmp_metadata["chain_status"] = format_chain_status(
tmp_metadata["locus_status"]
)
tmp_metadata["isotype_status"] = format_isotype2(tmp_metadata)
if "isotype" in tmp_metadata:
if tmp_metadata["isotype"].isna().all():
tmp_metadata.drop(
["isotype", "isotype_status"], axis=1, inplace=True
)
for rc in reqcols:
tmp_metadata[rc] = tmp_metadata[rc].replace(["", "None"], None)
if clonekey in init_dict:
tmp_metadata[clonekey] = tmp_metadata[clonekey].replace(
["", "None"], None
)
tmp_metadata = movecol(
tmp_metadata,
cols_to_move=[rc2 for rc2 in reqcols2 if rc2 in tmp_metadata],
ref_col="locus_VDJ",
)
for tmpm in tmp_metadata:
if all_missing2(tmp_metadata[tmpm]):
tmp_metadata.drop(tmpm, axis=1, inplace=True)
tmpxregstat = querier.retrieve(
query="rearrangement_status", retrieve_mode="split and unique only"
)
for x in tmpxregstat:
tmpxregstat[x] = [
(
"chimeric"
if isinstance(y, str) and re.search("chimeric", y)
else "Multi" if isinstance(y, str) and "|" in y else y
)
for y in tmpxregstat[x]
]
tmp_metadata[x] = pd.Series(tmpxregstat[x])
tmp_metadata = movecol(
tmp_metadata,
cols_to_move=[
rs
for rs in [
"rearrangement_status_VDJ",
"rearrangement_status_VJ",
]
if rs in tmp_metadata
],
ref_col="chain_status",
)
# if metadata already exist, just overwrite the default columns?
if self.metadata is not None:
if any(~self.metadata_names.isin(self._data.cell_id)):
self._metadata = tmp_metadata.copy() # reindex and replace.
for col in tmp_metadata:
self._metadata[col] = pd.Series(tmp_metadata[col])
else:
self._metadata = tmp_metadata.copy()
def _update_rearrangement_status(self, v_call_key: str) -> None:
"""Check rearrangement status."""
vcall = get_vcall_key(self._data, v_call_key)
contig_status = []
for v, j, c in zip(
self._data[vcall], self._data["j_call"], self._data["c_call"]
):
if present(v):
if present(j):
if present(c):
if len(list({v[:3], j[:3], c[:3]})) > 1:
contig_status.append("chimeric")
else:
contig_status.append("standard")
else:
if len(list({v[:3], j[:3]})) > 1:
contig_status.append("chimeric")
else:
contig_status.append("standard")
else:
contig_status.append("unknown")
else:
contig_status.append("unknown")
self._data["rearrangement_status"] = contig_status
[docs]
def compute(self):
"""Convert self.distances to a concrete csr matrix."""
if not isinstance(self.distances, csr_matrix):
try:
self.distances = csr_matrix(self.distances.compute())
except Exception:
self.distances = csr_matrix(self.distances)
self.distances._index_names = self.metadata_names
[docs]
def copy(self) -> Dandelion:
"""
Performs a deep copy of all slots in Dandelion class.
Returns
-------
Dandelion
a deep copy of Dandelion class.
"""
return copy.deepcopy(self)
[docs]
def update_plus(
self,
option: Literal[
"all",
"sequence",
"mutations",
"cdr3 lengths",
"mutations and cdr3 lengths",
] = "mutations and cdr3 lengths",
**kwargs,
) -> None:
"""
Retrieve additional data columns that are useful.
Parameters
----------
option : Literal["all", "sequence", "mutations", "cdr3 lengths", "mutations and cdr3 lengths", ], optional
One of 'all', 'sequence', 'mutations', 'cdr3 lengths',
'mutations and cdr3 lengths'
**kwargs
passed to `Dandelion.update_metadata`.
"""
mutations = [x for x in MUTATIONS if x in self._data]
vdjlengths = [x for x in VDJLENGTHS if x in self._data]
seqinfo = [x for x in SEQINFO if x in self._data]
if option == "all":
if len(mutations) > 0:
self.update_metadata(
retrieve=mutations,
retrieve_mode="split and sum",
**kwargs,
)
self.update_metadata(
retrieve=mutations, retrieve_mode="sum", **kwargs
)
if len(vdjlengths) > 0:
self.update_metadata(
retrieve=vdjlengths,
retrieve_mode="split and average",
**kwargs,
)
if len(seqinfo) > 0:
self.update_metadata(
retrieve=seqinfo, retrieve_mode="split and merge", **kwargs
)
if option == "sequence":
if len(seqinfo) > 0:
self.update_metadata(
retrieve=seqinfo, retrieve_mode="split and merge", **kwargs
)
if option == "mutations":
if len(mutations) > 0:
self.update_metadata(
retrieve=mutations,
retrieve_mode="split and sum",
**kwargs,
)
self.update_metadata(
retrieve=mutations, retrieve_mode="sum", **kwargs
)
if option == "cdr3 lengths":
if len(vdjlengths) > 0:
self.update_metadata(
retrieve=vdjlengths,
retrieve_mode="split and average",
**kwargs,
)
if option == "mutations and cdr3 lengths":
if len(mutations) > 0:
self.update_metadata(
retrieve=mutations,
retrieve_mode="split and sum",
**kwargs,
)
self.update_metadata(
retrieve=mutations, retrieve_mode="sum", **kwargs
)
if len(vdjlengths) > 0:
self.update_metadata(
retrieve=vdjlengths,
retrieve_mode="split and average",
**kwargs,
)
[docs]
def store_germline_reference(
self,
corrected: dict[str, str] | str | None = None,
germline: str | None = None,
org: Literal["human", "mouse"] = "human",
db: Literal["imgt", "ogrdb"] = "imgt",
) -> None:
"""
Update germline reference with corrected sequences and store in Dandelion object.
Parameters
----------
corrected : dict[str, str] | str | None, optional
dictionary of corrected germline sequences or file path to corrected germline sequences fasta file.
germline : str | None, optional
path to germline database folder. Defaults to `` environmental variable.
org : Literal["human", "mouse"], optional
organism of reference folder. Default is 'human'.
db : Literal["imgt", "ogrdb"], optional
database of reference sequences. Default is 'imgt'.
Raises
------
KeyError
if `GERMLINE` environmental variable is not set.
TypeError
if incorrect germline provided.
"""
start = logg.info("Updating germline reference")
env = os.environ.copy()
if germline is None:
try:
gml = Path(env["GERMLINE"])
except KeyError:
raise KeyError(
"Environmental variable GERMLINE must be set. Otherwise, "
+ "please provide path to folder containing germline IGHV, IGHD, and IGHJ fasta files."
)
gml = gml / db / org / "vdj"
else:
if isinstance(germline, list):
if len(germline) < 3:
raise TypeError(
"Input for germline is incorrect. Please provide path to folder containing germline IGHV, IGHD, "
+ "and IGHJ fasta files, or individual paths to the germline IGHV, IGHD, and IGHJ fasta "
+ "files (with .fasta extension) as a list."
)
else:
gml = []
for x in germline:
if not x.endswith((".fasta", ".fa")):
raise TypeError(
"Input for germline is incorrect. Please provide path to folder containing germline "
+ "IGHV, IGHD, and IGHJ fasta files, or individual paths to the germline IGHV, IGHD, and IGHJ fasta "
+ "files (with .fasta extension) as a list."
)
gml.append(x)
elif type(germline) is not list:
if os.path.isdir(germline):
germline_ = [
str(Path(germline, g)) for g in os.listdir(germline)
]
if len(germline_) < 3:
raise TypeError(
"Input for germline is incorrect. Please provide path to folder containing germline IGHV, "
+ "IGHD, and IGHJ fasta files, or individual paths to the germline IGHV, IGHD, and IGHJ "
+ "fasta files (with .fasta extension) as a list."
)
else:
gml = []
for x in germline_:
if not x.endswith((".fasta", ".fa")):
raise TypeError(
"Input for germline is incorrect. Please provide path to folder containing germline "
+ "IGHV, IGHD, and IGHJ fasta files, or individual paths to the germline IGHV, IGHD, "
+ "and IGHJ fasta files (with .fasta extension) as a list."
)
gml.append(x)
elif os.path.isfile(germline) and str(germline).endswith(
(".fasta", ".fa")
):
gml = []
gml.append(germline)
warnings.warn(
"Only 1 fasta file provided to updating germline slot. Please check if this is intended.",
RuntimeWarning,
stacklevel=2,
)
if type(gml) is not list:
gml = [gml]
gml = [str(g) for g in gml]
germline_ref = readGermlines(gml)
if corrected is not None:
if type(corrected) is dict:
personalized_ref_dict = corrected
elif os.path.isfile(str(corrected)):
personalized_ref_dict = readGermlines([str(corrected)])
# update with the personalized germline database
if "personalized_ref_dict" in locals():
germline_ref.update(personalized_ref_dict)
else:
raise TypeError(
"Input for corrected germline fasta is incorrect. Please provide path to file containing "
+ "corrected germline fasta sequences."
)
self.germline.update(germline_ref)
logg.info(
" finished",
time=start,
deep=(
"Updated Dandelion object: \n"
" 'germline', updated germline reference\n"
),
)
[docs]
def write_airr(
self, filename: str = "dandelion_airr.tsv", **kwargs
) -> None:
"""
Writes a Dandelion class to AIRR formatted .tsv format.
Parameters
----------
filename : str, optional
path to `.tsv` file.
**kwargs
passed to `pandas.DataFrame.to_csv`.
"""
data = sanitize_data(self._data)
data.to_csv(filename, sep="\t", index=False, **kwargs)
[docs]
def write_h5ddl(
self,
filename: str = "dandelion_data.h5ddl",
compression: (
Literal[
"gzip",
"lzf",
"szip",
]
| None
) = None,
compression_level: int | None = None,
):
"""
Writes a Dandelion class to .h5ddl format.
Parameters
----------
filename : str, optional
path to `.h5ddl` file.
compression : Literal["gzip", "lzf", "szip"], optional
Specifies the compression algorithm to use.
compression_level : int | None, optional
Specifies a compression level for data. A value of 0 disables compression.
"""
save_args = {
"compression": compression,
"compression_opts": (
9 if compression_level is None else compression_level
),
}
if compression is None:
save_args.pop("compression", None)
save_args.pop("compression_opts", None)
clear_h5file(filename)
# now to actually saving
data = self._data.copy()
data = sanitize_data(data)
data, data_dtypes = sanitize_data_for_saving(data)
# Convert the DataFrame to a NumPy structured array
structured_data_array = np.array(
[tuple(row) for row in data.to_numpy()], dtype=data_dtypes
)
with h5py.File(filename, "w") as hf:
hf.create_dataset(
"data",
data=structured_data_array,
**save_args,
)
if self.metadata is not None:
metadata = self._metadata.copy()
metadata, metadata_dtypes = sanitize_data_for_saving(metadata)
# Convert the DataFrame to a NumPy structured array
structured_metadata_array = np.array(
[tuple(row) for row in metadata.to_numpy()],
dtype=metadata_dtypes,
)
structured_metadata_names_array = np.array(
[s.encode("utf-8") for s in metadata.index.to_numpy()]
)
with h5py.File(filename, "a") as hf:
hf.create_dataset(
"metadata",
data=structured_metadata_array,
**save_args,
)
hf.create_dataset(
"metadata_names",
data=structured_metadata_names_array,
**save_args,
)
if self.graph is not None:
for i, g in enumerate(self.graph):
G_df = nx.to_pandas_adjacency(g, nonedge=np.nan)
G_x = csr_matrix(G_df.to_numpy())
G_column_array = np.array(
[s.encode("utf-8") for s in G_df.columns.to_numpy()]
)
G_index_array = np.array(
[s.encode("utf-8") for s in G_df.index.to_numpy()]
)
with h5py.File(filename, "a") as hf:
hf.create_dataset(
f"graph/graph_{str(i)}/data",
data=G_x.data,
**save_args,
)
hf.create_dataset(
f"graph/graph_{str(i)}/indices",
data=G_x.indices,
**save_args,
)
hf.create_dataset(
f"graph/graph_{str(i)}/indptr",
data=G_x.indptr,
**save_args,
)
hf.create_dataset(
f"graph/graph_{str(i)}/shape",
data=G_x.shape,
**save_args,
)
hf.create_dataset(
f"graph/graph_{str(i)}/column",
data=G_column_array,
**save_args,
)
hf.create_dataset(
f"graph/graph_{str(i)}/index",
data=G_index_array,
**save_args,
)
if self.distances is not None:
if isinstance(self.distances, csr_matrix):
with h5py.File(filename, "a") as hf:
hf.create_dataset(
"distances/data",
data=self.distances.data,
**save_args,
)
hf.create_dataset(
"distances/indices",
data=self.distances.indices,
**save_args,
)
hf.create_dataset(
"distances/indptr",
data=self.distances.indptr,
**save_args,
)
hf.create_dataset(
"distances/shape",
data=self.distances.shape,
**save_args,
)
else:
try:
import dask.array as da
if isinstance(self.distances, da.Array):
zarr_path = Path(filename).with_suffix(".zarr")
da.to_zarr(
self.distances,
str(zarr_path / "distance_matrix"),
overwrite=True,
)
logg.warning(
f"Distances are a dask array and cannot be stored "
f"inline in .h5ddl. Written to {zarr_path}. Pass "
f"`distance_zarr='{zarr_path}'` when reading, or "
f"it will be detected automatically."
)
except ImportError:
pass
if self.layout is not None:
for i, l in enumerate(self.layout):
with h5py.File(filename, "a") as hf:
layout_group = hf.create_group("layout/layout_" + str(i))
# Iterate through the dictionary and create datasets in the "layout" group
for key, value in l.items():
layout_group.create_dataset(
key,
data=value,
**save_args,
)
if len(self.germline) > 0:
with h5py.File(filename, "a") as hf:
hf.create_dataset(
"germline/keys",
data=np.array(list(self.germline.keys()), dtype="S"),
**save_args,
)
hf.create_dataset(
"germline/values",
data=np.array(list(self.germline.values()), dtype="S"),
**save_args,
)
write = write_ddl = write_h5ddl
[docs]
def write_vdj(
self,
folder: Path | str = "dandelion_data",
filename_prefix: str = "all",
sequence_key: str = "sequence",
clone_key: str = "clone_id",
) -> None:
"""
Writes a Dandelion object to contig-annotation formatted files compatible with
multiple platforms (10x Genomics, SeekGene, etc.) so that it can be ingested by
other tools.
Produces:
- ``{filename_prefix}_contig.fasta`` : sequences in FASTA format.
- ``{filename_prefix}_contig_annotations.csv`` : contig annotation table with
columns matching the 10x / SeekGene contig annotation schema.
Parameters
----------
folder : Path | str, optional
path to save the output files.
filename_prefix : str, optional
prefix for the output files.
sequence_key : str, optional
column name in `.data` slot to retrieve and write out in fasta format.
clone_key : str, optional
column name in `.data` slot for clone id information.
"""
folder = Path(folder) if isinstance(folder, str) else folder
folder.mkdir(parents=True, exist_ok=True)
out_fasta = folder / f"{filename_prefix}_contig.fasta"
out_anno_path = folder / f"{filename_prefix}_contig_annotations.csv"
seqs = self._data[sequence_key].to_dict()
write_fasta(seqs, out_fasta=out_fasta)
# also create the contig_annotations.csv
column_map = {
"barcode": "cell_id",
"is_cell": "is_cell_10x",
"contig_id": "sequence_id",
"high_confidence": "high_confidence_10x",
"length": "length",
"chain": "locus",
"v_gene": "v_call",
"d_gene": "d_call",
"j_gene": "j_call",
"c_gene": "c_call",
"full_length": "complete_vdj",
"productive": "productive",
"cdr3": "junction_aa",
"cdr3_nt": "junction",
"reads": "consensus_count",
"umis": "umi_count",
"raw_clonotype_id": clone_key,
"raw_consensus_id": clone_key,
}
if "complete_vdj" not in self._data.columns:
column_map.pop("full_length")
# Support both _10x-suffixed (10x CellRanger) and plain (SeekGene) column names.
is_cell_col = next(
(c for c in ["is_cell_10x", "is_cell"] if c in self._data.columns),
None,
)
if is_cell_col:
column_map["is_cell"] = is_cell_col
else:
column_map.pop("is_cell")
high_confidence_col = next(
(
c
for c in ["high_confidence_10x", "high_confidence"]
if c in self._data.columns
),
None,
)
if high_confidence_col:
column_map["high_confidence"] = high_confidence_col
else:
column_map.pop("high_confidence")
anno = []
bool_map = {
"T": "True",
"F": "False",
"True": "True",
"False": "False",
"TRUE": "True",
"FALSE": "False",
}
for _, r in self._data.iterrows():
info = []
for v in column_map.values():
if v in r.index:
info.append(r[v])
elif v in ["is_cell", "high_confidence"]:
info.append("True")
elif v == "length":
info.append(len(r[sequence_key]))
anno.append({k: r for k, r in zip(column_map.keys(), info)})
anno = pd.DataFrame(anno)
anno = anno.map(lambda x: bool_map[x] if x in bool_map.keys() else x)
anno.to_csv(out_anno_path, index=False)
[docs]
def write_10x(
self,
folder: Path | str = "dandelion_data",
filename_prefix: str = "all",
sequence_key: str = "sequence",
clone_key: str = "clone_id",
) -> None:
"""
Alias for :meth:`write_vdj` kept for backwards compatibility.
Parameters
----------
folder : Path | str, optional
path to save the output files.
filename_prefix : str, optional
prefix for the output files.
sequence_key : str, optional
column name in `.data` slot to retrieve and write out in fasta format.
clone_key : str, optional
column name in `.data` slot for clone id information.
"""
self.write_vdj(
folder=folder,
filename_prefix=filename_prefix,
sequence_key=sequence_key,
clone_key=clone_key,
)
class Query:
"""Query class"""
def __init__(
self,
data: pd.DataFrame,
productive_only: bool = True,
verbose: bool = False,
) -> None:
"""
Query class to retrieve data from the Dandelion object.
Parameters
----------
data : pd.DataFrame
Dataframe to query.
production_only : bool
where to only query productive contigs.
verbose : bool, optional
Whether to print the process, by default False.
"""
if productive_only:
data = data[data["productive"].isin(TRUES)]
self.data = data.copy()
self.Cell = Tree()
for contig, row in tqdm(
data.iterrows(),
desc="Setting up data",
disable=not verbose,
):
self.Cell[row["cell_id"]][contig].update(row)
@property
def querydtype(self):
"""Check dtype."""
return str(self.data[self.query].dtype)
def retrieve(
self,
query: str,
retrieve_mode: Literal[
"split and unique only",
"merge and unique only",
"split and merge",
"split and sum",
"split and average",
"split",
"merge",
"sum",
"average",
],
) -> pd.DataFrame:
"""
Retrieve query.
Parameters
----------
query : str
column name in `.data` slot to retrieve and update the metadata.
retrieve_mode : Literal["split and unique only", "merge and unique only", "split and merge", "split and sum", "split and average", "split", "merge", "sum", "average", ]
one of:
`split and unique only`
returns the retrieval splitted into two columns,
i.e. one for VDJ and one for VJ chains, separated by `|` for unique elements.
`merge and unique only`
returns the retrieval merged into one column,
separated by `|` for unique elements.
`split and merge`
returns the retrieval splitted into two columns,
i.e. one for VDJ and one for VJ chains, separated by `|` for every elements.
`split`
returns the retrieval splitted into separate columns for each contig.
`merge`
returns the retrieval merged into one columns for each contig,
separated by `|` for unique elements.
`split and sum`
returns the retrieval sum in the VDJ and VJ columns (separately).
`split and average`
returns the retrieval averaged in the VDJ and VJ columns (separately).
`sum`
returns the retrieval sum into one column for all contigs.
`average`
returns the retrieval averaged into one column for all contigs.
Returns
-------
pd.DataFrame
Retrieved data.
"""
self.query = query
ret = {}
for cell in self.Cell:
cols, vdj, vj = {}, [], []
for _, contig in self.Cell[cell].items():
if isinstance(contig, dict):
if contig["locus"] in ["IGH", "TRB", "TRD"]:
vdj.append(contig[query])
elif contig["locus"] in ["IGK", "IGL", "TRA", "TRG"]:
vj.append(contig[query])
if retrieve_mode == "split and unique only":
if len(vdj) > 0:
cols.update(
{
query
+ "_VDJ": "|".join(
str(x)
for x in list(dict.fromkeys(vdj))
if present(x)
)
}
)
if len(vj) > 0:
cols.update(
{
query
+ "_VJ": "|".join(
str(x)
for x in list(dict.fromkeys(vj))
if present(x)
)
}
)
elif retrieve_mode == "split and merge":
if len(vdj) > 0:
cols.update(
{
query
+ "_VDJ": "|".join(
str(x) for x in vdj if present(x)
)
}
)
if len(vj) > 0:
cols.update(
{
query
+ "_VJ": "|".join(str(x) for x in vj if present(x))
}
)
elif retrieve_mode == "merge and unique only":
cols.update(
{
query: "|".join(
str(x) for x in set(vdj + vj) if present(x)
)
}
)
elif retrieve_mode == "split and sum":
if len(vdj) > 0:
cols.update(
{
query
+ "_VDJ": np.sum(
[float(x) for x in vdj if present(x)]
)
}
)
else:
cols.update({query + "_VDJ": np.nan})
if len(vj) > 0:
cols.update(
{
query
+ "_VJ": np.sum(
[float(x) for x in vj if present(x)]
)
}
)
else:
cols.update({query + "_VJ": np.nan})
elif retrieve_mode == "split and average":
if len(vdj) > 0:
cols.update(
{
query
+ "_VDJ": np.mean(
[float(x) for x in vdj if present(x)]
)
}
)
else:
cols.update({query + "_VDJ": np.nan})
if len(vj) > 0:
cols.update(
{
query
+ "_VJ": np.mean(
[float(x) for x in vj if present(x)]
)
}
)
else:
cols.update({query + "_VJ": np.nan})
elif retrieve_mode == "merge":
cols.update(
{query: "|".join(x for x in (vdj + vj) if present(x))}
)
elif retrieve_mode == "split":
if len(vdj) > 0:
for i in range(1, len(vdj) + 1):
cols.update({query + "_VDJ_" + str(i): vdj[i - 1]})
if len(vj) > 0:
for i in range(1, len(vj) + 1):
cols.update({query + "_VJ_" + str(i): vj[i - 1]})
elif retrieve_mode == "sum":
cols.update(
{query: np.sum([float(x) for x in vdj + vj if present(x)])}
)
if not present(cols[query]):
cols.update({query: np.nan})
elif retrieve_mode == "average":
cols.update(
{query: np.mean([float(x) for x in vdj + vj if present(x)])}
)
if not present(cols[query]):
cols.update({query: np.nan})
ret.update({cell: cols})
out = pd.DataFrame.from_dict(ret, orient="index")
if retrieve_mode not in [
"split and sum",
"split and average",
"sum",
"average",
]:
if retrieve_mode == "split":
for x in out:
try:
out[x] = pd.to_numeric(out[x])
except:
out[x] = out[x].where(pd.notna(out[x]), None)
else:
out = out.where(pd.notna(out), None)
return out
def retrieve_celltype(
self,
query: str,
retrieve_mode: Literal[
"split and unique only",
"merge and unique only",
"split and merge",
"split and sum",
"split and average",
"split",
"merge",
"sum",
"average",
],
) -> pd.DataFrame:
"""
Retrieve query split by celltype.
Parameters
----------
query : str
column name in `.data` slot to retrieve and update the metadata.
retrieve_mode : Literal["split and unique only", "merge and unique only", "split and merge", "split and sum", "split and average", "split", "merge", "sum", "average", ]
one of:
`split and unique only`
returns the retrieval splitted into two columns,
i.e. one for VDJ and one for VJ chains, separated by `|` for unique elements.
`merge and unique only`
returns the retrieval merged into one column,
separated by `|` for unique elements.
`split and merge`
returns the retrieval splitted into two columns,
i.e. one for VDJ and one for VJ chains, separated by `|` for every elements.
`split`
returns the retrieval splitted into separate columns for each contig.
`merge`
returns the retrieval merged into one columns for each contig,
separated by `|` for unique elements.
`split and sum`
returns the retrieval sum in the VDJ and VJ columns (separately).
`split and average`
returns the retrieval averaged in the VDJ and VJ columns (separately).
`sum`
returns the retrieval sum into one column for all contigs.
`average`
returns the retrieval averaged into one column for all contigs.
Returns
-------
pd.DataFrame
Retrieved data.
"""
self.query = query
ret = {}
for cell in self.Cell:
cols, abt_vdj, gdt_vdj, b_vdj, abt_vj, gdt_vj, b_vj = (
{},
[],
[],
[],
[],
[],
[],
)
for _, contig in self.Cell[cell].items():
if isinstance(contig, dict):
if contig["locus"] in ["IGH"]:
b_vdj.append(contig[query])
elif contig["locus"] in ["IGK", "IGL"]:
b_vj.append(contig[query])
elif contig["locus"] in ["TRB"]:
abt_vdj.append(contig[query])
elif contig["locus"] in ["TRD"]:
gdt_vdj.append(contig[query])
elif contig["locus"] in ["TRA"]:
abt_vj.append(contig[query])
elif contig["locus"] in ["TRG"]:
gdt_vj.append(contig[query])
if retrieve_mode == "split and unique only":
if len(b_vdj) > 0:
cols.update(
{
query
+ "_B_VDJ": "|".join(
str(x)
for x in list(dict.fromkeys(b_vdj))
if present(x)
)
}
)
if len(b_vj) > 0:
cols.update(
{
query
+ "_B_VJ": "|".join(
str(x)
for x in list(dict.fromkeys(b_vj))
if present(x)
)
}
)
if len(abt_vdj) > 0:
cols.update(
{
query
+ "_abT_VDJ": "|".join(
str(x)
for x in list(dict.fromkeys(abt_vdj))
if present(x)
)
}
)
if len(abt_vj) > 0:
cols.update(
{
query
+ "_abT_VJ": "|".join(
str(x)
for x in list(dict.fromkeys(abt_vj))
if present(x)
)
}
)
if len(gdt_vdj) > 0:
cols.update(
{
query
+ "_gdT_VDJ": "|".join(
str(x)
for x in list(dict.fromkeys(gdt_vdj))
if present(x)
)
}
)
if len(gdt_vj) > 0:
cols.update(
{
query
+ "_gdT_VJ": "|".join(
str(x)
for x in list(dict.fromkeys(gdt_vj))
if present(x)
)
}
)
elif retrieve_mode == "split and merge":
if len(b_vdj) > 0:
cols.update(
{
query
+ "_B_VDJ": "|".join(
str(x) for x in b_vdj if present(x)
)
}
)
if len(b_vj) > 0:
cols.update(
{
query
+ "_B_VJ": "|".join(
str(x) for x in b_vj if present(x)
)
}
)
if len(abt_vdj) > 0:
cols.update(
{
query
+ "_abT_VDJ": "|".join(
str(x) for x in abt_vdj if present(x)
)
}
)
if len(abt_vj) > 0:
cols.update(
{
query
+ "_abT_VJ": "|".join(
str(x) for x in abt_vj if present(x)
)
}
)
if len(gdt_vdj) > 0:
cols.update(
{
query
+ "_gdT_VDJ": "|".join(
str(x) for x in gdt_vdj if present(x)
)
}
)
if len(gdt_vj) > 0:
cols.update(
{
query
+ "_gdT_VJ": "|".join(
str(x) for x in gdt_vj if present(x)
)
}
)
elif retrieve_mode == "merge and unique only":
cols.update(
{
query: "|".join(
str(x)
for x in set(
b_vdj
+ abt_vdj
+ gdt_vdj
+ b_vj
+ abt_vj
+ gdt_vj
)
if present(x)
)
}
)
elif retrieve_mode == "split and sum":
if len(b_vdj) > 0:
cols.update(
{
query
+ "_B_VDJ": np.sum(
[float(x) for x in b_vdj if present(x)]
)
}
)
else:
cols.update({query + "_B_VDJ": np.nan})
if len(b_vj) > 0:
cols.update(
{
query
+ "_B_VJ": np.sum(
[float(x) for x in b_vj if present(x)]
)
}
)
else:
cols.update({query + "_B_VJ": np.nan})
if len(abt_vdj) > 0:
cols.update(
{
query
+ "_abT_VDJ": np.sum(
[float(x) for x in abt_vdj if present(x)]
)
}
)
else:
cols.update({query + "_abT_VDJ": np.nan})
if len(abt_vj) > 0:
cols.update(
{
query
+ "_abT_VJ": np.sum(
[float(x) for x in abt_vj if present(x)]
)
}
)
else:
cols.update({query + "_abT_VJ": np.nan})
if len(gdt_vdj) > 0:
cols.update(
{
query
+ "_gdT_VDJ": np.sum(
[float(x) for x in gdt_vdj if present(x)]
)
}
)
else:
cols.update({query + "_gdT_VDJ": np.nan})
if len(gdt_vj) > 0:
cols.update(
{
query
+ "_gdT_VJ": np.sum(
[float(x) for x in gdt_vj if present(x)]
)
}
)
else:
cols.update({query + "_gdT_VJ": np.nan})
elif retrieve_mode == "split and average":
if len(b_vdj) > 0:
cols.update(
{
query
+ "_B_VDJ": np.mean(
[float(x) for x in b_vdj if present(x)]
)
}
)
else:
cols.update({query + "_B_VDJ": np.nan})
if len(b_vj) > 0:
cols.update(
{
query
+ "_B_VJ": np.mean(
[float(x) for x in b_vj if present(x)]
)
}
)
else:
cols.update({query + "_B_VJ": np.nan})
if len(abt_vdj) > 0:
cols.update(
{
query
+ "_abT_VDJ": np.mean(
[float(x) for x in abt_vdj if present(x)]
)
}
)
else:
cols.update({query + "_abT_VDJ": np.nan})
if len(abt_vj) > 0:
cols.update(
{
query
+ "_abT_VJ": np.mean(
[float(x) for x in abt_vj if present(x)]
)
}
)
else:
cols.update({query + "_abT_VJ": np.nan})
if len(gdt_vdj) > 0:
cols.update(
{
query
+ "_gdT_VDJ": np.mean(
[float(x) for x in gdt_vdj if present(x)]
)
}
)
else:
cols.update({query + "_gdT_VDJ": np.nan})
if len(gdt_vj) > 0:
cols.update(
{
query
+ "_gdT_VJ": np.mean(
[float(x) for x in gdt_vj if present(x)]
)
}
)
else:
cols.update({query + "_gdT_VJ": np.nan})
elif retrieve_mode == "merge":
cols.update(
{
query: "|".join(
x
for x in (
b_vdj
+ abt_vdj
+ gdt_vdj
+ b_vj
+ abt_vj
+ gdt_vj
)
if present(x)
)
}
)
elif retrieve_mode == "split":
if len(b_vdj) > 0:
for i in range(1, len(b_vdj) + 1):
cols.update({query + "_B_VDJ_" + str(i): b_vdj[i - 1]})
if len(b_vj) > 0:
for i in range(1, len(b_vj) + 1):
cols.update({query + "_B_VJ_" + str(i): b_vj[i - 1]})
if len(abt_vdj) > 0:
for i in range(1, len(abt_vdj) + 1):
cols.update(
{query + "_abT_VDJ_" + str(i): abt_vdj[i - 1]}
)
if len(abt_vj) > 0:
for i in range(1, len(abt_vj) + 1):
cols.update(
{query + "_abT_VJ_" + str(i): abt_vj[i - 1]}
)
if len(gdt_vdj) > 0:
for i in range(1, len(gdt_vdj) + 1):
cols.update(
{query + "_gdT_VDJ_" + str(i): gdt_vdj[i - 1]}
)
if len(gdt_vj) > 0:
for i in range(1, len(gdt_vj) + 1):
cols.update(
{query + "_gdT_VJ_" + str(i): gdt_vj[i - 1]}
)
elif retrieve_mode == "sum":
cols.update(
{
query: np.sum(
[
float(x)
for x in b_vdj
+ abt_vdj
+ gdt_vdj
+ b_vj
+ abt_vj
+ gdt_vj
if present(x)
]
)
}
)
if not present(cols[query]):
cols.update({query: np.nan})
elif retrieve_mode == "average":
cols.update(
{
query: np.mean(
[
float(x)
for x in b_vdj
+ abt_vdj
+ gdt_vdj
+ b_vj
+ abt_vj
+ gdt_vj
if present(x)
]
)
}
)
if not present(cols[query]):
cols.update({query: np.nan})
ret.update({cell: cols})
out = pd.DataFrame.from_dict(ret, orient="index")
if retrieve_mode not in [
"split and sum",
"split and average",
"sum",
"average",
]:
if retrieve_mode == "split":
for x in out:
try:
out[x] = pd.to_numeric(out[x])
except:
out[x] = out[x].where(pd.notna(out[x]), None)
else:
out = out.where(pd.notna(out), None)
return out
def _normalize_indices(
index: Index | None, names0: pd.Index, names1: pd.Index
) -> tuple[slice, str]:
"""Return indices"""
# deal with tuples of length 1
if isinstance(index, tuple) and len(index) == 1:
index = index[0]
# deal with pd.Series
if isinstance(index, pd.Series):
index = index.values
if isinstance(index, tuple):
if len(index) > 2:
raise ValueError(
"Dandelion can only be sliced in data or metadata rows."
)
# deal with pd.Series
# TODO: The series should probably be aligned first
if isinstance(index[1], pd.Series):
index = index[0], index[1].values
if isinstance(index[0], pd.Series):
index = index[0].values, index[1]
ax0_, _ = unpack_index(index)
if all(ax_ in names0 for ax_ in ax0_):
ax0 = _normalize_index(ax0_, names0)
axtype = "metadata"
elif all(ax_ in names1 for ax_ in ax0_):
ax0 = _normalize_index(ax0_, names1)
axtype = "data"
return ax0, axtype
def return_none_call(call: str) -> str:
"""Return None if not present."""
return call.split("|")[0] if not call in ["None", "", None] else None
def clean_clone_list(clone_series: pd.Series) -> pd.Series:
"""Remove empty/None clone tokens, deduplicate and sort; keep empty entries as None."""
clone_series = clone_series.replace("", None)
clone_series = clone_series.str.split("|").apply(
lambda x: (
[c for c in x if c not in ["None", None]]
if isinstance(x, list)
else []
)
)
clone_series = clone_series.apply(
lambda x: (
"|".join(
sorted(set(x), key=cmp_to_key(lambda a, b: (a > b) - (a < b)))
)
if x
else None
)
)
return clone_series
def flatten_and_count(tmp_metadata: pd.DataFrame, clonekey: str) -> pd.Series:
"""Return a Series of clone counts for all unique clones."""
tmp = tmp_metadata[clonekey].str.split("|", expand=True).stack()
clone_counts = tmp.value_counts()
# Filter out None and "None" string values from the index
clone_counts = clone_counts[~clone_counts.index.isin([None, "None"])]
return clone_counts
def get_receptor_prefix(clone: str) -> str:
"""Return receptor type prefix if matches RECEPTOR_SET, else None."""
prefix = clone.split("_")[0]
return prefix if prefix in RECEPTOR_SET else None
def assign_clone_numbers(clone_counts: pd.Series) -> dict:
"""Assign sequential numbers, possibly grouped by receptor type."""
# Determine all receptor types present
prefixes = {get_receptor_prefix(clone) for clone in clone_counts.index}
prefixes.discard(None)
size_dict = {}
if len(prefixes) <= 1:
# Only 1 receptor type (or none): number sequentially without prefix
for i, clone in enumerate(clone_counts.index, start=1):
size_dict[clone] = str(i)
else:
# Multiple receptor types: number sequentially per type
receptor_to_clones = {r: [] for r in RECEPTOR_SET}
other_clones = []
for clone in clone_counts.index:
prefix = get_receptor_prefix(clone)
if prefix in RECEPTOR_SET:
receptor_to_clones[prefix].append(clone)
else:
other_clones.append(clone)
# Sort each receptor group by descending size
for r in receptor_to_clones:
receptor_to_clones[r].sort(key=lambda c: -clone_counts[c])
other_clones.sort(key=lambda c: -clone_counts[c])
# Assign numbers
for r, clones in receptor_to_clones.items():
for i, clone in enumerate(clones, start=1):
size_dict[clone] = f"{r}_{i}"
for i, clone in enumerate(other_clones, start=1):
size_dict[clone] = (
f"other_{i}" if clone not in ["None", None] else None
)
return size_dict
def _add_clone_info(tmp_metadata: pd.DataFrame, clonekey: str) -> pd.DataFrame:
"""Add a `{clonekey}_rank` column to tmp_metadata with sequential numbering per receptor type based on clone size."""
tmp_metadata[clonekey] = clean_clone_list(tmp_metadata[clonekey])
clone_counts = flatten_and_count(tmp_metadata, clonekey)
size_dict = assign_clone_numbers(clone_counts)
# Map multi-clone entries
tmp_metadata[clonekey + "_rank"] = (
tmp_metadata[clonekey]
.apply(
lambda entry: (
"|".join(size_dict.get(p, p) for p in entry.split("|"))
if isinstance(entry, str)
else None
)
)
.astype("category")
)
# Reorder columns
tmp_metadata = tmp_metadata[
[clonekey, clonekey + "_rank"]
+ [
c
for c in tmp_metadata.columns
if c not in [clonekey, clonekey + "_rank"]
]
]
return tmp_metadata
@deprecated(
deprecated_in="1.0.0",
removed_in="1.1.0",
details="legacy .h5ddl format will no longer be supported.",
)
def write_h5ddl_legacy(
self: Dandelion,
filename: Path | str = "dandelion_data.h5ddl",
**kwargs,
) -> None: # pragma: no cover
"""
Writes a Dandelion class to .h5ddl format for legacy support.
Parameters
----------
self : Dandelion
input Dandelion object.
filename : Path | str, optional
path to `.h5ddl` file, by default "dandelion_data.h5ddl".
**kwargs
Additional arguments to `pd.DataFrame.to_hdf`.
"""
clear_h5file(filename)
# now to actually saving
data = self._data.copy()
data = sanitize_data(data)
data, _ = sanitize_data_for_saving(data)
data.to_hdf(
filename,
"data",
**kwargs,
)
if self.metadata is not None:
metadata = self._metadata.copy()
for col in metadata.columns:
if pd.__version__ < "2.1.0":
weird = (
metadata[[col]].applymap(type)
!= metadata[[col]].iloc[0].apply(type)
).any(axis=1)
else:
weird = (
metadata[[col]].map(type)
!= metadata[[col]].iloc[0].apply(type)
).any(axis=1)
if len(metadata[weird]) > 0:
metadata[col] = metadata[col].where(
pd.notnull(metadata[col]), ""
)
metadata.to_hdf(
filename,
"metadata",
format="table",
nan_rep=np.nan,
**kwargs,
)
graph_counter = 0
try:
for g in self.graph:
G = nx.to_pandas_adjacency(g, nonedge=np.nan)
G.to_hdf(
filename,
"graph/graph_" + str(graph_counter),
**kwargs,
)
graph_counter += 1
except:
pass
with h5py.File(filename, "a") as hf:
try:
layout_counter = 0
for l in self.layout:
try:
hf.create_group("layout/layout_" + str(layout_counter))
except:
pass
for k in l.keys():
hf["layout/layout_" + str(layout_counter)].attrs[k] = l[k]
layout_counter += 1
except:
pass
if len(self.germline) > 0:
try:
hf.create_group("germline")
except:
pass
for k in self.germline.keys():
hf["germline"].attrs[k] = self.germline[k]
def load_data(obj: pd.DataFrame | Path | str | None) -> pd.DataFrame:
"""
Read in or copy dataframe object and set sequence_id as index without dropping.
Parameters
----------
obj : pd.DataFrame | Path | str | None
file path to .tsv file or pandas DataFrame object.
Returns
-------
pd.DataFrame
Raises
------
TypeError
if input is not found.
KeyError
if `sequence_id` not found in input.
"""
if obj is not None:
if os.path.isfile(str(obj)):
obj_ = pd.read_csv(obj, sep="\t")
elif isinstance(obj, pd.DataFrame):
obj_ = obj.copy()
else:
raise TypeError(
"Either input is not of pandas DataFrame or AIRR file does not exist."
)
if "sequence_id" in obj_.columns:
# assert that sequence_id is string
obj_["sequence_id"] = obj_["sequence_id"].astype(str)
obj_.set_index("sequence_id", drop=False, inplace=True)
if "cell_id" not in obj_.columns:
obj_["cell_id"] = [
c.split("_contig")[0] if "_contig" in c else c
for c in obj_["sequence_id"]
]
# assert that cell_id is string
obj_["cell_id"] = obj_["cell_id"].astype(str)
else:
raise KeyError("'sequence_id' not found in columns of input")
if "duplicate_count" in obj_.columns:
if "umi_count" not in obj_.columns:
obj_.rename(
columns={"duplicate_count": "umi_count"}, inplace=True
)
return obj_
def check_travdv(data: pd.DataFrame) -> pd.DataFrame:
"""Check if locus is TRA/D."""
data = load_data(data)
contig = [x for x in data["sequence_id"]]
v = [x for x in data["v_call"]]
d = [x for x in data["d_call"]]
j = [x for x in data["j_call"]]
c = [x for x in data["c_call"]]
l = [x for x in data["locus"]]
v_dict = dict(zip(contig, v))
d_dict = dict(zip(contig, d))
j_dict = dict(zip(contig, j))
c_dict = dict(zip(contig, c))
l_dict = dict(zip(contig, l))
for co in contig:
if re.search("TRAV.*/DV", v_dict[co]):
if same_call(j_dict[co], c_dict[co], d_dict[co], "TRA"):
if not re.search("TRA", l_dict[co]):
l_dict[co] = "TRA"
elif same_call(j_dict[co], c_dict[co], d_dict[co], "TRD"):
if not re.search("TRD", l_dict[co]):
l_dict[co] = "TRD"
data["locus"] = pd.Series(l_dict)
return data