def test_args_validation(): trainer = Engine(do_nothing_update_fn) with pytest.raises(ValueError, match=r"Argument patience should be positive integer."): EarlyStopping(patience=-1, score_function=lambda engine: 0, trainer=trainer) with pytest.raises( ValueError, match=r"Argument min_delta should not be a negative number."): EarlyStopping(patience=2, min_delta=-0.1, score_function=lambda engine: 0, trainer=trainer) with pytest.raises(TypeError, match=r"Argument score_function should be a function."): EarlyStopping(patience=2, score_function=12345, trainer=trainer) with pytest.raises( TypeError, match=r"Argument trainer should be an instance of Engine."): EarlyStopping(patience=2, score_function=lambda engine: 0, trainer=None)
def test_with_engine_early_stopping(): class Counter(object): def __init__(self, count=0): self.count = count n_epochs_counter = Counter() scores = iter([1.0, 0.8, 1.2, 1.5, 0.9, 1.0, 0.99, 1.1, 0.9]) def score_function(engine): return next(scores) def update_fn(engine, batch): pass trainer = Engine(update_fn) evaluator = Engine(update_fn) early_stopping = EarlyStopping(patience=3, score_function=score_function, trainer=trainer) @trainer.on(Events.EPOCH_COMPLETED) def evaluation(engine): evaluator.run([0]) n_epochs_counter.count += 1 evaluator.add_event_handler(Events.COMPLETED, early_stopping) trainer.run([0], max_epochs=10) assert n_epochs_counter.count == 7
def create_callbacks(self): ## SETUP CALLBACKS print('[INFO] Creating callback functions for training loop...', end='') # Early Stopping - stops training if the validation loss does not decrease after 5 epochs handler = EarlyStopping(patience=self.config.EARLY_STOPPING_PATIENCE, score_function=score_function_loss, trainer=self.train_engine) self.evaluator.add_event_handler(Events.COMPLETED, handler) print('Early Stopping ({} epochs)...'.format( self.config.EARLY_STOPPING_PATIENCE), end='') val_checkpointer = Checkpoint( {"model": self.model}, ClearMLSaver(), n_saved=1, score_function=score_function_acc, score_name="val_acc", filename_prefix='cub200_{}_ignite_best'.format( self.config.MODEL.MODEL_NAME), global_step_transform=global_step_from_engine(self.train_engine), ) self.evaluator.add_event_handler(Events.EPOCH_COMPLETED, val_checkpointer) print('Model Checkpointing...', end='') print('Done')
def finalize(self, context): if context.local_rank == 0: publisher = PublishStatsAndModel( self._stats_path, self._publish_path, self._key_metric_filename, context.start_ts, context.run_id, context.output_dir, context.trainer, context.evaluator, ) if context.evaluator: context.evaluator.add_event_handler( event_name=Events.EPOCH_COMPLETED, handler=publisher) else: context.trainer.add_event_handler( event_name=Events.EPOCH_COMPLETED, handler=publisher) early_stop_patience = int(context.request.get("early_stop_patience", 0)) if early_stop_patience > 0 and context.evaluator: early_stopper = EarlyStopping( patience=early_stop_patience, score_function=stopping_fn_from_metric(self.VAL_KEY_METRIC), trainer=context.trainer, ) context.evaluator.add_event_handler( event_name=Events.EPOCH_COMPLETED, handler=early_stopper)
def _setup_early_stopping(self, trainer, val_evaluator, score_function): kwargs = dict(self.early_stopping_kwargs) if 'score_function' not in kwargs: kwargs['score_function'] = score_function handler = EarlyStopping(trainer=trainer, **kwargs) setup_logger(handler._logger, self.log_filepath, self.log_level) val_evaluator.add_event_handler(Events.COMPLETED, handler)
def train(param, device): model = Model(param) state_dict = torch.load(CKPT) new_dict = model.state_dict().copy() for k, v in state_dict.items(): if k.startswith('t_encoder'): new_dict[k] = state_dict[k] model.load_state_dict(new_dict) for parameter in model.t_encoder.parameters(): parameter.requires_grad = False optimizer = AdamW(model.parameters(), lr=param.lr, eps=1e-8) update_steps = MAX_EPOCH * len(train_loader) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=update_steps) loss_fn = L1Loss() trainer = create_trainer(model, optimizer, scheduler, loss_fn, MAX_GRAD_NORM, device) dev_evaluator = create_evaluator(model, val_metrics, device) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=10), log_training_loss) trainer.add_event_handler(Events.EPOCH_COMPLETED, log_results, *[dev_evaluator, dev_loader, 'Dev']) es_handler = EarlyStopping(patience=PATIENCE, score_function=score_fn, trainer=trainer) dev_evaluator.add_event_handler(Events.COMPLETED, es_handler) ckpt_handler = ModelCheckpoint(SAVE_PATH, f'lr_{param.lr}', score_function=score_fn, score_name='score', require_empty=True) dev_evaluator.add_event_handler(Events.COMPLETED, ckpt_handler, {SAVE_PATH.split("/")[-1]: model}) print(f'Start running {SAVE_PATH.split("/")[-1]} at device: {DEVICE}\tlr: {param.lr}') trainer.run(train_loader, max_epochs=MAX_EPOCH)
def test_with_engine_early_stopping_on_plateau(): class Counter(object): def __init__(self, count=0): self.count = count n_epochs_counter = Counter() def score_function(engine): return 0.047 trainer = Engine(do_nothing_update_fn) evaluator = Engine(do_nothing_update_fn) early_stopping = EarlyStopping(patience=4, score_function=score_function, trainer=trainer) @trainer.on(Events.EPOCH_COMPLETED) def evaluation(engine): evaluator.run([0]) n_epochs_counter.count += 1 evaluator.add_event_handler(Events.COMPLETED, early_stopping) trainer.run([0], max_epochs=10) assert n_epochs_counter.count == 5 assert trainer.state.epoch == 5
def test_args_validation(): def update_fn(engine, batch): pass trainer = Engine(update_fn) # save_interval & score_func with pytest.raises(AssertionError): h = EarlyStopping(patience=-1, score_function=lambda engine: 0, trainer=trainer) with pytest.raises(AssertionError): h = EarlyStopping(patience=2, score_function=12345, trainer=trainer) with pytest.raises(AssertionError): h = EarlyStopping(patience=2, score_function=lambda engine: 0, trainer=None)
def test_with_engine_no_early_stopping(): class Counter(object): def __init__(self, count=0): self.count = count n_epochs_counter = Counter() scores = iter([1.0, 0.8, 1.2, 1.23, 0.9, 1.0, 1.1, 1.253, 1.26, 1.2]) def score_function(engine): return next(scores) trainer = Engine(do_nothing_update_fn) evaluator = Engine(do_nothing_update_fn) early_stopping = EarlyStopping(patience=5, score_function=score_function, trainer=trainer) @trainer.on(Events.EPOCH_COMPLETED) def evaluation(engine): evaluator.run([0]) n_epochs_counter.count += 1 evaluator.add_event_handler(Events.COMPLETED, early_stopping) trainer.run([0], max_epochs=10) assert n_epochs_counter.count == 10 assert trainer.state.epoch == 10
def add_early_stopping_and_checkpoint(evaluator: Engine, trainer: Engine, checkpoint_filename: str, model: Module) -> None: """ adds two event handlers to an ``ignite`` trainer/evaluator pair: * early stopping * best model checkpoint saver :param evaluator: an evaluator to add hooks to :param trainer: a trainer from which to make a checkpoint :param checkpoint_filename: some pretty name for a checkpoint :param model: a network which is saved in checkpoints """ def score(engine): return -engine.state.metrics["loss"] early_stopping = EarlyStopping(100, score, trainer) evaluator.add_event_handler(Events.COMPLETED, early_stopping) checkpoint = ModelCheckpoint("checkpoints", "", score_function=score, require_empty=False) evaluator.add_event_handler(Events.COMPLETED, checkpoint, {checkpoint_filename: model})
def _test_distrib_with_engine_early_stopping(device): import torch.distributed as dist torch.manual_seed(12) class Counter(object): def __init__(self, count=0): self.count = count n_epochs_counter = Counter() scores = torch.tensor([1.0, 0.8, 1.2, 1.5, 0.9, 1.0, 0.99, 1.1, 0.9], requires_grad=False).to(device) def score_function(engine): i = trainer.state.epoch - 1 v = scores[i] dist.all_reduce(v) v /= dist.get_world_size() return v.item() trainer = Engine(do_nothing_update_fn) evaluator = Engine(do_nothing_update_fn) early_stopping = EarlyStopping(patience=3, score_function=score_function, trainer=trainer) @trainer.on(Events.EPOCH_COMPLETED) def evaluation(engine): evaluator.run([0]) n_epochs_counter.count += 1 evaluator.add_event_handler(Events.COMPLETED, early_stopping) trainer.run([0], max_epochs=10) assert trainer.state.epoch == 7 assert n_epochs_counter.count == 7
def train(epochs, model, train_loader, valid_loader, criterion, optimizer, writer, device, log_interval): # device: str であることに注意 # この時点では Dataloader を与えていないことに注意 trainer = create_supervised_trainer(model, optimizer, criterion, device=device) evaluator = create_supervised_evaluator(model, metrics={ 'accuracy': Accuracy(), 'nll': Loss(criterion) }, device=device) @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): i = (engine.state.iteration - 1) % len(train_loader) + 1 if i % log_interval == 0: print( f"Epoch[{engine.state.epoch}] Iteration[{i}/{len(train_loader)}] " f"Loss: {engine.state.output:.2f}") # engine.state.output は criterion(model(input)) を表す? writer.add_scalar("training/loss", engine.state.output, engine.state.iteration) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): evaluator.run(train_loader) metrics = evaluator.state.metrics write_metrics(metrics, writer, 'training', engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(valid_loader) metrics = evaluator.state.metrics write_metrics(metrics, writer, 'validation', engine.state.epoch) # # Checkpoint setting # ./checkpoints/sample_mymodel_{step_number} # n_saved 個までパラメータを保持する handler = ModelCheckpoint(dirname='./checkpoints', filename_prefix='sample', save_interval=2, n_saved=3, create_dir=True) trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {'mymodel': model}) # # Early stopping handler = EarlyStopping(patience=5, score_function=score_function, trainer=trainer) # Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset) evaluator.add_event_handler(Events.COMPLETED, handler) # kick everything off trainer.run(train_loader, max_epochs=epochs)
def test_state_dict(): scores = iter([1.0, 0.8, 0.88]) def score_function(engine): return next(scores) trainer = Engine(do_nothing_update_fn) h = EarlyStopping(patience=2, score_function=score_function, trainer=trainer) # Call 3 times and check if stopped assert not trainer.should_terminate h(None) assert not trainer.should_terminate # Swap to new object, but maintain state h2 = EarlyStopping(patience=2, score_function=score_function, trainer=trainer) h2.load_state_dict(h.state_dict()) h2(None) assert not trainer.should_terminate h2(None) assert trainer.should_terminate
def add_early_stopping(trainer, val_evaluator, configuration): # Setup early stopping: handler = EarlyStopping( patience=configuration.early_stop_patience, score_function=_score_function, trainer=trainer, ) setup_logger(handler._logger, configuration.log_dir, configuration.log_level) val_evaluator.add_event_handler(Events.COMPLETED, handler)
def register_early_stopping(evaluator_test, trainer, args): def score_function(engine): val_loss = engine.state.metrics['bce'] return val_loss early_stopping_handler = EarlyStopping(patience=args.patience, score_function=score_function, trainer=trainer) evaluator_test.add_event_handler(Events.COMPLETED, early_stopping_handler)
def _build_objects(acc_list): model = DummyModel().to(device) optim = torch.optim.SGD(model.parameters(), lr=0.1) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.5) def update_fn(engine, batch): x = torch.rand((4, 1)).to(device) optim.zero_grad() y = model(x) loss = y.pow(2.0).sum() loss.backward() if idist.has_xla_support: import torch_xla.core.xla_model as xm xm.optimizer_step(optim, barrier=True) else: optim.step() lr_scheduler.step() trainer = Engine(update_fn) evaluator = Engine(lambda e, b: None) acc_iter = iter(acc_list) @evaluator.on(Events.EPOCH_COMPLETED) def setup_result(): evaluator.state.metrics["accuracy"] = next(acc_iter) @trainer.on(Events.EPOCH_COMPLETED) def run_eval(): evaluator.run([0, 1, 2]) def score_function(engine): return engine.state.metrics["accuracy"] save_handler = DiskSaver(dirname, create_dir=True, require_empty=False) early_stop = EarlyStopping(score_function=score_function, patience=2, trainer=trainer) evaluator.add_event_handler(Events.COMPLETED, early_stop) checkpointer = Checkpoint( { "trainer": trainer, "model": model, "optim": optim, "lr_scheduler": lr_scheduler, "early_stop": early_stop, }, save_handler, include_self=True, global_step_transform=global_step_from_engine(trainer), ) evaluator.add_event_handler(Events.COMPLETED, checkpointer) return trainer, evaluator, model, optim, lr_scheduler, early_stop, checkpointer
def make_early_stopper(self, trainer): if self.early_stop_metric == 'loss': key_name = 'val_loss' c = -1 else: c = 1 key_name = 'val_accuracy' return EarlyStopping(self.early_stop_patience, lambda e: c * e.state.metrics[key_name], trainer, min_delta=self.early_stop_delta)
def test_args_validation(): def update_fn(engine, batch): pass trainer = Engine(update_fn) with pytest.raises(ValueError): h = EarlyStopping(patience=-1, score_function=lambda engine: 0, trainer=trainer) with pytest.raises(TypeError): h = EarlyStopping(patience=2, score_function=12345, trainer=trainer) with pytest.raises(TypeError): h = EarlyStopping(patience=2, score_function=lambda engine: 0, trainer=None)
def _test_distrib_integration_engine_early_stopping(device): from ignite.metrics import Accuracy if device is None: device = idist.device() if isinstance(device, str): device = torch.device(device) metric_device = device if device.type == "xla": metric_device = "cpu" rank = idist.get_rank() ws = idist.get_world_size() torch.manual_seed(12) n_epochs = 10 n_iters = 20 y_preds = ( [torch.randint(0, 2, size=(n_iters, ws)).to(device)] + [torch.ones(n_iters, ws).to(device)] + [torch.randint(0, 2, size=(n_iters, ws)).to(device) for _ in range(n_epochs - 2)] ) y_true = ( [torch.randint(0, 2, size=(n_iters, ws)).to(device)] + [torch.ones(n_iters, ws).to(device)] + [torch.randint(0, 2, size=(n_iters, ws)).to(device) for _ in range(n_epochs - 2)] ) def update(engine, _): e = trainer.state.epoch - 1 i = engine.state.iteration - 1 return y_preds[e][i, rank], y_true[e][i, rank] evaluator = Engine(update) acc = Accuracy(device=metric_device) acc.attach(evaluator, "acc") def score_function(engine): return engine.state.metrics["acc"] trainer = Engine(lambda e, b: None) early_stopping = EarlyStopping(patience=3, score_function=score_function, trainer=trainer) @trainer.on(Events.EPOCH_COMPLETED) def evaluation(engine): data = list(range(n_iters)) evaluator.run(data=data) evaluator.add_event_handler(Events.COMPLETED, early_stopping) trainer.run([0], max_epochs=10) assert trainer.state.epoch == 5
def add_early_stopping_by_val_score(patience, evaluator, trainer, metric_name): """Method setups early stopping handler based on the score (named by `metric_name`) provided by `evaluator`. Args: patience (int): number of events to wait if no improvement and then stop the training. evaluator (Engine): evaluation engine used to provide the score trainer (Engine): trainer engine to stop the run if no improvement. metric_name (str): metric name to use for score evaluation. This metric should be present in `evaluator.state.metrics`. """ es_handler = EarlyStopping(patience=patience, score_function=get_default_score_fn(metric_name), trainer=trainer) evaluator.add_event_handler(Events.COMPLETED, es_handler)
def test_simple_early_stopping_on_plateau(): def score_function(engine): return 42 trainer = Engine(do_nothing_update_fn) h = EarlyStopping(patience=1, score_function=score_function, trainer=trainer) # Call 2 times and check if stopped assert not trainer.should_terminate h(None) assert not trainer.should_terminate h(None) assert trainer.should_terminate
def test_args_validation(): def update_fn(engine, batch): pass trainer = Engine(update_fn) with pytest.raises(ValueError, match=r"Argument patience should be positive integer."): EarlyStopping(patience=-1, score_function=lambda engine: 0, trainer=trainer) with pytest.raises(TypeError, match=r"Argument score_function should be a function."): EarlyStopping(patience=2, score_function=12345, trainer=trainer) with pytest.raises( TypeError, match=r"Argument trainer should be an instance of Engine."): EarlyStopping(patience=2, score_function=lambda engine: 0, trainer=None)
def _early_stopping_handler(self): """Create the EarlyStopping handler that will evaluate the `score_function` class on each `evaluator_engine` run and stop the `trainer_engine` if there has been no improvement in the `_score_function` for the number of epochs specified in `early_stopping_patience`. Args: Returns: the early stopping handler """ return EarlyStopping( patience=self.early_stopping_patience, score_function=self._score_function, trainer=self.trainer_engine )
def init_function(h_model): h_criterion = torch.nn.CrossEntropyLoss() h_evaluator = SupervisedEvaluator(model=h_model, criterion=h_criterion, device=device) h_train_evaluator = SupervisedEvaluator(model=h_model, criterion=h_criterion, device=device) h_optimizer = torch.optim.Adam(params=h_model.parameters(), lr=1e-3) h_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(h_optimizer, 'max', verbose=True, patience=5, factor=0.5) h_trainer = SupervisedTrainer(model=h_model, optimizer=h_optimizer, criterion=h_criterion, device=device) # Tqdm logger h_pbar = ProgressBar(persist=False, bar_format=config.IGNITE_BAR_FORMAT) h_pbar.attach(h_trainer.engine, metric_names='all') h_tqdm_logger = TqdmLogger(pbar=h_pbar) # noinspection PyTypeChecker h_tqdm_logger.attach_output_handler( h_evaluator.engine, event_name=Events.COMPLETED, tag="validation", global_step_transform=global_step_from_engine(h_trainer.engine), ) # noinspection PyTypeChecker h_tqdm_logger.attach_output_handler( h_train_evaluator.engine, event_name=Events.COMPLETED, tag="train", global_step_transform=global_step_from_engine(h_trainer.engine), ) # Learning rate scheduling # The PyTorch Ignite LRScheduler class does not work with ReduceLROnPlateau h_evaluator.engine.add_event_handler(Events.COMPLETED, lambda engine: h_lr_scheduler.step(engine.state.metrics['accuracy'])) # Model checkpoints h_handler = ModelCheckpoint(config.MODELS_DIR, run.replace('/', '-'), n_saved=1, create_dir=True, require_empty=False, score_name='acc', score_function=lambda engine: engine.state.metrics['accuracy'], global_step_transform=global_step_from_engine(trainer.engine)) h_evaluator.engine.add_event_handler(Events.EPOCH_COMPLETED, h_handler, {'m': model}) # Early stopping h_es_handler = EarlyStopping(patience=15, min_delta=0.0001, score_function=lambda engine: engine.state.metrics['accuracy'], trainer=h_trainer.engine, cumulative_delta=True) h_es_handler.logger.setLevel(logging.DEBUG) h_evaluator.engine.add_event_handler(Events.COMPLETED, h_es_handler) return h_trainer, h_train_evaluator, h_evaluator
def _register_early_stopping(self, loss_fn, device, trainer): prepare_batch = __class__._prepare_batch evaluator = create_supervised_evaluator(model=self.model, metrics={'nll': Loss(loss_fn)}, device=device, prepare_batch=prepare_batch) def score_fn(engine): return -engine.state.metrics['nll'] early_stopping = EarlyStopping(patience=5, score_function=score_fn, trainer=trainer) evaluator.add_event_handler(Events.COMPLETED, early_stopping) return evaluator
def test_simple_no_early_stopping(): scores = iter([1.0, 0.8, 1.2]) def score_function(engine): return next(scores) trainer = Engine(do_nothing_update_fn) h = EarlyStopping(patience=2, score_function=score_function, trainer=trainer) # Call 3 times and check if not stopped assert not trainer.should_terminate h(None) h(None) h(None) assert not trainer.should_terminate
def test_early_stopping_on_last_event_delta(): scores = iter([0.0, 0.3, 0.6]) trainer = Engine(do_nothing_update_fn) h = EarlyStopping( patience=2, min_delta=0.4, cumulative_delta=False, score_function=lambda _: next(scores), trainer=trainer ) assert not trainer.should_terminate h(None) # counter == 0 assert not trainer.should_terminate h(None) # delta == 0.3; counter == 1 assert not trainer.should_terminate h(None) # delta == 0.3; counter == 2 assert trainer.should_terminate
def _finetune(self, train_dl, val_dl, criterion, iter_num): print("Recovery") self.model.to_rank = False finetune_epochs = config["pruning"]["finetune_epochs"].get() optimizer_constructor = optimizer_constructor_from_config(config) optimizer = optimizer_constructor(self.model.parameters()) finetune_engine = create_supervised_trainer(self.model, optimizer, criterion, self.device) # progress bar pbar = Progbar(train_dl, metrics='none') finetune_engine.add_event_handler(Events.ITERATION_COMPLETED, pbar) # log training loss if self.writer: finetune_engine.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: log_training_loss(engine, self.writer)) # terminate on Nan finetune_engine.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) # model checkpoints checkpoint = ModelCheckpoint(config["pruning"]["out_path"].get(), require_empty=False, filename_prefix=f"pruning_iteration_{iter_num}", save_interval=1) finetune_engine.add_event_handler(Events.COMPLETED, checkpoint, {"weights": self.model.cpu()}) # add early stopping validation_evaluator = create_supervised_evaluator(self.model, device=self.device, metrics=self._metrics) if config["pruning"]["early_stopping"].get(): def _score_function(evaluator): return -evaluator.state.metrics["loss"] early_stop = EarlyStopping(config["pruning"]["patience"].get(), _score_function, finetune_engine) validation_evaluator.add_event_handler(Events.EPOCH_COMPLETED, early_stop) finetune_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: run_evaluator(engine, validation_evaluator, val_dl)) for handler_dict in self._finetune_handlers: finetune_engine.add_event_handler(handler_dict["event_name"], handler_dict["handler"], *handler_dict["args"], **handler_dict["kwargs"]) # run training engine finetune_engine.run(train_dl, max_epochs=finetune_epochs)
def add_early_stopping_by_val_score(patience: int, evaluator: Engine, trainer: Engine, metric_name: str): """Method setups early stopping handler based on the score (named by `metric_name`) provided by `evaluator`. Metric value should increase in order to keep training and not early stop. Args: patience (int): number of events to wait if no improvement and then stop the training. evaluator (Engine): evaluation engine used to provide the score trainer (Engine): trainer engine to stop the run if no improvement. metric_name (str): metric name to use for score evaluation. This metric should be present in `evaluator.state.metrics`. Returns: A :class:`~ignite.handlers.EarlyStopping` handler. """ es_handler = EarlyStopping(patience=patience, score_function=get_default_score_fn(metric_name), trainer=trainer) evaluator.add_event_handler(Events.COMPLETED, es_handler) return es_handler
def assign_event_handlers(trainer, evaluator, val_set): pbar = ProgressBar() pbar.attach(trainer, ['loss']) early_stop = EarlyStopping(patience=2, score_function=lambda e: -e.state.metrics['loss'], trainer=trainer) evaluator.add_event_handler(Events.COMPLETED, early_stop) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): print("\nTraining Results - Epoch: {} : Avg loss: {:.3f}" .format(trainer.state.epoch, trainer.state.metrics['avg_loss'])) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(val_set) metrics_eval = evaluator.state.metrics print("Validation Results - Epoch: {} Avg loss: {:.3f}, Avg abs. error: {:.2f}" .format(trainer.state.epoch, metrics_eval['loss'], metrics_eval['mae']))