Beispiel #1
0
def predict_once(run_configuration, optimizer=None):
    """Predict the result once for each element in the dataset."""
    adapter = run_configuration.adapter
    checkpoint_path = run_configuration.original_checkpoint_path
    optimizer = optimizer or adapter.create_optimizer(run_configuration)
    dataset = run_configuration.dataset_info.dataset

    # Restore checkpoint
    optimizer = checkpoint_utils.restore_checkpoint(checkpoint_path, optimizer)

    # Replicate optimizer.
    optimizer = flax.jax_utils.replicate(optimizer)
    predict_step = adapter.make_predict_step()
    predict_step_parallel = jax.pmap(predict_step, axis_name='batch')

    # Perform inference
    dataset_iter_raw = iter(dataset)
    dataset_iter = adapter.preprocess(dataset_iter_raw)
    metrics_all = []
    for example in itertools.islice(dataset_iter, 200):
        train_inputs = adapter.get_train_inputs(example)
        metrics, logits, state = predict_step_parallel(optimizer.target,
                                                       train_inputs)
        adapter.handle_predict(metrics, logits, state)
        metrics_all.append(metrics)
    metrics_all = common_utils.get_metrics(metrics_all)
    metrics = jax.tree_map(jnp.sum, metrics_all)
    return metrics
Beispiel #2
0
def eval_once(run_configuration, checkpoint_path, optimizer=None):
    """Evaluates a single checkpoint on a single epoch of data."""
    config = run_configuration.config
    run_dir = run_configuration.run_dir
    adapter = run_configuration.adapter
    optimizer = optimizer or adapter.create_optimizer(run_configuration)
    dataset = run_configuration.dataset_info.dataset
    info = run_configuration.dataset_info.info

    eval_name = config.eval_name or 'eval'
    log_dir = os.path.join(run_dir, eval_name)

    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(log_dir)

    # Restore checkpoint
    optimizer = checkpoint_utils.restore_checkpoint(checkpoint_path, optimizer)
    step = int(optimizer.state.step)

    # Replicate optimizer.
    optimizer = flax.jax_utils.replicate(optimizer)
    eval_step = adapter.make_eval_step()
    eval_step_parallel = jax.pmap(eval_step, axis_name='batch')

    # Perform evaluation
    tick = time.time()
    metrics_all = []

    example = None
    dataset_iter_raw = iter(dataset)
    dataset_iter = adapter.preprocess(dataset_iter_raw)
    for unused_eval_step, example in zip(range(config.eval_steps),
                                         dataset_iter):
        train_inputs = adapter.get_train_inputs(example)
        metrics, logits, state = eval_step_parallel(optimizer.target,
                                                    train_inputs)
        metrics_all.append(metrics)

    # Write results.
    metrics_all = common_utils.get_metrics(metrics_all)
    metrics_sums = jax.tree_map(jnp.sum, metrics_all)
    denominator = metrics_sums.pop('denominator')
    summary = jax.tree_map(lambda x: x / denominator, metrics_sums)  # pylint: disable=cell-var-from-loop
    summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)
    logging.info('eval @ train step: %d, loss: %.4f', step, summary['loss'])
    if jax.host_id() == 0:
        tock = time.time()
        steps_per_sec = len(metrics_all) / (tock - tick)
        examples_per_sec = denominator / (tock - tick)
        summary_writer.scalar('per-second/steps', steps_per_sec, step)
        summary_writer.scalar('per-second/examples', examples_per_sec, step)
        for key, val in summary.items():
            summary_writer.scalar(key, val, step)

        adapter.write_summaries(example, logits, summary_writer, info, step,
                                state)
        summary_writer.flush()