Source code for dandelion.utilities._utilities

from __future__ import annotations
import h5py
import os
import re
import unicodedata
import warnings
import zarr

import numpy as np
import pandas as pd

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    from airr import RearrangementSchema
from collections import defaultdict
from collections.abc import Iterable
from packaging import version
from pathlib import Path
from subprocess import run
from typing import TypeVar, Literal, Callable

ZARR_V3 = version.parse(zarr.__version__) >= version.parse("3.0.0")
if ZARR_V3:
    from zarr.storage import LocalStore, ZipStore
    from zarr.codecs import BloscCodec

    def open_zarr_group(store, mode="a"):
        return zarr.open_group(store=store, mode=mode)

    def create_zarr_array(root, name, **kwargs):
        return root.create_array(name, **kwargs)

    def create_zarr_dataset(group, *args, **kwargs):
        # Zarr v3 uses create_array() instead of create_dataset()
        # When data is provided, shape and dtype are inferred - don't pass them
        if "data" in kwargs and kwargs["data"] is not None:
            kwargs.pop("shape", None)
            kwargs.pop("dtype", None)
        # Zarr v3 requires chunk sizes >= 1
        if "chunks" in kwargs and kwargs["chunks"] is not None:
            chunks = kwargs["chunks"]
            if isinstance(chunks, tuple):
                kwargs["chunks"] = tuple(max(1, c) for c in chunks)
        return group.create_array(*args, **kwargs)

else:  # pragma: no cover
    from zarr import DirectoryStore, ZipStore
    from zarr.codecs import Blosc as BloscCodec

    def LocalStore(path):
        return DirectoryStore(path)

    def open_zarr_group(store, mode="a"):
        if mode == "w":
            return zarr.group(store=store, overwrite=True)
        else:
            return zarr.group(store=store)

    def create_zarr_array(root, name, **kwargs):
        return root.create(name, **kwargs)

    def create_zarr_dataset(group, *args, **kwargs):
        return group.create_dataset(*args, **kwargs)


# help silence the dtype warning?
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

F = TypeVar("F", bound=Callable)  # Define a TypeVar for any callable type

RECEPTOR_SET = {"B", "abT", "gdT"}
TRUES = ["T", "t", "True", "true", "TRUE", True, "1", 1]
FALSES = ["F", "f", "False", "false", "FALSE", False, "0", 0]
TRUES_STR = [str(x).upper() for x in TRUES]
FALSES_STR = [str(x).upper() for x in FALSES]
HEAVYLONG = ["IGH", "TRB", "TRD"]
LIGHTSHORT = ["IGK", "IGL", "TRA", "TRG"]
VCALL = "v_call"
JCALL = "j_call"
VCALLG = "v_call_genotyped"
JCALLG = "j_call_genotyped"
STRIPALLELENUM = "[*][0-9][0-9]"
NO_DS = [
    "129S1_SvImJ",
    "AKR_J",
    "A_J",
    "C3H_HeJ",
    "C57BL_6J",
    "BALB_c_ByJ",
    "CBA_J",
    "DBA_1J",
    "DBA_2J",
    "MRL_MpJ",
    "NOR_LtJ",
    "NZB_BlNJ",
    "SJL_J",
]
EMPTIES_STR = [
    "nan",
    "NaN",
    "",
    "None",
    "none",
]
EMPTIES = EMPTIES_STR + [
    None,
    np.nan,
    pd.NA,
]


DEFAULT_PREFIX = "all"
BOOLEAN_LIKE_COLUMNS = ["extra", "ambiguous", "full_length", "complete_vdj"]
CHECK_COLS = BOOLEAN_LIKE_COLUMNS + [
    "rev_comp",
    "productive",
    "vj_in_frame",
    "stop_codon",
    "complete_vdj",
    "v_frameshift",
    "j_frameshift",
]
AIRR = [
    "cell_id",
    "sequence_id",
    "sequence",
    "sequence_aa",
    "productive",
    "complete_vdj",
    "vj_in_frame",
    "locus",
    "v_call",
    "d_call",
    "j_call",
    "c_call",
    "junction",
    "junction_aa",
    "consensus_count",
    "umi_count",
    "cdr3_start",
    "cdr3_end",
    "sequence_length_10x",
    "high_confidence_10x",
    "is_cell_10x",
    "fwr1_aa",
    "fwr1",
    "cdr1_aa",
    "cdr1",
    "fwr2_aa",
    "fwr2",
    "cdr2_aa",
    "cdr2",
    "fwr3_aa",
    "fwr3",
    "fwr4_aa",
    "fwr4",
    "clone_id",
    "raw_consensus_id_10x",
    "exact_subclonotype_id_10x",
]
CELLRANGER = [
    "barcode",
    "contig_id",
    "sequence",
    "aa_sequence",
    "productive",
    "full_length",
    "frame",
    "chain",
    "v_gene",
    "d_gene",
    "j_gene",
    "c_gene",
    "cdr3_nt",
    "cdr3",
    "reads",
    "umis",
    "cdr3_start",
    "cdr3_stop",
    "length",
    "high_confidence",
    "is_cell",
    "fwr1",
    "fwr1_nt",
    "cdr1",
    "cdr1_nt",
    "fwr2",
    "fwr2_nt",
    "cdr2",
    "cdr2_nt",
    "fwr3",
    "fwr3_nt",
    "fwr4",
    "fwr4_nt",
    "raw_clonotype_id",
    "raw_consensus_id",
    "exact_subclonotype_id",
]

mutations_type = ["mu_count", "mu_freq"]
mutationsdef = [
    "cdr_r",
    "cdr_s",
    "fwr_r",
    "fwr_s",
    "1_r",
    "1_s",
    "2_r",
    "2_s",
    "3_r",
    "3_s",
    "4_r",
    "4_s",
    "5_r",
    "5_s",
    "6_r",
    "6_s",
    "7_r",
    "7_s",
    "8_r",
    "8_s",
    "9_r",
    "9_s",
    "10_r",
    "10_s",
    "11_r",
    "11_s",
    "12_r",
    "12_s",
    "13_r",
    "13_s",
    "14_r",
    "14_s",
    "15_r",
    "15_s",
    "16_r",
    "16_s",
    "17_r",
    "17_s",
    "18_r",
    "18_s",
    "19_r",
    "19_s",
    "20_r",
    "20_s",
    "21_r",
    "21_s",
    "22_r",
    "22_s",
    "23_r",
    "23_s",
    "24_r",
    "24_s",
    "25_r",
    "25_s",
    "26_r",
    "26_s",
    "27_r",
    "27_s",
    "28_r",
    "28_s",
    "29_r",
    "29_s",
    "30_r",
    "30_s",
    "31_r",
    "31_s",
    "32_r",
    "32_s",
    "33_r",
    "33_s",
    "34_r",
    "34_s",
    "35_r",
    "35_s",
    "36_r",
    "36_s",
    "37_r",
    "37_s",
    "38_r",
    "38_s",
    "39_r",
    "39_s",
    "40_r",
    "40_s",
    "41_r",
    "41_s",
    "42_r",
    "42_s",
    "43_r",
    "43_s",
    "44_r",
    "44_s",
    "45_r",
    "45_s",
    "46_r",
    "46_s",
    "47_r",
    "47_s",
    "48_r",
    "48_s",
    "49_r",
    "49_s",
    "50_r",
    "50_s",
    "51_r",
    "51_s",
    "52_r",
    "52_s",
    "53_r",
    "53_s",
    "54_r",
    "54_s",
    "55_r",
    "55_s",
    "56_r",
    "56_s",
    "57_r",
    "57_s",
    "58_r",
    "58_s",
    "59_r",
    "59_s",
    "60_r",
    "60_s",
    "61_r",
    "61_s",
    "62_r",
    "62_s",
    "63_r",
    "63_s",
    "64_r",
    "64_s",
    "65_r",
    "65_s",
    "66_r",
    "66_s",
    "67_r",
    "67_s",
    "68_r",
    "68_s",
    "69_r",
    "69_s",
    "70_r",
    "70_s",
    "71_r",
    "71_s",
    "72_r",
    "72_s",
    "73_r",
    "73_s",
    "74_r",
    "74_s",
    "75_r",
    "75_s",
    "76_r",
    "76_s",
    "77_r",
    "77_s",
    "78_r",
    "78_s",
    "79_r",
    "79_s",
    "80_r",
    "80_s",
    "81_r",
    "81_s",
    "82_r",
    "82_s",
    "83_r",
    "83_s",
    "84_r",
    "84_s",
    "85_r",
    "85_s",
    "86_r",
    "86_s",
    "87_r",
    "87_s",
    "88_r",
    "88_s",
    "89_r",
    "89_s",
    "90_r",
    "90_s",
    "91_r",
    "91_s",
    "92_r",
    "92_s",
    "93_r",
    "93_s",
    "94_r",
    "94_s",
    "95_r",
    "95_s",
    "96_r",
    "96_s",
    "97_r",
    "97_s",
    "98_r",
    "98_s",
    "99_r",
    "99_s",
    "100_r",
    "100_s",
    "101_r",
    "101_s",
    "102_r",
    "102_s",
    "103_r",
    "103_s",
    "104_r",
    "104_s",
    "cdr1_r",
    "cdr1_s",
    "cdr2_r",
    "cdr2_s",
    "fwr1_r",
    "fwr1_s",
    "fwr2_r",
    "fwr2_s",
    "fwr3_r",
    "fwr3_s",
    "v_r",
    "v_s",
]
MUTATIONS = [] + mutations_type
for m in mutations_type:
    for d in mutationsdef:
        MUTATIONS.append(m + "_" + d)
VDJLENGTHS = [
    "junction_length",
    "junction_aa_length",
    "np1_length",
    "np2_length",
]
SEQINFO = [
    "sequence",
    "sequence_alignment",
    "sequence_alignment_aa",
    "junction",
    "junction_aa",
    "germline_alignment",
    "fwr1",
    "fwr1_aa",
    "fwr2",
    "fwr2_aa",
    "fwr3",
    "fwr3_aa",
    "fwr4",
    "fwr4_aa",
    "cdr1",
    "cdr1_aa",
    "cdr2",
    "cdr2_aa",
    "cdr3",
    "cdr3_aa",
    "v_sequence_alignment_aa",
    "d_sequence_alignment_aa",
    "j_sequence_alignment_aa",
]


def fasta_iterator(fh: str) -> tuple[str, str]:
    """Read in a fasta file as an iterator."""
    while True:
        line = fh.readline()
        if line.startswith(">"):
            break
    while True:
        header = line[1:-1].rstrip()
        sequence = fh.readline().rstrip()
        while True:
            line = fh.readline()
            if not line:
                break
            if line.startswith(">"):
                break
            sequence += line.rstrip()
        yield (header, sequence)
        if not line:
            return


class Tree(defaultdict):
    """Create a recursive defaultdict."""

    def __init__(self, value=None) -> None:
        super().__init__(Tree)
        self.value = value


def dict_from_table(meta: pd.DataFrame, columns: tuple[str, str]) -> dict:
    """
    Generate a dictionary from a dataframe.

    Parameters
    ----------
    meta : pd.DataFrame
        pandas data frame or file path
    columns : tuple[str, str]
        column names in data frame

    Returns
    -------
    dict
        dictionary
    """
    if (isinstance(meta, pd.DataFrame)) & (columns is not None):
        meta_ = meta
        if len(columns) == 2:
            sample_dict = dict(zip(meta_[columns[0]], meta_[columns[1]]))
    elif (os.path.isfile(str(meta))) & (columns is not None):
        meta_ = pd.read_csv(meta, sep="\t", dtype="object")
        if len(columns) == 2:
            sample_dict = dict(zip(meta_[columns[0]], meta_[columns[1]]))

    sample_dict = clean_nan_dict(sample_dict)
    return sample_dict


def clean_nan_dict(d: dict) -> dict:
    """
    Remove nan from dictionary.

    Parameters
    ----------
    d : dict
        dictionary

    Returns
    -------
    dict
        dictionary with no NAs.
    """
    return {k: v for k, v in d.items() if v is not np.nan}


def flatten(l: list) -> list:
    """
    Flatten a list-in-list-in-list.

    Parameters
    ----------
    l : list
        a list-in-list list

    Yields
    ------
    list
        a flattened list.
    """
    for el in l:
        if isinstance(el, Iterable) and not isinstance(el, (str, bytes)):
            yield from flatten(el)
        else:
            yield el


[docs] def makeblastdb(ref: Path | str) -> None: """ Run makeblastdb on constant region fasta file. Wrapper for makeblastdb. Parameters ---------- ref : Path | str constant region fasta file. """ cmd = ["makeblastdb", "-dbtype", "nucl", "-parse_seqids", "-in", str(ref)] run(cmd)
def bh(pvalues: np.array) -> np.array: # pragma: no cover """ Compute the Benjamini-Hochberg FDR correction. Parameters ---------- pvalues : np.array array of p-values to correct Returns ------- np.array np.array of corrected p-values """ n = int(pvalues.shape[0]) new_pvalues = np.empty(n) values = [(pvalue, i) for i, pvalue in enumerate(pvalues)] values.sort() values.reverse() new_values = [] for i, vals in enumerate(values): rank = n - i pvalue, index = vals new_values.append((n / rank) * pvalue) for i in range(0, int(n) - 1): if new_values[i] < new_values[i + 1]: new_values[i + 1] = new_values[i] for i, vals in enumerate(values): pvalue, index = vals new_pvalues[index] = new_values[i] return new_pvalues def is_categorical(array_like: pd.Series) -> bool: """Check if a pandas Series has categorical dtype. Parameters ---------- array_like : pd.Series Series to check. Returns ------- bool True if the Series dtype is ``"category"``. """ return array_like.dtype.name == "category" def type_check(dataframe: pd.DataFrame, key: str) -> bool: """Check if a DataFrame column is string-like, categorical, or boolean. Parameters ---------- dataframe : pd.DataFrame DataFrame containing the column. key : str Column name to check. Returns ------- bool True if the column dtype is str, object, categorical, or bool. """ return ( dataframe[key].dtype == str or dataframe[key].dtype == object or is_categorical(dataframe[key]) or dataframe[key].dtype == bool ) def check_filepath( file_or_folder_path: Path | str, filename_prefix: str | None = None, ends_with: str | None = None, sub_dir: str | None = None, within_dandelion: bool = True, ) -> Path | None: """ Checks whether file path exists. Parameters ---------- file_or_folder_path : Path | str either a string or Path object pointing to a file or folder. filename_prefix : str | None, optional the prefix of the filename. ends_with : str | None, optional the suffix of the filename. Can be flexible i.e. not just the extension. sub_dir : str | None, optional the subdirectory to look for the file if specified within_dandelion : bool, optional whether to look for the file within a 'dandelion' sub folder. Returns ------- Path | None Path object if file is found, else None. """ filename_pre = ( DEFAULT_PREFIX if filename_prefix is None else filename_prefix ) ends_with = "" if ends_with is None else ends_with input_path = ( Path(str(file_or_folder_path)).expanduser() if str(file_or_folder_path)[0] == "~" else Path(str(file_or_folder_path)) ) if input_path.is_file() and str(input_path).endswith(ends_with): return input_path elif input_path.is_dir(): if within_dandelion: for child in input_path.iterdir(): if child.name[0] != ".": if child.is_dir() and child.name == "dandelion": out_dir = child if sub_dir is not None: out_dir = out_dir / sub_dir for file in out_dir.iterdir(): if file.name[0] != ".": if file.is_file() and str(file).endswith( ends_with ): if file.name.startswith( filename_pre + "_contig" ): return file else: if sub_dir is not None: input_path = input_path / sub_dir for file in input_path.iterdir(): if file.name[0] != ".": if file.is_file() and str(file).endswith(ends_with): if file.name.startswith(filename_pre + "_contig"): return file else: return None def cmp_to_key(mycmp): """Convert a cmp= function into a key= function.""" class K: """Key class""" def __init__(self, obj, *args) -> None: self.obj = obj def __lt__(self, other) -> bool: """Less than.""" return mycmp(self.obj, other.obj) < 0 def __gt__(self, other) -> bool: """Greater than.""" return mycmp(self.obj, other.obj) > 0 # pragma: no cover def __eq__(self, other) -> bool: """Equal.""" return mycmp(self.obj, other.obj) == 0 # pragma: no cover def __le__(self, other) -> bool: """Less than or equal.""" return mycmp(self.obj, other.obj) <= 0 # pragma: no cover def __ge__(self, other) -> bool: """Greater than or equal.""" return mycmp(self.obj, other.obj) >= 0 # pragma: no cover def __ne__(self, other) -> bool: """Not equal.""" return mycmp(self.obj, other.obj) != 0 # pragma: no cover return K def not_same_call(a: str, b: str, pattern: str) -> bool: """Check if exactly one of ``a`` or ``b`` matches ``pattern``. Parameters ---------- a : str First string. b : str Second string. pattern : str Regex pattern to match against. Returns ------- bool True if exactly one of ``a`` or ``b`` matches ``pattern``. """ return (re.search(pattern, a) and not re.search(pattern, b)) or ( re.search(pattern, b) and not re.search(pattern, a) ) def same_call(a: str, b: str, c: str, pattern: str) -> bool: """Check if all non-null values among ``a``, ``b``, ``c`` match ``pattern``. Parameters ---------- a : str First string. b : str Second string. c : str Third string. pattern : str Regex pattern to match against. Returns ------- bool True if all non-null values match ``pattern``. """ queries = [a, b, c] queries = [q for q in queries if pd.notnull(q)] return all([re.search(pattern, x) for x in queries]) def present(x: str | None) -> bool: """Check if ``x`` is not null or a blank/missing sentinel string. Parameters ---------- x : str | None Value to check. Returns ------- bool True if ``x`` is not null and not one of the blank sentinel values (``""``, ``"None"``, ``"none"``, ``"NA"``, ``"na"``, ``"NaN"``, ``"nan"``). """ return pd.notnull(x) and x not in [ "", "None", "none", "NA", "na", "NaN", "nan", ] def check_missing(x: str | None) -> bool: """Check if ``x`` is null or an empty string. Parameters ---------- x : str | None Value to check. Returns ------- bool True if ``x`` is null or ``""``. """ return pd.isnull(x) or x == "" def all_missing(x: str | None) -> bool: """Check if all elements in ``x`` are null or empty strings. Parameters ---------- x : str | None Iterable of values to check. Returns ------- bool True if all values are null or ``""``. """ return all(pd.isnull(x)) or all(x == "") def all_missing2(x: str | None) -> bool: """Check if all elements in ``x`` are null, empty strings, or the string ``"None"``. Parameters ---------- x : str | None Iterable of values to check. Returns ------- bool False if ``x`` is empty; True if all values are null, ``""``, or ``"None"``. """ if len(x) == 0: return False return all(pd.isnull(x)) or all(x == "") or all(x == "None") def get_numpy_dtype(series: pd.Series) -> str: """ Map a Pandas dtype to an appropriate NumPy dtype. Parameters ---------- series : pd.Series The Pandas Series. Returns ------- str A string representing the NumPy dtype corresponding to the Pandas dtype. Raises ------ TypeError If the Pandas dtype is unsupported. """ if pd.api.types.is_integer_dtype(series): return "i4" # 32-bit integer elif pd.api.types.is_float_dtype(series): return "f8" # 64-bit float elif pd.api.types.is_bool_dtype(series): return "i1" # 8-bit integer for booleans (True/False) elif pd.api.types.is_string_dtype(series) or pd.api.types.is_object_dtype( series ): # Handle object or string columns; dynamically calculate the max string length max_length = series.astype(str).map(len).max() return "S{}".format(max(1, max_length)) # String with max length else: raise TypeError( f"Unsupported data type: {series.name}" ) # pragma: no cover def sanitize_data_for_saving( data: pd.DataFrame, ) -> tuple[pd.DataFrame, dict[str, str]]: """ Quick sanitize dtypes for saving. Parameters ---------- data : pd.DataFrame Input dataframe. Returns ------- tuple[pd.DataFrame, dict[str, str]] DataFrame and corresponding NumPy structured dtype. """ tmp = data.copy() dtype_dict = {} for col in tmp: if col in RearrangementSchema.properties: dtype = RearrangementSchema.properties[col]["type"] tmp[col] = sanitize_column(tmp[col], dtype) elif col in BOOLEAN_LIKE_COLUMNS: dtype = "boolean" tmp[col] = sanitize_column(tmp[col], dtype) else: tmp[col] = try_numeric_conversion(tmp[col]) dtype_dict[col] = get_numpy_dtype(tmp[col]) # 🔧 Fix: ensure string columns use Unicode dtype, not ASCII bytes if tmp[col].dtype == object or tmp[col].dtype == "string": dtype_dict[col] = h5py.string_dtype(encoding="utf-8") dtypes = [(key, record) for key, record in dtype_dict.items()] return tmp, dtypes def sanitize_boolean(value: str | bool) -> str: """ Sanitize a boolean-like value to 'T' or 'F'. Parameters ---------- value : str | bool The value to sanitize. Returns ------- str 'T' for True, 'F' for False, or the original value if not boolean-like. """ if isinstance(value, bool): return "T" if value else "F" elif isinstance(value, str): stripped_value = value.strip().lower() if stripped_value in ["true", "t"]: return "T" elif stripped_value in ["false", "f"]: return "F" elif isinstance(value, (int, float)): if value == 1: return "T" elif value == 0: return "F" return value def sanitize_column(series: pd.Series, dtype: str) -> pd.Series: """ Sanitize a column based on the specified dtype. Parameters ---------- series : pd.Series The column to be sanitized. dtype : str The expected data type of the column (`string`, `boolean`, `integer`, or `number`). Returns ------- pd.Series The sanitized column with replaced values and appropriate data type. """ pd.set_option("future.no_silent_downcasting", True) if dtype == "boolean": series = series.apply(lambda x: "" if check_missing(x) else x) series = series.replace([None, np.nan, "nan", "na", "NaN", ""], "") return series.apply(sanitize_boolean) elif dtype == "string": series = series.apply(lambda x: "" if check_missing(x) else x) return ( series.replace([None, np.nan, "nan", "na", "NaN", ""], "") .astype(str) .apply(clean_unicode) ) elif dtype in ["number"]: series = series.apply(lambda x: np.nan if check_missing(x) else x) series = series.replace([None, np.nan, "nan", "na", "NaN", ""], np.nan) # for dtype to be float return series.astype("float64").fillna(np.nan) elif dtype in ["integer"]: series = series.apply(lambda x: "" if check_missing(x) else int(x)) series = series.replace([None, np.nan, "nan", "na", "NaN", ""], "") return series.astype(str).apply(clean_unicode) return series def clean_unicode(x: str) -> str: """Normalize and ensure valid UTF-8 text.""" if not isinstance(x, str): return "" # Normalize to NFKC form (handles Greek/Unicode nicely) x = unicodedata.normalize("NFKC", x) # Remove invalid or unencodable characters safely return x.encode("utf-8", "ignore").decode("utf-8") def try_numeric_conversion(series: pd.Series) -> pd.Series: """ Attempt to convert a column to numeric, or fallback to treating it as a string. Parameters ---------- series : pd.Series The column to be converted. Returns ------- pd.Series The column converted to numeric if possible, or sanitized as a string if not. """ if series.dtype.name == "category": series = sanitize_column(series, "string") if series.apply(lambda x: isinstance(x, str) and "|" in x).any(): return sanitize_column(series, "string") try: return pd.to_numeric(series) except: return sanitize_column(series, "string") def sanitize_data(data: pd.DataFrame, ignore: str = "clone_id") -> None: """Quick sanitize dtypes.""" data = data.astype("object") data = data.infer_objects() for d in data: if d in BOOLEAN_LIKE_COLUMNS: data[d] = data[d].apply(sanitize_boolean) if d in RearrangementSchema.properties: if RearrangementSchema.properties[d]["type"] in [ "string", "boolean", "integer", ]: data[d] = data[d].replace( EMPTIES, "", ) if RearrangementSchema.properties[d]["type"] == "integer": data[d] = [ int(x) if present(x) else "" for x in pd.to_numeric(data[d]) ] if RearrangementSchema.properties[d]["type"] == "boolean": data[d] = data[d].apply(sanitize_boolean) else: data[d] = data[d].replace( EMPTIES, np.nan, ) else: if d != ignore: try: data[d] = pd.to_numeric(data[d]) except: data[d] = data[d].replace( to_replace=EMPTIES, value="", ) if re.search("mu_freq", d): data[d] = [ float(x) if present(x) else np.nan for x in pd.to_numeric(data[d]) ] if re.search("mu_count", d): data[d] = [ int(x) if present(x) else "" for x in pd.to_numeric(data[d]) ] if ( pd.Series(["cell_id", "umi_count", "productive"]) .isin(data.columns) .all() ): # sort so that the productive contig with the largest umi is first data.sort_values( by=["cell_id", "productive", "umi_count"], inplace=True, ascending=[True, False, False], ) # check if airr-standards is happy validate_airr(data) return data def sanitize_blastn(data: pd.DataFrame) -> None: """Sanitize dtypes in a blastn output DataFrame to AIRR schema types. Parameters ---------- data : pd.DataFrame DataFrame to sanitize in place. Returns ------- pd.DataFrame Sanitized DataFrame with corrected dtypes. """ data = data.astype("object") data = data.infer_objects() for d in data: if d in RearrangementSchema.properties: if RearrangementSchema.properties[d]["type"] in [ "string", "boolean", "integer", ]: data[d] = data[d].replace( EMPTIES, "", ) if RearrangementSchema.properties[d]["type"] == "integer": data[d] = [ int(x) if present(x) else "" for x in pd.to_numeric(data[d]) ] else: data[d] = data[d].replace( EMPTIES, np.nan, ) else: try: data[d] = pd.to_numeric(data[d]) except: data[d] = data[d].replace( to_replace=EMPTIES, value="", ) return data def validate_airr(data: pd.DataFrame) -> None: """Validate dtypes in an AIRR table against the AIRR schema. Parameters ---------- data : pd.DataFrame DataFrame containing AIRR-format contig data. """ tmp = data.copy() int_columns = [] for d in tmp: try: tmp[d].replace(np.nan, pd.NA).astype("Int64") int_columns.append(d) except: pass for _, row in tmp.iterrows(): contig = Contig(row).contig for required in [ "sequence", "rev_comp", "sequence_alignment", "germline_alignment", "v_cigar", "d_cigar", "j_cigar", ]: if required not in contig: contig.update({required: ""}) RearrangementSchema.validate_header(contig.keys()) RearrangementSchema.validate_row(contig) class ContigDict(dict): """Class Object to extract the contigs as a dictionary.""" def __setitem__(self, key: str, value: str) -> None: """Standard __setitem__.""" super().__setitem__(key, value) def __hash__(self) -> int: """Make it hashable.""" return hash(tuple(self)) class Contig: """Class Object to hold contig.""" def __init__(self, contig: dict, mapper: dict | None = None) -> None: """ Parameters ---------- contig : dict Dictionary of contig fields (typically a row from an AIRR DataFrame). mapper : dict | None, optional Optional column-name mapping to rename keys before storing. Any keys not in ``mapper`` are kept under their original names. """ if mapper is not None: mapper.update({k: k for k in contig.keys() if k not in mapper}) self._contig = ContigDict( {mapper[key]: vals for (key, vals) in contig.items()} ) else: self._contig = ContigDict(contig) for key, value in self._contig.items(): if isinstance(value, float) and np.isnan(value): self._contig[key] = "" @property def contig(self) -> ContigDict: """Contig slot.""" return self._contig def deprecated( details: str, deprecated_in: str, removed_in: str ) -> Callable[[F], F]: """Decorator to mark a function as deprecated. Parameters ---------- details : str Message describing what to use instead. deprecated_in : str Version in which the function was deprecated. removed_in : str Version in which the function will be removed. Returns ------- Callable Wrapped function that emits a ``DeprecationWarning`` when called. """ def deprecated_decorator(func: F) -> F: """Deprecate dectorator""" def deprecated_func(*args, **kwargs): """Deprecate function""" warnings.warn( "{} is a deprecated in {} and will be removed in {}." " {}".format(func.__name__, deprecated_in, removed_in, details), category=DeprecationWarning, stacklevel=2, ) return func(*args, **kwargs) return deprecated_func return deprecated_decorator def format_isotype1(metadata: pd.DataFrame) -> list[str]: """Format isotype column, collapsing IgM/IgD co-expression and marking multiples. Parameters ---------- metadata : pd.DataFrame Metadata DataFrame containing an ``isotype`` column. Returns ------- list[str] List of formatted isotype strings; co-expressed IgM+IgD becomes ``"IgM/IgD"`` and any other multi-isotype entry becomes ``"Multi"``. """ isotype_status = [ ( None if i is None else ( "IgM/IgD" if (i == "IgM|IgD") or (i == "IgD|IgM") else "Multi" if "|" in i else i ) ) for i in metadata["isotype"] ] return isotype_status def format_isotype2(metadata: pd.DataFrame) -> list[str]: """Format isotype_status, allowing IgM/IgD for exception chain statuses. Parameters ---------- metadata : pd.DataFrame Metadata DataFrame containing ``isotype_status`` and ``chain_status`` columns. Returns ------- list[str] List of formatted isotype-status strings. Cells with an exception chain status keep their original ``isotype_status``; cells with ``"Extra pair"`` status are relabelled ``"Multi"``. """ isotype_status = [ ( x if y is None or "exception" in y else ("Multi" if y == "Extra pair" else x) ) for x, y in zip(metadata["isotype_status"], metadata["chain_status"]) ] return isotype_status def format_locus( metadata: pd.DataFrame, vcall: str, suffix_vdj: str = "_VDJ", suffix_vj: str = "_VJ", productive_only: bool = True, ) -> pd.Series: """Extract locus call value from data. Parameters ---------- metadata : pd.DataFrame Per-cell metadata DataFrame. vcall : str Base name of the V-call column (e.g. ``"v_call"`` or ``"v_call_genotyped"``). suffix_vdj : str, optional Suffix appended to locus/productive/call columns for the VDJ chain. suffix_vj : str, optional Suffix appended to locus/productive/call columns for the VJ chain. productive_only : bool, optional If True, only consider productive chains when determining the locus call. Returns ------- pd.Series Series of locus call strings indexed by cell barcode. """ def _pipe_split(value) -> list[str]: if isinstance(value, str): return value.split("|") return [] locus_1 = dict(metadata["locus" + suffix_vdj]) locus_2 = dict(metadata["locus" + suffix_vj]) constant_1 = dict(metadata["isotype_status"]) prod_1 = dict(metadata["productive" + suffix_vdj]) prod_2 = dict(metadata["productive" + suffix_vj]) # also extract the v/d/j calls v_call_1 = dict(metadata[vcall + suffix_vdj]) j_call_1 = dict(metadata["j_call" + suffix_vdj]) d_call_1 = dict(metadata["d_call" + suffix_vdj]) locus_dict = {} for i in metadata.index: locus1_split = _pipe_split(locus_1[i]) locus2_split = _pipe_split(locus_2[i]) prod1_split = _pipe_split(prod_1[i]) prod2_split = _pipe_split(prod_2[i]) if productive_only: loc1 = { e: l for e, l in enumerate( [ ll for ll, p in zip(locus1_split, prod1_split) if p in TRUES ] ) } loc2 = { e: l for e, l in enumerate( [ ll for ll, p in zip(locus2_split, prod2_split) if p in TRUES ] ) } else: loc1 = {e: l for e, l in enumerate([ll for ll in locus1_split])} loc2 = {e: l for e, l in enumerate([ll for ll in locus2_split])} loc1x, loc2x = [], [] if not all([px == "None" for px in loc1.values()]): loc1xx = list(loc1.values()) loc1x = [ij[:2] for ij in loc1.values()] if not all([px == "None" for px in loc2.values()]): loc2xx = list(loc2.values()) loc2x = [ij[:2] for ij in loc2.values()] if len(loc1x) > 0: if len(list(set(loc1x))) > 1: # pragma: no cover tmp1 = "ambiguous" if len(loc2x) > 0: if len(list(set(loc2x))) > 1: tmp2 = "ambiguous" else: if len(loc2x) > 1: if (all(x in ["TRA", "TRG"] for x in loc2xx)) and ( len(list(set(loc2xx))) == 2 ): tmp2 = "Extra VJ-exception" else: tmp2 = "Extra VJ" else: tmp2 = loc2xx[0] else: tmp2 = "None" else: if len(loc1x) > 1: if constant_1[i] == "IgM/IgD": # for BCR e.g. IgM/IgD, also check that the v/d/j calls are the same v1 = _pipe_split(v_call_1[i]) d1 = _pipe_split(d_call_1[i]) j1 = _pipe_split(j_call_1[i]) if productive_only: v1 = [ vv for vv, pp in zip(v1, prod1_split) if pp in TRUES ] d1 = [ dd for dd, pp in zip(d1, prod1_split) if pp in TRUES ] j1 = [ jj for jj, pp in zip(j1, prod1_split) if pp in TRUES ] same_vdj = True if len(v1) == 2 and len(d1) == 2 and len(j1) == 2: if not ( v1[0] == v1[1] and d1[0] == d1[1] and j1[0] == j1[1] ): same_vdj = False else: same_vdj = False if same_vdj: tmp1 = "Extra VDJ-exception" else: tmp1 = "Extra VDJ" elif (all(x in ["TRB", "TRD"] for x in loc1xx)) and ( len(list(set(loc1xx))) == 2 ): tmp1 = "Extra VDJ-exception" else: tmp1 = "Extra VDJ" else: tmp1 = loc1xx[0] if len(loc2x) > 0: if len(list(set(loc2x))) > 1: tmp2 = "ambiguous" else: if len(loc2x) > 1: if (all(x in ["TRA", "TRG"] for x in loc2xx)) and ( len(list(set(loc2xx))) == 2 ): tmp2 = "Extra VJ-exception" else: tmp2 = "Extra VJ" else: tmp2 = loc2xx[0] else: tmp2 = "None" if ( tmp1 not in ["None", "Extra VDJ", "Extra VDJ-exception"] ) and (tmp2 not in ["None", "Extra VJ", "Extra VJ-exception"]): if list(set(loc1x)) != list(set(loc2x)): tmp1 = "ambiguous" tmp2 = "ambiguous" else: tmp1 = "None" if len(loc2x) > 0: if len(list(set(loc2x))) > 1: tmp2 = "ambiguous" else: if len(loc2x) > 1: if (all(x in ["TRA", "TRG"] for x in loc2xx)) and ( len(list(set(loc2xx))) == 2 ): tmp2 = "Extra VJ-exception" else: tmp2 = "Extra VJ" else: tmp2 = loc2xx[0] else: tmp2 = "None" if any(tmp == "ambiguous" for tmp in [tmp1, tmp2]): locus_dict.update({i: "ambiguous"}) else: locus_dict.update({i: tmp1 + " + " + tmp2}) if any(tmp == "None" for tmp in [tmp1, tmp2]): if tmp1 == "None": locus_dict.update({i: "Orphan " + tmp2}) elif tmp2 == "None": locus_dict.update({i: "Orphan " + tmp1}) if any(re.search("No_contig", tmp) for tmp in [tmp1, tmp2]): locus_dict.update({i: "No_contig"}) result = pd.Series(locus_dict) return result def lib_type(lib: str): """Dictionary of acceptable loci for library type.""" librarydict = { "tr-ab": ["TRA", "TRB"], "tr-gd": ["TRG", "TRD"], "ig": ["IGH", "IGK", "IGL"], } return librarydict[lib] def movecol( df: pd.DataFrame, cols_to_move: list = [], ref_col: str = "", ) -> pd.DataFrame: """Reorder DataFrame columns, inserting ``cols_to_move`` after ``ref_col``. Parameters ---------- df : pd.DataFrame Input DataFrame. cols_to_move : list, optional Columns to reposition after ``ref_col``. ref_col : str, optional Column after which ``cols_to_move`` will be inserted. Returns ------- pd.DataFrame DataFrame with reordered columns. """ # https://towardsdatascience.com/reordering-pandas-dataframe-columns-thumbs-down-on-standard-solutions-1ff0bc2941d5 cols = df.columns.tolist() seg1 = cols[: list(cols).index(ref_col) + 1] seg2 = cols_to_move seg1 = [i for i in seg1 if i not in seg2] seg3 = [i for i in cols if i not in seg1 + seg2] return df[seg1 + seg2 + seg3] def format_chain_status(locus_status): """Format chain status labels from per-cell locus-status strings. Parameters ---------- locus_status : iterable of str Iterable of locus-status strings, one per cell. Returns ------- list[str] List of chain-status labels such as ``"Single pair"``, ``"Extra pair"``, ``"Orphan VDJ"``, ``"Orphan VJ"``, ``"Orphan VDJ-exception"``, etc. """ chain_status = [] for ls in locus_status: if ("Orphan" in ls) and (re.search("TRB|IGH|TRD|VDJ", ls)): if not re.search("exception", ls): if re.search("Extra", ls): chain_status.append("Orphan Extra VDJ") else: chain_status.append("Orphan VDJ") else: chain_status.append("Orphan VDJ-exception") elif ("Orphan" in ls) and (re.search("TRA|TRG|IGK|IGL|VJ", ls)): if not re.search("exception", ls): if re.search("Extra", ls): chain_status.append("Orphan Extra VJ") else: chain_status.append("Orphan VJ") else: chain_status.append("Orphan VJ-exception") elif re.search("exception|IgM/IgD", ls): chain_status.append("Extra pair-exception") elif re.search("Extra", ls): chain_status.append("Extra pair") elif re.search("ambiguous|None", ls): chain_status.append("ambiguous") else: chain_status.append("Single pair") return chain_status def set_germline_env( germline: str | None = None, org: Literal["human", "mouse"] = "human", input_file: Path | str | None = None, db: Literal["imgt", "ogrdb"] = "imgt", ) -> tuple[dict[str, str], Path, Path]: """ Set the paths to germline database and environment variables and relevant input files. Parameters ---------- germline : str | None, optional path to germline database. None defaults to environmental variable $GERMLINE. org : Literal["human", "mouse"], optional organism for germline sequences. input_file : Path | str | None, optional path to input file. db : Literal["imgt", "ogrdb"], optional database to use. Defaults to imgt. Returns ------- tuple[dict[str, str], Path, Path] environment dictionary and path to germline database. Raises ------ KeyError if $GERMLINE environmental variable is not set. """ env = os.environ.copy() if germline is None: try: gml = Path(env["GERMLINE"]) except KeyError: raise KeyError( "Environmental variable $GERMLINE is missing. " "Please 'export GERMLINE=/path/to/database/germlines/'" ) gml = gml / db / org / "vdj" else: gml = env["GERMLINE"] = Path(germline) if input_file is not None: input_file = Path(input_file) return env, gml, input_file def set_igblast_env( igblast_db: Path | str | None = None, input_file: Path | str | None = None, ) -> tuple[dict[str, str], Path, Path]: """ Set the igblast database and environment variables and relevant input files. Parameters ---------- igblast_db : str | None, optional path to igblast database. None defaults to environmental variable $IGDATA. input_file : Path | str | None, optional path to input file. Returns ------- tuple[dict[str, str], Path, Path] environment dictionary and path to igblast database. Raises ------ KeyError if $IGDATA environmental variable is not set. """ env = os.environ.copy() if igblast_db is None: try: igdb = Path(env["IGDATA"]) except KeyError: raise KeyError( "Environmental variable $IGDATA is missing. " "Please 'export IGDATA=/path/to/database/igblast/'" ) else: igdb = env["IGDATA"] = Path(igblast_db) if input_file is not None: input_file = Path(input_file) return env, igdb, input_file def set_blast_env( blast_db: str | None = None, input_file: Path | str | None = None, ) -> tuple[dict[str, str], Path, Path]: """ Set the blast database and environment variables and relevant input files. Parameters ---------- blast_db : str | None, optional path to blast database. None defaults to environmental variable $BLASTDB. input_file : Path | str | None, optional path to input file. Returns ------- tuple[dict[str, str], Path, Path] environment dictionary and path to igblast database. Raises ------ KeyError if $BLASTDB environmental variable is not set. """ env = os.environ.copy() if blast_db is None: try: bdb = Path(env["BLASTDB"]) except KeyError: raise KeyError( "Environmental variable $BLASTDB is missing. " "Please 'export BLASTDB=/path/to/database/blast/'" ) else: bdb = env["BLASTDB"] = Path(blast_db) if input_file is not None: input_file = Path(input_file) return env, bdb, input_file def check_data( data: list[Path | str] | Path | str, filename_prefix: list[str] | str | None ) -> tuple[list[str], list[str]]: """Normalise ``data`` and ``filename_prefix`` to matching-length lists. Parameters ---------- data : list[Path | str] | Path | str One or more paths to data folders or files. filename_prefix : list[str] | str | None One or more filename prefixes preceding ``'_contig'``. If a single value is given and ``data`` has multiple entries, it is broadcast. Returns ------- tuple[list[str], list[str]] ``(data, filename_prefix)`` both as lists of equal length. """ if type(data) is not list: data = [data] if not isinstance(filename_prefix, list): filename_prefix = [filename_prefix] if len(filename_prefix) == 1: if len(data) > 1: filename_prefix = filename_prefix * len(data) if all(t is None for t in filename_prefix): filename_prefix = [None for d in data] return data, filename_prefix def check_same_celltype(clone_def1: str, clone_def2: str) -> bool: """Check whether two clone definition strings share the same cell-type prefix. Parameters ---------- clone_def1 : str First clone definition key (e.g. ``"B_clone_id"``). clone_def2 : str Second clone definition key. Returns ------- bool True if the portion before the first ``"_"`` is identical in both strings. """ return clone_def1.split("_", 1)[0] == clone_def2.split("_", 1)[0] def clear_h5file(filename: Path | str) -> None: """Clear all datasets from an existing HDF5 file. Parameters ---------- filename : Path | str Path to the HDF5 file to clear. """ with h5py.File(filename, "w") as hf: for datasetname in hf.keys(): del hf[datasetname] def get_vcall_key(data: dict, v_call_key: str) -> str: """ Determine which V-call key to use based on the provided data and key. Parameters ---------- data : dict The data dictionary containing possible keys. v_call_key : str The requested key to check (e.g. "v_call" or "v_call_genotyped"). Returns ------- str The best matching V-call key, following this priority: 1. "v_call_genotyped" if it exists in data and matches v_call_key 2. "v_call" if it exists in data and matches v_call_key 3. v_call_key if it exists in data 4. "v_call" as a default fallback """ if "v_call_genotyped" in data and v_call_key == "v_call_genotyped": return "v_call_genotyped" elif "v_call" in data and v_call_key == "v_call": return "v_call" elif v_call_key in data: return v_call_key else: return "v_call" def write_fasta( fasta_dict: dict[str, str], out_fasta: Path | str, overwrite=True ) -> None: """ Generic fasta writer using fasta_iterator Parameters ---------- fasta_dict : dict[str, str] dictionary containing fasta headers and sequences as keys and records respectively. out_fasta : Path | str path to write fasta file to. overwrite : bool, optional whether or not to overwrite the output file (out_fasta). """ if overwrite: fh = open(out_fasta, "w") fh.close() out = "" for l in fasta_dict: out = ">" + l + "\n" + fasta_dict[l] + "\n" _write_output(out, out_fasta) def _write_output(out: str, file: Path | str) -> None: """General line writer.""" fh = open(file, "a") fh.write(out) fh.close()