Beispiel #1
0
def run_eval_epoch(val_fn: EvalStepFn, val_dataset: tf.data.Dataset,
                   val_summary_writer: tf.summary.SummaryWriter,
                   test_fn: EvalStepFn, test_dataset: tf.data.Dataset,
                   test_summary_writer: tf.summary.SummaryWriter,
                   current_step: int, hparams: Optional[Dict[str, Any]]):
    """Run one evaluation epoch on the test and optionally validation splits."""
    val_outputs_np = None
    if val_dataset:
        val_iterator = iter(val_dataset)
        val_outputs = val_fn(val_iterator)
        with val_summary_writer.as_default():
            if hparams:
                hp.hparams(hparams)
            for name, metric in val_outputs.items():
                tf.summary.scalar(name, metric, step=current_step)
        val_outputs_np = {k: v.numpy() for k, v in val_outputs.items()}
        logging.info('Validation metrics for step %d: %s', current_step,
                     val_outputs_np)
    test_outputs = {}
    if test_summary_writer:
        test_iterator = iter(test_dataset)
        test_outputs = test_fn(test_iterator)
        with test_summary_writer.as_default():
            if hparams:
                hp.hparams(hparams)
            for name, metric in test_outputs.items():
                tf.summary.scalar(name, metric, step=current_step)
    return val_outputs_np, {k: v.numpy() for k, v in test_outputs.items()}
Beispiel #2
0
def _write_summaries(train_step_outputs: Dict[str, Any],
                     current_step: int,
                     train_summary_writer: tf.summary.SummaryWriter,
                     hparams: Optional[Dict[str, Any]] = None) -> None:
    """Log metrics every using tf.summary."""
    with train_summary_writer.as_default():
        if hparams:
            hp.hparams(hparams)
        for name, result in train_step_outputs.items():
            tf.summary.scalar(name, result, step=current_step)
Beispiel #3
0
def run_eval_epoch(
    current_step: int,
    test_fn: EvalStepFn,
    test_dataset: tf.data.Dataset,
    test_summary_writer: tf.summary.SummaryWriter,
    val_fn: Optional[EvalStepFn] = None,
    val_dataset: Optional[tf.data.Dataset] = None,
    val_summary_writer: Optional[tf.summary.SummaryWriter] = None,
    ood_fn: Optional[EvalStepFn] = None,
    ood_dataset: Optional[tf.data.Dataset] = None,
    ood_summary_writer: Optional[tf.summary.SummaryWriter] = None,
    hparams: Optional[Dict[str, Any]] = None,
):
    """Run one evaluation epoch on the test and optionally validation splits."""
    val_outputs_np = None
    if val_dataset:
        val_iterator = iter(val_dataset)
        val_outputs = val_fn(val_iterator)
        with val_summary_writer.as_default():  # pytype: disable=attribute-error
            if hparams:
                hp.hparams(hparams)
            for name, metric in val_outputs.items():
                tf.summary.scalar(name, metric, step=current_step)
        val_outputs_np = {k: v.numpy() for k, v in val_outputs.items()}
        logging.info('Validation metrics for step %d: %s', current_step,
                     val_outputs_np)
    if ood_dataset:
        ood_iterator = iter(ood_dataset)
        ood_outputs = ood_fn(ood_iterator)
        with ood_summary_writer.as_default():  # pytype: disable=attribute-error
            if hparams:
                hp.hparams(hparams)
            for name, metric in ood_outputs.items():
                tf.summary.scalar(name, metric, step=current_step)
        ood_outputs_np = {k: v.numpy() for k, v in ood_outputs.items()}
        logging.info('OOD metrics for step %d: %s', current_step,
                     ood_outputs_np)

    test_iterator = iter(test_dataset)
    test_outputs = test_fn(test_iterator)
    with test_summary_writer.as_default():
        if hparams:
            hp.hparams(hparams)
        for name, metric in test_outputs.items():
            tf.summary.scalar(name, metric, step=current_step)
    test_outputs_np = {k: v.numpy() for k, v in test_outputs.items()}
    return val_outputs_np, ood_outputs_np, test_outputs_np
def _write_summaries(train_step_outputs: Dict[str, Any], current_step: int,
                     train_summary_writer: tf.summary.SummaryWriter):
    """Log metrics every using tf.summary."""
    with train_summary_writer.as_default():
        for name, result in train_step_outputs.items():
            tf.summary.scalar(name, result, step=current_step)