def train_one_epoch( trainer: Trainer, epoch_count: int) -> Tuple[Dict[str, float], Dict[str, float]]: train_metrics: Dict[str, float] = {} val_metrics: Dict[str, float] = {} this_epoch_val_metric: float = None metrics: Dict[str, float] = {} train_metrics = trainer._train_epoch(epoch_count) if trainer._validation_data is not None: with torch.no_grad(): # We have a validation set, so compute all the metrics on it. val_loss, num_batches = trainer._validation_loss() val_metrics = training_util.get_metrics(trainer.model, val_loss, num_batches, reset=True) this_epoch_val_metric = val_metrics[trainer._validation_metric] for key, value in train_metrics.items(): metrics["training_" + key] = value for key, value in val_metrics.items(): metrics["validation_" + key] = value if trainer._serialization_dir: dump_metrics( os.path.join(trainer._serialization_dir, f"metrics_epoch_{epoch_count}.json"), metrics) # The Scheduler API is agnostic to whether your schedule requires a validation metric - # if it doesn't, the validation metric passed here is ignored. if trainer._learning_rate_scheduler: trainer._learning_rate_scheduler.step(this_epoch_val_metric, epoch_count) if trainer._momentum_scheduler: trainer._momentum_scheduler.step(this_epoch_val_metric, epoch_count) #trainer._save_checkpoint(epoch_count) return train_metrics, val_metrics