Source code for descope.dataset

import torch
import random
import pickle
import logging
import datasets
import numpy as np
import pandas as pd
import scanpy as sc

from typing import Union, Optional
from abc import ABC, abstractmethod
from torch.utils.data import Dataset
from .tokenizer import _filter_perturbations
from .utils import (
    load_gene_embs,
    preprocess_atac_perturbation_adata_consistent_with_epiagent,
    preprocess_rna_perturbation_adata
)

logger = logging.getLogger(__name__)


# Slow | Deprecated
[docs] class BaseDataset(Dataset, ABC): MAIN_INPUT_NAME = None RANDOM_MAPPING_CONTROL_TO_CONTROL = False def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if cls is BaseDataset: return if not hasattr(cls, 'MAIN_INPUT_NAME') or cls.MAIN_INPUT_NAME == None: raise NotImplementedError( f"Class {cls.__name__} must define the class attribute `MAIN_INPUT_NAME` with a non-None value." ) def __init__( self, adata: Union[str, sc.AnnData], pert_col: str = "perturbation", ctrl_name: str = "control", perts_to_include: Optional[list] = None, perts_to_exclude: Optional[list] = None, gene_embs_file: str = "./ESM2_pert_features.pt" ): super().__init__() self.pert_col = pert_col self.ctrl_name = ctrl_name self.perts_to_include = perts_to_include self.perts_to_exclude = perts_to_exclude if isinstance(adata, str): logger.info(f"Read anndata from {adata} ...") adata = sc.read_h5ad(adata) if hasattr(adata.X, "toarray"): adata.X = adata.X.toarray() adata = self.preprocess_adata(adata) # abstract self.adata = _filter_perturbations( adata=adata, pert_col=pert_col, ctrl_name=ctrl_name, perts_to_include=perts_to_include, perts_to_exclude=perts_to_exclude ) self.ctrl_cell_indices = self.get_ctrl_cell_indices(self.adata) self.gene_embs = load_gene_embs( gene_embs_file=gene_embs_file, perts_to_emb=self.adata.obs[self.pert_col].unique().tolist() )
[docs] def get_ctrl_cell_indices(self, adata: sc.AnnData) -> list[int]: ctrl_cell_indices = np.where(adata.obs[self.pert_col] == self.ctrl_name)[0] if len(ctrl_cell_indices) == 0: raise ValueError("No control cells found!") return ctrl_cell_indices
def __getitem__(self, idx): adata_pert = self.adata[idx] pert_name = adata_pert.obs[self.pert_col].item() pert_gene_emb = self.gene_embs[pert_name].to(torch.float32) if pert_name != self.ctrl_name or self.RANDOM_MAPPING_CONTROL_TO_CONTROL: random_ctrl_idx = np.random.choice(self.ctrl_cell_indices) adata_ctrl = self.adata[random_ctrl_idx] basal_sequence = torch.tensor( adata_ctrl.X.reshape(-1), dtype=torch.float32 ) labels = torch.tensor( adata_pert.X.reshape(-1), dtype=torch.float32 ) else: basal_sequence = torch.tensor( adata_pert.X.reshape(-1), dtype=torch.float32 ) labels = basal_sequence.clone() return {self.MAIN_INPUT_NAME: basal_sequence, "pert_gene_emb": pert_gene_emb, "labels": labels} def __len__(self) -> int: return len(self.adata)
[docs] @abstractmethod def preprocess_adata(self, adata: sc.AnnData) -> sc.AnnData: raise NotImplementedError()
# Slow | Deprecated
[docs] class DatasetForATAC(BaseDataset): MAIN_INPUT_NAME = "ctrl_cell_tf_idf" RANDOM_MAPPING_CONTROL_TO_CONTROL = False def __init__( self, adata: Union[str, sc.AnnData], pert_col: str = "perturbation", ctrl_name: str = "control", topk_ccres: int = 50000, perts_to_include: Optional[list] = None, perts_to_exclude: Optional[list] = None, gene_embs_file: str = "./ESM2_pert_features.pt" ): self.topk_ccres = topk_ccres super().__init__( adata=adata, pert_col=pert_col, ctrl_name=ctrl_name, perts_to_include=perts_to_include, perts_to_exclude=perts_to_exclude, gene_embs_file=gene_embs_file )
[docs] def preprocess_adata(self, adata: sc.AnnData) -> sc.AnnData: return preprocess_atac_perturbation_adata_consistent_with_epiagent( adata, self.topk_ccres, self.pert_col )
# Slow | Deprecated
[docs] class DatasetForRNA(BaseDataset): MAIN_INPUT_NAME = "ctrl_cell_expr" RANDOM_MAPPING_CONTROL_TO_CONTROL = False def __init__( self, adata: Union[str, sc.AnnData], pert_col: str = "target_gene", ctrl_name: str = "non-targeting", target_sum: float = 1e4, skip_raw_counts_check: bool = False, perts_to_include: Optional[list] = None, perts_to_exclude: Optional[list] = None, gene_embs_file: str = "./ESM2_pert_features.pt" ): self.target_sum = target_sum self.skip_raw_counts_check = skip_raw_counts_check super().__init__( adata=adata, pert_col=pert_col, ctrl_name=ctrl_name, perts_to_include=perts_to_include, perts_to_exclude=perts_to_exclude, gene_embs_file=gene_embs_file )
[docs] def preprocess_adata(self, adata: sc.AnnData) -> sc.AnnData: return preprocess_rna_perturbation_adata( adata=adata, target_sum=self.target_sum, pert_col=self.pert_col, skip_raw_counts_check=self.skip_raw_counts_check )
# Fast
[docs] class HFBaseDataset(Dataset): MAIN_INPUT_NAME = None RANDOM_MAPPING_CONTROL_TO_CONTROL = False def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if cls is HFBaseDataset: return if not hasattr(cls, 'MAIN_INPUT_NAME') or cls.MAIN_INPUT_NAME == None: raise NotImplementedError( f"Class {cls.__name__} must define the class attribute `MAIN_INPUT_NAME` with a non-None value." ) def __init__( self, hf_dataset: datasets.Dataset, ctrl_name: str = "control", gene_embs_file: str = "./ESM2_pert_features.pt", mse_weights_pkl_file: Optional[str] = None ): super().__init__() self._check_hf_dataset_features(hf_dataset) self.ds = hf_dataset # features in self.ds: labels, pert_gene, celltype self.ctrl_name = ctrl_name self.ctrl_cell_indices = self.get_ctrl_cell_indices_for_each_celltype() self.gene_embs = load_gene_embs( gene_embs_file=gene_embs_file, perts_to_emb=self.ds.unique("pert_gene") ) # {gene: torch.Tensor} if mse_weights_pkl_file is not None: with open(mse_weights_pkl_file, "rb") as f: self.cp2weights = pickle.load(f) # {("celltype", "perturbation"): np.array(...)} else: self.cp2weights = None # preprocess hf dataset self._preprocess_hf_dataset() # features in self.ds: labels, pert_gene, pert_gene_emb, celltype @staticmethod def _check_hf_dataset_features(hf_dataset: datasets.Dataset): missing_features = ["labels", "pert_gene", "celltype"] for feature in hf_dataset.features: if feature in ["labels", "pert_gene", "celltype"]: missing_features.remove(feature) if len(missing_features) > 0: raise ValueError( f"The following features are missing from the HuggingFace dataset: {missing_features}. " "Please make sure that the dataset contains the following features: labels, pert_gene, celltype." )
[docs] def get_ctrl_cell_indices_for_each_celltype(self) -> dict[str, list[int]]: celltype = np.array(self.ds["celltype"]) pert_gene = np.array(self.ds["pert_gene"]) df = pd.DataFrame({ "celltype": celltype, "pert_gene": pert_gene }) ctrl_cell_indices = { celltype: indices for (celltype, pert_gene), indices in df.groupby(["celltype", "pert_gene"]).groups.items() if pert_gene == self.ctrl_name } for celltype, indices in ctrl_cell_indices.items(): if len(indices) == 0: raise ValueError(f"No control cells found for celltype {celltype}!") return ctrl_cell_indices
def _preprocess_hf_dataset(self): # Step1: Add pert_gene_emb to hf dataset gene_embs = {gene: embs.numpy() for gene, embs in self.gene_embs.items()} self.ds = self.ds.add_column("pert_gene_emb", pd.Series(self.ds["pert_gene"]).map(gene_embs).tolist()) # Step2: Add mse_weights from pickle if provided torch_columns = ["pert_gene_emb", "labels"] if self.cp2weights is not None: zero_weights = np.zeros_like(next(iter(self.cp2weights.values()))) weights = [self.cp2weights.get((ct, pg), zero_weights) for ct, pg in zip(self.ds["celltype"], self.ds["pert_gene"])] self.ds = self.ds.add_column("mse_weights", weights) torch_columns.append("mse_weights") # Step3: Set format to torch self.ds.set_format("torch", columns=torch_columns, output_all_columns=True) def __getitem__(self, idx) -> dict: return self.ds[idx] def __getitems__(self, keys: list) -> list: """Can be used to get a batch using a list of integers indices.""" batch = self.ds.__getitem__(keys) selected_ctrl_indices = [] if self.RANDOM_MAPPING_CONTROL_TO_CONTROL: for ct in batch["celltype"]: selected_ctrl_indices.append( int(random.choice(self.ctrl_cell_indices[ct])) ) else: for ct, pg, cell_idx in zip(batch["celltype"], batch["pert_gene"], keys): if pg != "non-targeting": selected_ctrl_indices.append( int(random.choice(self.ctrl_cell_indices[ct])) ) else: selected_ctrl_indices.append(int(cell_idx)) # MAIN_INPUT_NAME, pert_gene_emb, labels, pert_gene, celltype batch[self.MAIN_INPUT_NAME] = self.ds.__getitem__(selected_ctrl_indices)["labels"] return batch def __len__(self) -> int: return len(self.ds)
[docs] @staticmethod def collate_fn(batch): return batch
# Fast
[docs] class HFDatasetForATAC(HFBaseDataset): MAIN_INPUT_NAME = "ctrl_cell_tf_idf" RANDOM_MAPPING_CONTROL_TO_CONTROL = False
# Fast
[docs] class HFDatasetForRNA(HFBaseDataset): MAIN_INPUT_NAME = "ctrl_cell_expr" RANDOM_MAPPING_CONTROL_TO_CONTROL = False