Source code for descope.inference

import torch
import random
import logging
import cell_eval
import numpy as np
import pandas as pd
import scanpy as sc
import polars as pl
import anndata as ad

from pathlib import Path
from tqdm.auto import tqdm
from typing import Union, Optional
from abc import ABC, abstractmethod
from easydict import EasyDict as edict
from cell_eval import MetricsEvaluator
from cell_eval._types import MetricType
from transformers import PreTrainedModel
from cell_eval.metrics import metrics_registry
from cell_eval._evaluator import PerturbationAnndataPair
from .utils import (
    pearson, 
    pearson_delta_on_topk_de,
    direction_match_on_topk_de,
    edistance, 
    load_gene_embs,
    preprocess_atac_perturbation_adata_consistent_with_epiagent,
    preprocess_rna_perturbation_adata
)
from .models.modeling_descope import DeSCOPEForATAC, DeSCOPEForRNA

logger = logging.getLogger(__name__)


def _build_anndata_pair_NOCHECK(
    real: Union[sc.AnnData, str],
    pred: Union[sc.AnnData, str],
    control_pert: str,
    pert_col: str,
    allow_discrete: bool = False,
) -> PerturbationAnndataPair:
    r"""
    Builds a PerturbationAnndataPair without performing cell-eval's default data validation,
    particularly the check for log1p-normalized expression values.
    
    This is useful when the model outputs raw or non-log-transformed predictions (e.g., negative values),
    and the user wishes to bypass format enforcement to compute metrics directly on the given scale.
    """
    logger.warning("-------------------------------------------------")
    logger.warning("Running cell-eval without data format validation.")
    logger.warning("-------------------------------------------------")

    if isinstance(real, str):
        logger.info(f"Reading real anndata from {real}")
        real = ad.read_h5ad(real)
    if isinstance(pred, str):
        logger.info(f"Reading pred anndata from {pred}")
        pred = ad.read_h5ad(pred)

    # Validate that the input is normalized and log-transformed
    # _convert_to_normlog(real, which="real", allow_discrete=allow_discrete)
    # _convert_to_normlog(pred, which="pred", allow_discrete=allow_discrete)

    # Build the anndata pair
    return PerturbationAnndataPair(
        real=real, pred=pred, control_pert=control_pert, pert_col=pert_col
    )


[docs] class CellEvalMixin: r""" A mixin class designed to extend evaluation capabilities for single-cell perturbation prediction tasks. This class provides methods to compute various metrics on predicted and real AnnData objects, including differential expression (DE) analysis etc. **Example:** >>> from descope.inference import CellEvalMixin >>> evalmixin = CellEvalMixin() >>> print(evalmixin.all_metrics) # list all metrics in cell-eval >>> results, agg_results, evaluator = evalmixin.compute_metrics(...) >>> # Compute extra metrics >>> evalmixin.extra_metrics_func.pearson(evaluator.anndata_pair, ...) # pearson >>> evalmixin.extra_metrics_func.pearson_delta_on_topk_de(evaluator.anndata_pair, ...) # pearson_delta_on_topk_de >>> evalmixin.extra_metrics_func.direction_match_on_topk_de(evaluator.anndata_pair, ...) # direction_match_on_topk_de >>> evalmixin.extra_metrics_func.edistance(...) # edistance """ def __init__(self): self.extra_metrics_func = { "pearson": pearson, "pearson_delta_on_topk_de": pearson_delta_on_topk_de, "direction_match_on_topk_de": direction_match_on_topk_de, "edistance": edistance } self.extra_metrics_func = edict(self.extra_metrics_func) @property def all_metrics(self) -> list: # Includes only the built-in metrics provided by cell-eval. return metrics_registry.list_metrics() @property def all_de_metrics(self) -> list: # Includes only the built-in metrics provided by cell-eval. return metrics_registry.list_metrics(MetricType.DE) @property def all_anndata_pair_metrics(self) -> list: # Includes only the built-in metrics provided by cell-eval. return metrics_registry.list_metrics(MetricType.ANNDATA_PAIR) @property def extra_metrics(self) -> list: return list(self.extra_metrics_func.keys())
[docs] def check_whether_skip_de_metrics( self, metrics_to_calculate: list ) -> bool: de_metrics = self.all_de_metrics skip_de = True if len(set(metrics_to_calculate) & set(de_metrics)) == 0 else False return skip_de
def _compute_metrics( self, adata_pred: Union[str, sc.AnnData], adata_real: Union[str, sc.AnnData], metrics_to_calculate: list, de_pred: Optional[str] = None, de_real: Optional[str] = None, # If available, provide DE results to speed up computation. control_pert: str = "non-targeting", pert_col: str = "target_gene", de_method: str = "wilcoxon", num_threads: int = -1, outdir: str = "./cell-eval-outdir", ) -> tuple[pl.DataFrame, pl.DataFrame, MetricsEvaluator]: evaluator = MetricsEvaluator( adata_pred=adata_pred, adata_real=adata_real, de_pred=de_pred, de_real=de_real, control_pert=control_pert, pert_col=pert_col, de_method=de_method, num_threads=num_threads, outdir=outdir, skip_de=self.check_whether_skip_de_metrics(metrics_to_calculate) ) (results, agg_results) = evaluator.compute( profile="full", skip_metrics=list(set(self.all_metrics) - set(metrics_to_calculate)) ) # Return evaluator for calculating pearson and pearson_delta_on_topk_de return results, agg_results, evaluator
[docs] def compute_metrics( self, adata_pred: Union[str, sc.AnnData], adata_real: Union[str, sc.AnnData], metrics_to_calculate: list, de_pred: Optional[str] = None, de_real: Optional[str] = None, # If available, provide DE results to speed up computation. control_pert: str = "non-targeting", pert_col: str = "target_gene", de_method: str = "wilcoxon", num_threads: int = -1, outdir: str = "./cell-eval-outdir", ) -> tuple[pl.DataFrame, pl.DataFrame, MetricsEvaluator]: # ATAC: # Since the prediction target is TF-IDF, there is no need to perform normalization or log1p checks. # RNA: # As some methods allow the model to output negative values, # normalization and log1p checks should also be disabled; # however, note that this step needs to be handled manually. # If you want to validate the adata format, please call the `self._compute_metrics` method. cell_eval._evaluator._build_anndata_pair = _build_anndata_pair_NOCHECK return self._compute_metrics( adata_pred=adata_pred, adata_real=adata_real, # self.adata_test_ground_truth metrics_to_calculate=metrics_to_calculate, de_pred=de_pred, de_real=de_real, control_pert=control_pert, pert_col=pert_col, de_method=de_method, num_threads=num_threads, outdir=outdir, )
[docs] class BaseInference(CellEvalMixin, ABC): def __init__( self, test_csv_template_fp: str, adata: Union[str, sc.AnnData], # where you select control cells for inference pretrained_model_name_or_path: str, pert_col: str = "perturbation", ctrl_name: str = "control" ): super().__init__() self.pert_col = pert_col self.ctrl_name = ctrl_name # Load test csv template self.test_csv_template = pd.read_csv(test_csv_template_fp) self._check_test_csv_template(self.test_csv_template) # Add test_genes attribute (involve control) self.test_genes = self.test_csv_template["target_gene"].unique().tolist() if ctrl_name not in self.test_genes: self.test_genes.append(ctrl_name) # Load model (abstract) self.model = self.load_model(pretrained_model_name_or_path) # Load and prepare anndata (abstract) -> Return (adata_test_ground_truth & adata_ctrl) if isinstance(adata, str): logger.info(f"Read AnnData from {adata} ...") adata = sc.read_h5ad(adata) adata = self.preprocess_adata(adata) self.adata_ctrl = self._prepare_control_adata(adata) self.adata_test_ground_truth = self._prepare_test_adata(adata) self.adata_pred = None # Placeholder # Ensure control row exists in template if ctrl_name not in self.test_csv_template["target_gene"].unique(): self.test_csv_template = pd.concat( [ self.test_csv_template, pd.DataFrame({"target_gene": [ctrl_name], "n_cells": len(self.adata_ctrl)}) ], ignore_index=True )
[docs] @staticmethod def write_h5ad(adata: sc.AnnData, save_path: str): output_dir = Path(save_path).parent output_dir.mkdir(parents=True, exist_ok=True) adata.write_h5ad(save_path)
@staticmethod def _check_test_csv_template(test_csv_template: pd.DataFrame): columns = test_csv_template.columns required_columns = ["target_gene", "n_cells"] missing_columns = [c for c in required_columns if c not in columns] if missing_columns: raise ValueError(f"Missing required columns in test csv template: {missing_columns}") def _prepare_control_adata(self, adata: sc.AnnData) -> sc.AnnData: if self.ctrl_name not in adata.obs[self.pert_col].unique().tolist(): raise ValueError( "No control cells detected in the input adata; inference cannot be performed." ) else: adata_ctrl = adata[adata.obs[self.pert_col] == self.ctrl_name] logger.info("self.adata_ctrl registered successfully.") return adata_ctrl def _prepare_test_adata(self, adata: sc.AnnData) -> sc.AnnData: pert_genes_in_adata = adata.obs[self.pert_col].unique().tolist() missing = 0 for g in self.test_genes: if g not in pert_genes_in_adata: missing += 1 if missing > 0: logger.warning( "The input adata does not contain all the test perturbation genes; " "`self.adata_test_ground_truth` has been set to None; " "You need to manually provide the ground truth when computing metrics." ) return else: adata_test_ground_truth = adata[adata.obs[self.pert_col].isin(self.test_genes)] logger.info("self.adata_test_ground_truth registered successfully.") return adata_test_ground_truth
[docs] @torch.no_grad() def inference( self, gene_embs_file: str, device: torch.device, batch_size: int = 128, use_generated_control_cells: bool = False, seed: int = 42 ): # ----------------- Experimental Flags (UNSTABLE - Keep Disabled) ----------------- assert not use_generated_control_cells, ( "[CRITICAL]: Generated control cells are for internal debugging only. " "Using them for final evaluation will cause result instability. " "Please ensure `use_generated_control_cells=False` for all model evaluations." ) # ---------------------------------------------------------------------------------- gene_embs = load_gene_embs(gene_embs_file, perts_to_emb=self.test_genes) if hasattr(self.adata_ctrl.X, "toarray"): X_ctrl = self.adata_ctrl.X.toarray() else: X_ctrl = self.adata_ctrl.X self.model = self.model.to(device) self.model.eval() random.seed(seed) pbar = tqdm(total=len(self.test_csv_template), desc="Inference") results = [] for _, row in self.test_csv_template.iterrows(): target_gene = row["target_gene"] n_cells = row["n_cells"] gen_n_cells = 0 while True: selected_ctrl_cells_indices = random.choices(range(len(X_ctrl)), k=batch_size) selected_ctrl_cells_X = torch.tensor( X_ctrl[selected_ctrl_cells_indices], dtype=torch.float32, device=device ) pert_gene_emb = gene_embs[target_gene].unsqueeze(0).repeat( len(selected_ctrl_cells_X), 1 ).to(torch.float32).to(device) logits = self.model.inference( selected_ctrl_cells_X, # Main Input Name pert_gene_emb=pert_gene_emb, ).logits.cpu().numpy() if np.any(np.isnan(logits)): # Remove cells contain NaN. print("Detected NaN, drop current batch.") continue gen_n_cells += len(logits) if gen_n_cells <= n_cells: results.append(logits) else: results.append(logits[:-(gen_n_cells - n_cells)]) pbar.update(1) break results = np.concatenate(results, axis=0) torch.cuda.empty_cache() # Organize final results self.adata_pred = self._organize_final_results( results, use_generated_control_cells ) logger.info("Predicted results have been registered to self.adata_pred.")
def _organize_final_results( self, results: np.ndarray, use_generated_control_cells: bool = False ) -> sc.AnnData: target_gene = self.test_csv_template.apply( lambda r: [r["target_gene"]] * r["n_cells"], axis=1 ).values.tolist() target_gene = sum(target_gene, []) predicted_adata = ad.AnnData(X=results) predicted_adata.obs[self.pert_col] = target_gene predicted_adata.var_names = self.adata_ctrl.var_names # If not use_generated_control_cells, replace generated control cells with real control cells. if not use_generated_control_cells: logger.info("You are using real control cells in the final predicted_adata.") predicted_adata = predicted_adata[predicted_adata.obs[self.pert_col] != self.ctrl_name] predicted_adata = sc.concat([predicted_adata, self.adata_ctrl], axis=0) # predicted_adata.X[predicted_adata.X < np.log1p(1)] = 0 # cutoff small values return predicted_adata
[docs] @abstractmethod def load_model(self, pretrained_model_name_or_path: str) -> PreTrainedModel: raise NotImplementedError()
[docs] @abstractmethod def preprocess_adata(self, adata: sc.AnnData) -> sc.AnnData: raise NotImplementedError()
[docs] class InferenceForATAC(BaseInference):
[docs] def preprocess_adata(self, adata: sc.AnnData) -> sc.AnnData: # Extract topk ccres from model.config.input_length topk_ccres = self.model.config.input_length return preprocess_atac_perturbation_adata_consistent_with_epiagent( adata, topk_ccres, self.pert_col )
[docs] def load_model(self, pretrained_model_name_or_path: str) -> PreTrainedModel: return DeSCOPEForATAC.from_pretrained(pretrained_model_name_or_path)
[docs] class InferenceForRNA(BaseInference): def __init__( self, test_csv_template_fp: str, adata: Union[str, sc.AnnData], # where you select control cells for inference pretrained_model_name_or_path: str, pert_col: str = "target_gene", ctrl_name: str = "non-targeting", target_sum: float = 1e4, skip_raw_counts_check: bool = False ): self.target_sum = target_sum self.skip_raw_counts_check = skip_raw_counts_check super().__init__( test_csv_template_fp=test_csv_template_fp, adata=adata, pretrained_model_name_or_path=pretrained_model_name_or_path, pert_col=pert_col, ctrl_name=ctrl_name )
[docs] def preprocess_adata(self, adata: sc.AnnData) -> sc.AnnData: return preprocess_rna_perturbation_adata( adata=adata, target_sum=self.target_sum, pert_col=self.pert_col, skip_raw_counts_check=self.skip_raw_counts_check )
[docs] def load_model(self, pretrained_model_name_or_path: str) -> PreTrainedModel: return DeSCOPEForRNA.from_pretrained(pretrained_model_name_or_path)