def loss_lookahead_diff(model: NeuralTeleportationModel, data: Tensor, target: Tensor, metrics: TrainingMetrics, config: OptimalTeleportationTrainingConfig, **kwargs) -> Number: # Save the state of the model, prior to performing the lookahead state_dict = model.state_dict() # Initialize a new optimizer to perform lookahead optimizer = get_optimizer_from_model_and_config(model, config) optimizer.zero_grad() # Compute loss at the teleported point loss = torch.stack([metrics.criterion(model(data_batch), target_batch) for data_batch, target_batch in zip(data, target)]).mean(dim=0) # Take a step using the gradient at the teleported point loss.backward() # Compute loss after the optimizer step lookahead_loss = torch.stack([metrics.criterion(model(data_batch), target_batch) for data_batch, target_batch in zip(data, target)]).mean(dim=0) # Restore the state of the model prior to the lookahead model.load_state_dict(state_dict) # Compute the difference between the lookahead loss and the original loss return (loss - lookahead_loss).item()
def train_epoch(model: nn.Module, metrics: TrainingMetrics, optimizer: Optimizer, train_loader: DataLoader, epoch: int, device: str = 'cpu', progress_bar: bool = True, config: TrainingConfig = None, lr_scheduler=None) -> None: lr_scheduler_interval = None if config.lr_scheduler is not None: lr_scheduler_interval = config.lr_scheduler[1] # Init data structures to keep track of the metrics at each batch metrics_by_batch = {metric.__name__: [] for metric in metrics.metrics} metrics_by_batch.update(loss=[]) model.train() pbar = tqdm(enumerate(train_loader)) for batch_idx, (data, target) in pbar: if batch_idx == config.max_batch: break data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = metrics.criterion(output, target) metrics_by_batch["loss"].append(loss.item()) for metric in metrics.metrics: metrics_by_batch[metric.__name__].append(metric(output, target)) loss.backward() optimizer.step() if progress_bar: output = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, (batch_idx + 1) * train_loader.batch_size, len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()) pbar.set_postfix_str(output) if lr_scheduler and lr_scheduler_interval == "step": lr_scheduler.step() pbar.update() pbar.close() # Log the mean of each metric at the end of the epoch if config is not None and config.logger is not None: reduced_metrics = { metric: mean(values_by_batch) for metric, values_by_batch in metrics_by_batch.items() } config.logger.log_metrics(reduced_metrics, epoch=epoch) for metric_name, value in reduced_metrics.items(): config.logger.add_scalar(metric_name, value, epoch)
def test(model: nn.Module, dataset: Dataset, metrics: TrainingMetrics, config: TrainingConfig, eval_mode: bool = True) -> Dict[str, Any]: test_loader = DataLoader(dataset, batch_size=config.batch_size) if eval_mode: model.eval() results = defaultdict(list) pbar = tqdm(enumerate(test_loader)) with torch.no_grad(): for i, (data, target) in pbar: if i == config.max_batch: break data, target = data.to(config.device), target.to(config.device) output = model(data) results['loss'].append(metrics.criterion(output, target).item()) if metrics is not None: batch_results = compute_metrics(metrics.metrics, y=target, y_hat=output, to_tensor=False) for k in batch_results.keys(): results[k].append(batch_results[k]) pbar.update() pbar.set_postfix(loss=pd.DataFrame(results['loss']).mean().values, accuracy=pd.DataFrame( results['accuracy']).mean().values) pbar.close() reduced_results = dict(pd.DataFrame(results).mean()) if config.logger is not None: config.logger.log_metrics(reduced_results, epoch=0) return reduced_results