Exemple #1
0
def train_one_epoch(
        trainer: Trainer,
        epoch_count: int) -> Tuple[Dict[str, float], Dict[str, float]]:
    train_metrics: Dict[str, float] = {}
    val_metrics: Dict[str, float] = {}
    this_epoch_val_metric: float = None
    metrics: Dict[str, float] = {}

    train_metrics = trainer._train_epoch(epoch_count)

    if trainer._validation_data is not None:
        with torch.no_grad():
            # We have a validation set, so compute all the metrics on it.
            val_loss, num_batches = trainer._validation_loss()
            val_metrics = training_util.get_metrics(trainer.model,
                                                    val_loss,
                                                    num_batches,
                                                    reset=True)
            this_epoch_val_metric = val_metrics[trainer._validation_metric]

    for key, value in train_metrics.items():
        metrics["training_" + key] = value
    for key, value in val_metrics.items():
        metrics["validation_" + key] = value

    if trainer._serialization_dir:
        dump_metrics(
            os.path.join(trainer._serialization_dir,
                         f"metrics_epoch_{epoch_count}.json"), metrics)

    # The Scheduler API is agnostic to whether your schedule requires a validation metric -
    # if it doesn't, the validation metric passed here is ignored.
    if trainer._learning_rate_scheduler:
        trainer._learning_rate_scheduler.step(this_epoch_val_metric,
                                              epoch_count)
    if trainer._momentum_scheduler:
        trainer._momentum_scheduler.step(this_epoch_val_metric, epoch_count)
    #trainer._save_checkpoint(epoch_count)
    return train_metrics, val_metrics