Source code for descope.trainer

import os
import torch
import inspect
import logging
from tqdm.auto import tqdm
from wppkg import (
    Trainer, 
    Accumulator, 
    get_logger
)


def _init_logger(self) -> logging.Logger:
    # Create an independent logger for the Trainer.
    log_file = os.path.join(self.args.output_dir, "run.log")
    logger = get_logger(
        name="wppkg.Trainer",
        log_file=log_file,
        log_file_mode="w",
        fmt="%(asctime)s | %(levelname)s | %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        main_process_level=logging.INFO,
        other_process_level=logging.WARN,
        local_rank=self.accelerator.local_process_index
    )
    # Silence Console Output
    logger.info(f"Training logs have been written to {os.path.abspath(log_file)}")
    for handler in logger.handlers:
        if not isinstance(handler, logging.FileHandler):
            handler.setLevel(logging.ERROR)
    return logger


Trainer._init_logger = _init_logger


[docs] class DeSCOPETrainer(Trainer): r""" NOTE: 1. Early stopping does not currently support resuming training. If training is forcibly resumed, the early stopping callback will be reinitialized. 2. If you enable early stopping, ensure that `eval_every_n_epochs` and `checkpointing_steps` are aligned, as the Trainer does not automatically save the best model. 3. The final model is always saved at the end of training, even if early stopping is triggered. """
[docs] def train(self): # Train! total_batch_size = self.args.per_device_train_batch_size * self.accelerator.num_processes * self.args.gradient_accumulation_steps self.logger.info("***** Running training *****") self.logger.info(f" Num examples = {len(self.train_dataset)}") self.logger.info(f" Num Epochs = {self.args.num_train_epochs}") self.logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}") self.logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") self.logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") self.logger.info(f" Total optimization steps = {self.args.max_train_steps}") self.logger.info("*****************************") # Only show the progress bar once on each machine. progress_bar = tqdm(range(self.args.max_train_steps), disable=not self.accelerator.is_local_main_process) completed_steps = 0 starting_epoch = 0 # Potentially load in the weights and states from a previous save if self.args.resume_from_checkpoint: if self.args.resume_from_checkpoint is not None or self.args.resume_from_checkpoint != "": checkpoint_path = self.args.resume_from_checkpoint path = os.path.basename(self.args.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs.sort(key=os.path.getctime) path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last checkpoint_path = path path = os.path.basename(checkpoint_path) self.logger.info(f"Resumed from checkpoint: {checkpoint_path}") self.accelerator.load_state(checkpoint_path) # Extract `epoch_{i}` or `step_{i}` training_difference = os.path.splitext(path)[0] if "epoch" in training_difference: starting_epoch = int(training_difference.replace("epoch_", "")) + 1 resume_step = None completed_steps = starting_epoch * self.num_update_steps_per_epoch else: # need to multiply `gradient_accumulation_steps` to reflect real steps resume_step = int(training_difference.replace("step_", "")) * self.args.gradient_accumulation_steps starting_epoch = resume_step // len(self.train_dataloader) completed_steps = resume_step // self.args.gradient_accumulation_steps resume_step -= starting_epoch * len(self.train_dataloader) # update the progress_bar if load from checkpoint progress_bar.update(completed_steps) accumulator_train = Accumulator(name=["loss", "mse_loss", "kl_loss"]) model_forward_keys = list(inspect.signature(self.model.forward).parameters.keys()) for epoch in range(starting_epoch, self.args.num_train_epochs): self.model.train() if self.args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: # We skip the first `n` batches in the dataloader when resuming from a checkpoint active_dataloader = self.accelerator.skip_first_batches(self.train_dataloader, resume_step) else: active_dataloader = self.train_dataloader for step, batch in enumerate(active_dataloader): with self.accelerator.accumulate(self.model): filtered_batch = {k: v.to(self.accelerator.device) for k, v in batch.items() if k in model_forward_keys} outputs = self.model(**filtered_batch) loss = outputs.loss mse_loss = outputs.mse_loss kl_loss = outputs.kl_loss self.accelerator.backward(loss) if self.accelerator.sync_gradients: grad_norm = self.accelerator.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) self.optimizer.step() self.lr_scheduler.step() self.optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if self.accelerator.sync_gradients: progress_bar.update(1) completed_steps += 1 # We keep track of the loss at each logging_steps accumulator_train.add( self.accelerator.reduce(loss.detach().clone(), "mean").item(), self.accelerator.reduce(mse_loss.detach().clone(), "mean").item(), self.accelerator.reduce(kl_loss.detach().clone(), "mean").item() ) # Log training progress if completed_steps % self.args.logging_steps == 0: accumulator_train.mean() log_dict = accumulator_train.to_dict() accumulator_train.reset() # reset accumulator extra_log_dict = { "grad_norm": grad_norm.detach().item() if torch.is_tensor(grad_norm) else grad_norm, "lr": self.lr_scheduler.get_last_lr()[0] } log_dict = log_dict | extra_log_dict log_dict_round = { k: round(v, 6) if k == "lr" else round(v, 4) for k, v in log_dict.items() } self.logger.info({"epoch": epoch, "step": completed_steps, **log_dict_round}) if self.args.with_tracking: self.accelerator.log(log_dict, step=completed_steps) if isinstance(self.args.checkpointing_steps, int): if completed_steps % self.args.checkpointing_steps == 0 and self.accelerator.sync_gradients: output_dir = f"step_{completed_steps}" output_dir = os.path.join(self.args.output_dir, output_dir) self.accelerator.save_state(output_dir) # Save the model checkpoint et al. self._save(os.path.join(output_dir, "model")) if completed_steps >= self.args.max_train_steps: break # NOTE: Evaluation will be performed at the end of each epoch. (or every `eval_every_n_epochs`) if self.eval_dataloader is not None and (epoch + 1) % self.args.eval_every_n_epochs == 0: eval_log_dict = self.evaluate() # Log evaluation progress self.logger.info({"epoch": epoch, **eval_log_dict}) if self.args.with_tracking: self.accelerator.log(eval_log_dict, step=epoch) # EarlyStop: check if we should stop the training on any processes if self.earlystop_callback is not None: if self.earlystop_callback.check_early_stopping(eval_log_dict["eval_loss"]): self.accelerator.set_trigger() # If so, we break the loop if self.accelerator.check_trigger(): self.logger.info(f"Model has not improved for {self.args.earlystop_patience} evaluations, so we halt the training session.") break # NOTE: Allow checkpointing_steps to be in the format "epoch-<number>", meaning a checkpoint is saved every <number> epochs. if isinstance(self.args.checkpointing_steps, str): checkpointing_every_n_epochs = ( 1 if self.args.checkpointing_steps == "epoch" else int(self.args.checkpointing_steps.split("-")[-1]) ) if (epoch + 1) % checkpointing_every_n_epochs == 0: output_dir = f"epoch_{epoch}" output_dir = os.path.join(self.args.output_dir, output_dir) self.accelerator.save_state(output_dir) # Save the model checkpoint et al. self._save(os.path.join(output_dir, "model")) # Save the last model checkpoint. self._save(os.path.join(self.args.output_dir, "last_model")) self.accelerator.wait_for_everyone() self.accelerator.end_training() self.logger.info("Training exited successfully.")
[docs] def evaluate(self): self.model.eval() losses, mse_losses, kl_losses = [], [], [] model_forward_keys = list(inspect.signature(self.model.forward).parameters.keys()) for step, batch in enumerate(self.eval_dataloader): with torch.no_grad(): filtered_batch = {k: v.to(self.accelerator.device) for k, v in batch.items() if k in model_forward_keys} outputs = self.model(**filtered_batch) loss, mse_loss, kl_loss = outputs.loss, outputs.mse_loss, outputs.kl_loss losses.append(self.accelerator.gather_for_metrics(loss.repeat(self.args.per_device_eval_batch_size))) mse_losses.append(self.accelerator.gather_for_metrics(mse_loss.repeat(self.args.per_device_eval_batch_size))) kl_losses.append(self.accelerator.gather_for_metrics(kl_loss.repeat(self.args.per_device_eval_batch_size))) eval_loss = torch.mean(torch.cat(losses)) eval_mse_loss = torch.mean(torch.cat(mse_losses)) eval_kl_loss = torch.mean(torch.cat(kl_losses)) return { "eval_loss": eval_loss.item(), "eval_mse_loss": eval_mse_loss.item(), "eval_kl_loss": eval_kl_loss.item() }