def predict_once(run_configuration, optimizer=None): """Predict the result once for each element in the dataset.""" adapter = run_configuration.adapter checkpoint_path = run_configuration.original_checkpoint_path optimizer = optimizer or adapter.create_optimizer(run_configuration) dataset = run_configuration.dataset_info.dataset # Restore checkpoint optimizer = checkpoint_utils.restore_checkpoint(checkpoint_path, optimizer) # Replicate optimizer. optimizer = flax.jax_utils.replicate(optimizer) predict_step = adapter.make_predict_step() predict_step_parallel = jax.pmap(predict_step, axis_name='batch') # Perform inference dataset_iter_raw = iter(dataset) dataset_iter = adapter.preprocess(dataset_iter_raw) metrics_all = [] for example in itertools.islice(dataset_iter, 200): train_inputs = adapter.get_train_inputs(example) metrics, logits, state = predict_step_parallel(optimizer.target, train_inputs) adapter.handle_predict(metrics, logits, state) metrics_all.append(metrics) metrics_all = common_utils.get_metrics(metrics_all) metrics = jax.tree_map(jnp.sum, metrics_all) return metrics
def eval_once(run_configuration, checkpoint_path, optimizer=None): """Evaluates a single checkpoint on a single epoch of data.""" config = run_configuration.config run_dir = run_configuration.run_dir adapter = run_configuration.adapter optimizer = optimizer or adapter.create_optimizer(run_configuration) dataset = run_configuration.dataset_info.dataset info = run_configuration.dataset_info.info eval_name = config.eval_name or 'eval' log_dir = os.path.join(run_dir, eval_name) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(log_dir) # Restore checkpoint optimizer = checkpoint_utils.restore_checkpoint(checkpoint_path, optimizer) step = int(optimizer.state.step) # Replicate optimizer. optimizer = flax.jax_utils.replicate(optimizer) eval_step = adapter.make_eval_step() eval_step_parallel = jax.pmap(eval_step, axis_name='batch') # Perform evaluation tick = time.time() metrics_all = [] example = None dataset_iter_raw = iter(dataset) dataset_iter = adapter.preprocess(dataset_iter_raw) for unused_eval_step, example in zip(range(config.eval_steps), dataset_iter): train_inputs = adapter.get_train_inputs(example) metrics, logits, state = eval_step_parallel(optimizer.target, train_inputs) metrics_all.append(metrics) # Write results. metrics_all = common_utils.get_metrics(metrics_all) metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('eval @ train step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = len(metrics_all) / (tock - tick) examples_per_sec = denominator / (tock - tick) summary_writer.scalar('per-second/steps', steps_per_sec, step) summary_writer.scalar('per-second/examples', examples_per_sec, step) for key, val in summary.items(): summary_writer.scalar(key, val, step) adapter.write_summaries(example, logits, summary_writer, info, step, state) summary_writer.flush()