def on_epoch( self, trainer: "GradientDescentTrainer", metrics: Dict[str, Any], epoch: int, is_primary: bool = True, **kwargs, ) -> None: if not is_primary: return None assert self.trainer is not None train_metrics, val_metrics = get_train_and_validation_metrics(metrics) self.log_epoch( train_metrics, val_metrics, epoch, )
def on_epoch( self, trainer: "GradientDescentTrainer", metrics: Dict[str, Any], epoch: int, is_primary: bool = True, **kwargs, ) -> None: if not is_primary: return None train_metrics, val_metrics = get_train_and_validation_metrics(metrics) metric_names = set(train_metrics.keys()) if val_metrics is not None: metric_names.update(val_metrics.keys()) val_metrics = val_metrics or {} dual_message_template = "%s | %8.3f | %8.3f" no_val_message_template = "%s | %8.3f | %8s" no_train_message_template = "%s | %8s | %8.3f" header_template = "%s | %-10s" name_length = max(len(x) for x in metric_names) logger.info(header_template, "Training".rjust(name_length + 13), "Validation") for name in sorted(metric_names): train_metric = train_metrics.get(name) val_metric = val_metrics.get(name) if val_metric is not None and train_metric is not None: logger.info(dual_message_template, name.ljust(name_length), train_metric, val_metric) elif val_metric is not None: logger.info(no_train_message_template, name.ljust(name_length), "N/A", val_metric) elif train_metric is not None: logger.info(no_val_message_template, name.ljust(name_length), train_metric, "N/A")