def run_eval_epoch(val_fn: EvalStepFn, val_dataset: tf.data.Dataset, val_summary_writer: tf.summary.SummaryWriter, test_fn: EvalStepFn, test_dataset: tf.data.Dataset, test_summary_writer: tf.summary.SummaryWriter, current_step: int, hparams: Optional[Dict[str, Any]]): """Run one evaluation epoch on the test and optionally validation splits.""" val_outputs_np = None if val_dataset: val_iterator = iter(val_dataset) val_outputs = val_fn(val_iterator) with val_summary_writer.as_default(): if hparams: hp.hparams(hparams) for name, metric in val_outputs.items(): tf.summary.scalar(name, metric, step=current_step) val_outputs_np = {k: v.numpy() for k, v in val_outputs.items()} logging.info('Validation metrics for step %d: %s', current_step, val_outputs_np) test_outputs = {} if test_summary_writer: test_iterator = iter(test_dataset) test_outputs = test_fn(test_iterator) with test_summary_writer.as_default(): if hparams: hp.hparams(hparams) for name, metric in test_outputs.items(): tf.summary.scalar(name, metric, step=current_step) return val_outputs_np, {k: v.numpy() for k, v in test_outputs.items()}
def _write_summaries(train_step_outputs: Dict[str, Any], current_step: int, train_summary_writer: tf.summary.SummaryWriter, hparams: Optional[Dict[str, Any]] = None) -> None: """Log metrics every using tf.summary.""" with train_summary_writer.as_default(): if hparams: hp.hparams(hparams) for name, result in train_step_outputs.items(): tf.summary.scalar(name, result, step=current_step)
def run_eval_epoch( current_step: int, test_fn: EvalStepFn, test_dataset: tf.data.Dataset, test_summary_writer: tf.summary.SummaryWriter, val_fn: Optional[EvalStepFn] = None, val_dataset: Optional[tf.data.Dataset] = None, val_summary_writer: Optional[tf.summary.SummaryWriter] = None, ood_fn: Optional[EvalStepFn] = None, ood_dataset: Optional[tf.data.Dataset] = None, ood_summary_writer: Optional[tf.summary.SummaryWriter] = None, hparams: Optional[Dict[str, Any]] = None, ): """Run one evaluation epoch on the test and optionally validation splits.""" val_outputs_np = None if val_dataset: val_iterator = iter(val_dataset) val_outputs = val_fn(val_iterator) with val_summary_writer.as_default(): # pytype: disable=attribute-error if hparams: hp.hparams(hparams) for name, metric in val_outputs.items(): tf.summary.scalar(name, metric, step=current_step) val_outputs_np = {k: v.numpy() for k, v in val_outputs.items()} logging.info('Validation metrics for step %d: %s', current_step, val_outputs_np) if ood_dataset: ood_iterator = iter(ood_dataset) ood_outputs = ood_fn(ood_iterator) with ood_summary_writer.as_default(): # pytype: disable=attribute-error if hparams: hp.hparams(hparams) for name, metric in ood_outputs.items(): tf.summary.scalar(name, metric, step=current_step) ood_outputs_np = {k: v.numpy() for k, v in ood_outputs.items()} logging.info('OOD metrics for step %d: %s', current_step, ood_outputs_np) test_iterator = iter(test_dataset) test_outputs = test_fn(test_iterator) with test_summary_writer.as_default(): if hparams: hp.hparams(hparams) for name, metric in test_outputs.items(): tf.summary.scalar(name, metric, step=current_step) test_outputs_np = {k: v.numpy() for k, v in test_outputs.items()} return val_outputs_np, ood_outputs_np, test_outputs_np
def _write_summaries(train_step_outputs: Dict[str, Any], current_step: int, train_summary_writer: tf.summary.SummaryWriter): """Log metrics every using tf.summary.""" with train_summary_writer.as_default(): for name, result in train_step_outputs.items(): tf.summary.scalar(name, result, step=current_step)