class Evaluator: """Model evaluator Args: model: model to be evaluated device: device on which to evaluate model loader: dataloader on which to evaluate model checkpoint_path: path to model checkpoint """ def __init__( self, model: torch.nn.Module, device: torch.device, loader: DataLoader, checkpoint_path: Optional[str] = None, ) -> None: # Device self.device = device # Data self.loader = loader # Model self.model = model if checkpoint_path: self._load_from_checkpoint(checkpoint_path) # Metrics self.acc_metric = BinaryAccuracyMetric(threshold=0.5) def evaluate(self) -> float: """Evaluates the model Returns: (float) accuracy (on a 0 to 1 scale) """ # Progress bar pbar = tqdm.tqdm(total=len(self.loader), leave=False) pbar.set_description("Evaluating... ") # Set to eval self.model.eval() # Loop for data, target in self.loader: with torch.no_grad(): # To device data, target = data.to(self.device), target.to(self.device) # Forward out = self.model(data) self.acc_metric.update(out.sigmoid(), target) # Update progress bar pbar.update() pbar.close() accuracy = self.acc_metric.compute() print(f"Accuracy: {accuracy:.4f}\n") return accuracy def predict(self, threshold: int = 0.5) -> torch.Tensor: """Returns predictions for the given data Assumes the output of the model are the logits and applies sigmoid to the output Args: threshold: prediction threshold Returns: (torch.Tensor) Model predictions for the given data of shape [N,], where N is the number of samples in the data """ # Progress bar pbar = tqdm.tqdm(total=len(self.loader), leave=False) pbar.set_description("Predicting... ") # Set to eval self.model.eval() preds = [] # Loop for data, _ in self.loader: with torch.no_grad(): # To device data = data.to(self.device) # Forward out = self.model(data) pred = torch.where(out.sigmoid() > threshold, 1, 0) preds.append(pred) # Update progress bar pbar.update() pbar.close() preds = torch.cat(preds).reshape(-1) return preds def _load_from_checkpoint(self, checkpoint_path: str) -> None: checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint["model"]) print(f"Checkpoint loaded: {checkpoint_path}")
class Trainer: """Model trainer Args: model: model to train loss_fn: loss function optimizer: model optimizer epochs: number of epochs device: device to train the model on train_loader: training dataloader val_loader: validation dataloader scheduler: learning rate scheduler writer: writer which logs metrics to TensorBoard (disabled if None) save_path: folder in which to save models (disabled if None) checkpoint_path: path to model checkpoint, to resume training show_pbar: whether to display the progress bar or not """ def __init__( self, model: torch.nn.Module, loss_fn: torch.nn.Module, optimizer: torch.optim.Optimizer, epochs: int, device: torch.device, train_loader: DataLoader, val_loader: Optional[DataLoader] = None, scheduler: Optional = None, # Type: torch.optim.lr_scheduler._LRScheduler writer: Optional[SummaryWriter] = None, save_path: Optional[str] = None, checkpoint_path: Optional[str] = None, show_pbar: bool = True, ) -> None: self.writer = writer # Saving self.save_path = save_path # Device self.device = device # Data self.train_loader = train_loader self.val_loader = val_loader # Model self.model = model self.loss_fn = loss_fn self.optimizer = optimizer self.scheduler = scheduler self.epochs = epochs self.start_epoch = 0 if checkpoint_path: self._load_from_checkpoint(checkpoint_path) # Metrics self.train_loss_metric = LossMetric() self.val_loss_metric = LossMetric() self.train_acc_metric = BinaryAccuracyMetric(threshold=0.5) self.val_acc_metric = BinaryAccuracyMetric(threshold=0.5) # Progress bar self.show_pbar = show_pbar def train(self) -> None: """Trains the model""" print("Beginning training") start_time = time.time() for epoch in range(self.start_epoch, self.epochs): start_epoch_time = time.time() self._train_loop(epoch) if self.val_loader is not None: self._val_loop(epoch) epoch_time = time.time() - start_epoch_time self._end_loop(epoch, epoch_time) train_time_h = (time.time() - start_time) / 3600 print(f"Finished training! Total time: {train_time_h:.2f}h") if self.save_path: self._save_model(os.path.join(self.save_path, "final_model.pt"), self.epochs) def _train_loop(self, epoch: int) -> None: """ Regular train loop Args: epoch: current epoch """ # Progress bar if self.show_pbar: pbar = tqdm.auto.tqdm(total=len(self.train_loader), leave=False) pbar.set_description(f"Epoch {epoch} | Train") # Set to train self.model.train() # Loop for data, target in self.train_loader: # To device data, target = data.to(self.device), target.to(self.device) # Forward + backward self.optimizer.zero_grad() out = self.model(data) loss = self.loss_fn(out, target) loss.backward() self.optimizer.step() # Update metrics self.train_loss_metric.update(loss.item(), data.shape[0]) self.train_acc_metric.update(out.sigmoid(), target) # Update progress bar if self.show_pbar: pbar.update() pbar.set_postfix_str(f"Loss: {loss.item():.3f}", refresh=False) # Update scheduler if it is epoch-based if self.scheduler is not None and not isinstance( self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.scheduler.step() if self.show_pbar: pbar.close() def _val_loop(self, epoch: int) -> None: """ Standard validation loop Args: epoch: current epoch """ # Progress bar if self.show_pbar: pbar = tqdm.auto.tqdm(total=len(self.val_loader), leave=False) pbar.set_description(f"Epoch {epoch} | Validation") # Set to eval self.model.eval() # Loop for data, target in self.val_loader: with torch.no_grad(): # To device data, target = data.to(self.device), target.to(self.device) # Forward out = self.model(data) loss = self.loss_fn(out, target) # Update metrics self.val_loss_metric.update(loss.item(), data.shape[0]) self.val_acc_metric.update(out.sigmoid(), target) # Update progress bar if self.show_pbar: pbar.update() pbar.set_postfix_str(f"Loss: {loss.item():.3f}", refresh=False) if self.show_pbar: pbar.close() if self.scheduler is not None and isinstance( self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.scheduler.step(self.val_loss_metric.compute()) def _end_loop(self, epoch: int, epoch_time: float): # Print epoch results print(self._epoch_str(epoch, epoch_time)) # Write to tensorboard if self.writer is not None: self._write_to_tb(epoch) # Save model if self.save_path is not None: self._save_model(os.path.join(self.save_path, "most_recent.pt"), epoch) # Clear metrics self.train_loss_metric.reset() self.train_acc_metric.reset() if self.val_loader is not None: self.val_loss_metric.reset() self.val_acc_metric.reset() def _epoch_str(self, epoch: int, epoch_time: float): s = f"Epoch {epoch} " s += f"| Train loss: {self.train_loss_metric.compute():.3f} " s += f"| Train acc: {self.train_acc_metric.compute():.3f} " if self.val_loader is not None: s += f"| Val loss: {self.val_loss_metric.compute():.3f} " s += f"| Val acc: {self.val_acc_metric.compute():.3f} " s += f"| Epoch time: {epoch_time:.1f}s" return s def _write_to_tb(self, epoch): self.writer.add_scalar("Loss/train", self.train_loss_metric.compute(), epoch) self.writer.add_scalar("Acc/train", self.train_acc_metric.compute(), epoch) if self.val_loader is not None: self.writer.add_scalar("Loss/val", self.val_loss_metric.compute(), epoch) self.writer.add_scalar("Acc/val", self.val_acc_metric.compute(), epoch) def _save_model(self, path, epoch): obj = { "epoch": epoch + 1, "optimizer": self.optimizer.state_dict(), "model": self.model.state_dict(), "scheduler": self.scheduler.state_dict() if self.scheduler is not None else None, } torch.save(obj, path) def _load_from_checkpoint(self, checkpoint_path: str) -> None: checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint["model"]) self.optimizer.load_state_dict(checkpoint["optimizer"]) self.start_epoch = checkpoint["epoch"] if self.scheduler: self.scheduler.load_state_dict(checkpoint["scheduler"]) if self.start_epoch > self.epochs: raise ValueError("Starting epoch is larger than total epochs") print(f"Checkpoint loaded, resuming from epoch {self.start_epoch}")