Source code for descope.arguments

from typing import Optional
from wppkg import TrainingArguments
from dataclasses import dataclass, field


[docs] @dataclass class DeSCOPETrainingArguments(TrainingArguments): alpha_mse_loss: float = field( default=1.0, metadata={ "help": "The weight (or coefficient) for the MSE loss." } ) alpha_kl_loss: float = field( default=1.0, metadata={ "help": "The weight (or coefficient) for the KL loss." } ) pretrained_model_name_or_path: Optional[str] = field( default=None, metadata={ "help": "Path to a pretrained model." } )
[docs] @dataclass class DeSCOPEDataArguments: tokenized_datasets_dir: str = field( default="./tokenized_dataset/K562", metadata={ "help": "Path to the tokenized huggingface datasets." } ) keep_in_memory: bool = field( default=False, metadata={ "help": "Whether to keep the datasets in memory." } ) ctrl_name: str = field( default="control", metadata={ "help": "The name in huggingface datasets that represents control cells." } ) gene_embs_file: str = field( default="./ESM2_pert_features.pt", metadata={ "help": "Path to the gene embedding file." } )
[docs] @dataclass class DeSCOPEModelArguments: hidden_act: str = field( default="gelu", metadata={ "help": "The activation function for the hidden layers." } ) hidden_size: int = field( default=672, metadata={ "help": "The hidden size of the model." } ) hidden_dropout: float = field( default=0, metadata={ "help": "The dropout rate for the hidden layers." } ) pert_gene_encoder_layers: int = field( default=1, metadata={ "help": "The number of layers in the perturbation gene encoder." } ) variational_encoder_layers: int = field( default=4, metadata={ "help": "The number of layers in the variational encoder." } ) variational_decoder_layers: int = field( default=4, metadata={ "help": "The number of layers in the variational decoder." } ) add_layernorm: bool = field( default=True, metadata={ "help": "Whether to add layer normalization to the hidden layers." } )