def test_synchronize_multiple_hosts(self, process_index_mock):
     multihost_base_dir = os.path.join(tempfile.mkdtemp(), "test")
     state = TrainState(step=1)
     process_index_mock.return_value = 0
     ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir)
     process_index_mock.return_value = 1
     ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir)
     # Initialize both at step=1.
     state_0 = ckpt_0.restore_or_initialize(state)
     state_1 = ckpt_1.restore_or_initialize(state)
     # Update both at step=2.
     state_0 = state_0.replace(step=2)
     ckpt_0.save(state_0)
     state_1 = state_1.replace(step=2)
     ckpt_1.save(state_1)
     # Update ckpt_1 at step=3.
     state_1 = state_1.replace(step=3)
     ckpt_1.save(state_1)
     # Reload both at step=2.
     process_index_mock.return_value = 0
     ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir)
     process_index_mock.return_value = 1
     ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir)
     self.assertEqual(ckpt_0.latest_checkpoint,
                      ckpt_0.get_latest_checkpoint_to_restore_from())
     self.assertNotEqual(ckpt_1.latest_checkpoint,
                         ckpt_1.get_latest_checkpoint_to_restore_from())
     state_0 = ckpt_0.restore_or_initialize(state)
     state_1 = ckpt_1.restore_or_initialize(state)
     self.assertEqual(state_0.step, 2)
     self.assertEqual(state_1.step, 2)
 def test_initialize_mkdir(self, process_index_mock):
     multihost_base_dir = os.path.join(tempfile.mkdtemp(), "test")
     state = TrainState(step=1)
     process_index_mock.return_value = 0
     base_dir = f"{multihost_base_dir}-0"
     ckpt = checkpoint.MultihostCheckpoint(multihost_base_dir)
     self.assertIsNone(ckpt.latest_checkpoint)
     self.assertFalse(os.path.isdir(base_dir))
     state = ckpt.restore_or_initialize(state)
     self.assertIsNotNone(ckpt.latest_checkpoint)
     self.assertTrue(os.path.isdir(base_dir))
 def test_preemption(self):
     multihost_base_dir = os.path.join(tempfile.mkdtemp(), "test")
     state = TrainState(step=1)
     state0 = state.replace(step=0)
     ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=0)
     ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=1)
     # Initialize both at step=1.
     state_0 = ckpt_0.restore_or_initialize(state)
     state_1 = ckpt_1.restore_or_initialize(state)
     self.assertEqual(state_0.step, 1)
     self.assertEqual(state_1.step, 1)
     # Restore both at step=1.
     state_0 = ckpt_0.restore_or_initialize(state0)
     state_1 = ckpt_1.restore_or_initialize(state0)
     self.assertEqual(state_0.step, 1)
     self.assertEqual(state_1.step, 1)
     # Update only ckpt_0 to step=2.
     state_0 = state_0.replace(step=2)
     ckpt_0.save(state_0)
     # Load both checkpoints at last common step=1.
     ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=0)
     ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=1)
     state_0 = ckpt_0.restore_or_initialize(state)
     state_1 = ckpt_1.restore_or_initialize(state)
     self.assertEqual(state_0.step, 1)
     self.assertEqual(state_1.step, 1)
     # Store both at step=2.
     state_0 = state_0.replace(step=2)
     state_1 = state_1.replace(step=2)
     ckpt_0.save(state_0)
     ckpt_1.save(state_1)
     # Restore both at step=2.
     state_0 = ckpt_0.restore_or_initialize(state0)
     state_1 = ckpt_1.restore_or_initialize(state0)
     self.assertEqual(state_0.step, 2)
     self.assertEqual(state_1.step, 2)
Beispiel #4
0
    def train_and_evaluate(self, workdir):
        """Runs a training and evaluation loop.

    Args:
      workdir: Working directory for checkpoints and TF summaries. If this
        contains checkpoint training will be resumed from the latest checkpoint.
    """

        tf.io.gfile.makedirs(workdir)
        config = self.config
        substeps = config.training.substeps

        # Learning rate schedule.
        num_train_steps = config.training.num_train_steps
        logging.info('num_train_steps=%d', num_train_steps)

        # Get train state
        state = self._train_state

        # Set up checkpointing of the model and the input pipeline.
        checkpoint_dir = os.path.join(workdir, 'checkpoints')
        ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=5)
        state = ckpt.restore_or_initialize(state)
        initial_step = int(state.step)

        # Distribute training.
        state = flax_utils.replicate(state)

        writer = metric_writers.create_default_writer(
            workdir, just_logging=jax.process_index() > 0)
        if initial_step == 0:
            writer.write_hparams(dict(config))

        logging.info('Starting training loop at step %d.', initial_step)
        hooks = []
        report_progress = periodic_actions.ReportProgress(
            num_train_steps=num_train_steps, writer=writer)
        if jax.process_index() == 0:
            hooks += [
                report_progress,
                periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
            ]
        step = initial_step
        with metric_writers.ensure_flushes(writer):
            while step < num_train_steps:
                # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
                # devices.
                is_last_step = step + substeps >= num_train_steps

                with jax.profiler.StepTraceAnnotation('train', step_num=step):
                    inputs = jax.tree_map(np.asarray, next(self._train_iter))
                    state, outputs = self._update_func(state, inputs)

                # Quick indication that training is happening.
                logging.log_first_n(logging.INFO, 'Finished training step %d.',
                                    5, step)
                for h in hooks:
                    h(step)

                new_step = int(state.step[0])
                assert new_step == step + substeps
                step = new_step

                is_eval = step % config.logs.eval_full_every_steps == 0 or is_last_step
                if step % config.logs.log_loss_every_steps == 0 and not is_eval:

                    def avg_over_substeps(x):
                        assert x.shape[0] == substeps
                        return float(x.mean(axis=0))

                    # Extract scalars and images.
                    outputs = flax_utils.unreplicate(outputs)
                    outputs = jax.tree_map(avg_over_substeps, outputs)
                    scalars = outputs['scalars']
                    writer.write_scalars(step, scalars)

                if is_eval:
                    with report_progress.timed('eval_full'):
                        outputs = self._eval_epoch(params=state.ema_params)
                        outputs = flax_utils.unreplicate(outputs)
                        scalars = outputs['scalars']
                        writer.write_scalars(step, scalars)

                if step % config.logs.checkpoint_every_steps == 0 or is_last_step:
                    with report_progress.timed('checkpoint'):
                        ckpt.save(flax_utils.unreplicate(state))

        logging.info('Finishing training at step %d', num_train_steps)
Beispiel #5
0
def train_and_evaluate(config, workdir):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    logging.info('Starting training at %s', workdir)
    tf.io.gfile.makedirs(workdir)
    if jax.process_index() == 0:
        with tf.io.gfile.GFile(os.path.join(workdir, 'config.json'), 'w') as f:
            json.dump(config.to_dict(), f, indent=2)
    rng = jax.random.PRNGKey(config.seed)

    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
    train_ds, eval_ds = input_pipeline.create_datasets(config.dataset,
                                                       data_rng)
    train_iter = iter(train_ds)

    test_ds = []
    for split in config.dataset.test_splits:
        ds = input_pipeline.create_val_dataset(
            config.dataset, split, config.dataset.test_per_device_batch_size,
            config.dataset.test_pad_last_batch)
        test_ds.append(ds)

    # Learning rate schedule.
    num_train_steps = config.num_train_steps
    if num_train_steps == -1:
        num_train_steps = train_ds.cardinality().numpy()
    steps_per_epoch = num_train_steps // config.dataset.num_epochs
    logging.info('num_train_steps=%d, steps_per_epoch=%d', num_train_steps,
                 steps_per_epoch)
    learning_rate_fn = functools.partial(
        train_utils.get_learning_rate,
        base_learning_rate=config.learning_rate,
        num_train_steps=num_train_steps,
        schedule_type=config.learning_rate_schedule,
        warmup_proportion=config.warmup_proportion,
        step_boundaries=config.learning_rate_step_boundaries)

    # Initialize model.
    inputs = train_utils.get_init_inputs(train_ds)
    rng, model_rng = jax.random.split(rng)
    eval_config = models.TransformerConfig(**config.model.to_dict())
    train_config = eval_config.replace(deterministic=False)
    model = models.Model(eval_config)
    state = train_utils.create_train_state(model,
                                           config,
                                           model_rng,
                                           inputs=inputs)

    # Set up checkpointing of the model and the input pipeline.
    checkpoint_dir = os.path.join(workdir, 'checkpoints')
    ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=3)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step) + 1

    # Distribute training.
    state = flax_utils.replicate(state)
    p_train_step = jax.pmap(functools.partial(
        train_step,
        config=train_config,
        learning_rate_fn=learning_rate_fn,
        grad_clip=config.grad_clip),
                            axis_name='batch',
                            donate_argnums=(0, ))
    p_eval_step = jax.pmap(functools.partial(eval_step, config=eval_config),
                           axis_name='batch')

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.process_index() > 0)
    if initial_step == 1:
        writer.write_hparams(train_utils.flatten_config(config))

    logging.info('Starting training loop at step %d.', initial_step)
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    if jax.process_index() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(
                num_profile_steps=config.num_profile_steps, logdir=workdir)
        ]

    rng, train_rngs = jax.random.split(rng)
    train_rngs = jax.random.fold_in(train_rngs, jax.process_index())
    train_rngs = jax.random.split(train_rngs, jax.local_device_count())

    train_metrics = []
    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            is_last_step = step == num_train_steps
            with jax.profiler.StepTraceContext('train', step_num=step):
                batch = jax.tree_map(np.asarray, next(train_iter))
                state, metrics = p_train_step(batch=batch,
                                              rng=train_rngs,
                                              state=state)
                train_metrics.append(metrics)

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, 'Finished training step %d.', 5,
                                step)
            for h in hooks:
                h(step)

            if config.log_loss_every_steps > 0 and (
                    step % config.log_loss_every_steps == 0 or is_last_step):
                train_metrics = common_utils.get_metrics(train_metrics)
                lr = train_metrics.pop('learning_rate').mean()
                train_summary = train_utils.metrics_summary(
                    train_metrics, 'train')
                train_summary['learning_rate'] = lr
                writer.write_scalars(step, train_summary)
                train_metrics = []

            if config.eval_every_steps > 0 and (step % config.eval_every_steps
                                                == 0 or is_last_step):
                with report_progress.timed('eval'):
                    eval_summary = evaluate(p_eval_step, state, eval_ds,
                                            config.num_eval_steps)
                writer.write_scalars(step, eval_summary)

            if config.checkpoint_every_steps > 0 and (
                    step % config.checkpoint_every_steps == 0 or is_last_step):
                with report_progress.timed('checkpoint'):
                    ckpt.save(flax_utils.unreplicate(state))
                logging.info('Checkpoint saved to %s', checkpoint_dir)

    logging.info('Finishing training at step %d', num_train_steps)
def training_loop(
    *,
    module,
    rng,
    train_ds,
    eval_ds,
    loss_fn,
    optimizer,
    train_metrics_dict,
    eval_metrics_dict,
    stats_aggregators,
    config,
    workdir,
):
    """Runs a training and evaluation loop.

  Args:
    module: The module that should be trained.
    rng: A jax pseudo-random number generator key.
    train_ds: Dataset used for training.
    eval_ds: Dataset used for evaluation.
    loss_fn: Loss function to use for training.
    optimizer: Optax optimizer to use for training.
    train_metrics_dict: Collection of metrics to be collected during training.
    eval_metrics_dict: Collection of metrics to be collected during evaluation.
    stats_aggregators: Dictionary of statistics aggregator functions to be run
      on the first evaluation batch. These functions ingest the stats returned
      by the model and output a Dict[str, image/scalar] that will be logged.
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.

  Raises:
    RuntimeError: If a training metric is NaN or inf.

  Returns:
    Training state.
  """
    rng, model_rng = jax.random.split(rng)
    input_shape = tuple(train_ds.element_spec["image"].shape[1:])
    model, init_params, init_state = create_model(module, input_shape,
                                                  model_rng)
    parameter_overview.log_parameter_overview(model.params)

    # Load a pretrained model parameters and state. Ignore the step and the
    # optimizer state in the checkpoint.
    pretrained_path = config.get("pretrained_checkpoint", "")
    if pretrained_path:
        logging.info("Load pretrained weights from '%s'", pretrained_path)
        state_dict = checkpoint.load_state_dict(pretrained_path)
        flatten_model_params = utils.flatten_dict(state_dict["model_params"],
                                                  sep="/")
        model_state = state_dict["model_state"]

        # A prefix can be used to replace only a subpart of the network (e.g the
        # encoder). Prepend the prefix (if any) to model parameters and states.
        prefix = config.get("pretrained_prefix", "")
        if prefix:
            flatten_model_params = utils.add_prefix_to_dict_keys(
                flatten_model_params, f"{prefix}/")
            model_state = utils.add_prefix_to_dict_keys(
                model_state, f"/{prefix}")

        # Merge the params/state from the checkpoint into the initial params/state.
        flatten_init_params = utils.flatten_dict(init_params, sep="/")
        flatten_init_params, ignored_params = utils.override_dict(
            flatten_init_params, flatten_model_params)
        init_params = utils.unflatten_dict(flatten_init_params, delimiter="/")
        init_state, _ = utils.override_dict(init_state, model_state)

        if ignored_params:
            logging.warning(
                "%d/%d parameters from the pretrained checkpoint "
                "were ignored: %s", len(ignored_params),
                len(flatten_init_params), ignored_params)

    optimizer_state = optimizer.init(init_params)

    state = TrainState(step=1,
                       model_params=init_params,
                       model_state=init_state,
                       optimizer_state=optimizer_state)  # type: ignore
    # Do not keep a copy of the initial model.
    del init_params, init_state, optimizer_state

    train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
    checkpoint_dir = os.path.join(workdir, "checkpoints")

    ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step)
    # Replicate our parameters.
    state = flax.jax_utils.replicate(state)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    step_timer = utils.StepTimer(batch_size=config.batch_size,
                                 initial_step=initial_step)

    # Write config to the summary files. This makes the hyperparameters available
    # in TensorBoard and makes comparison of runs with tensorboard/ easier.
    if initial_step == 1:
        writer.write_hparams(utils.flatten_dict(config.to_dict()))

    # Generate per-device PRNG keys for the training loop.
    rng, train_rng = jax.random.split(rng)
    train_rngs = jax.random.split(train_rng, jax.local_device_count())

    # Generate per-device PRNG keys for model evaluation.
    rng, eval_rng = jax.random.split(rng)
    eval_rngs = jax.random.split(eval_rng, jax.local_device_count())

    logging.info("Starting training loop at step %d.", initial_step)
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    train_metrics = utils.Means()

    do_eval_only = config.get("do_eval_only", False)
    if do_eval_only:
        config.num_train_steps = 1

    debug_enabled = config.get("debug", False)
    previous_grads = grads = None
    previous_updates = updates = None
    previous_state = None
    for step in range(initial_step, config.num_train_steps + 1):
        is_last_step = step == config.num_train_steps
        if debug_enabled:
            previous_grads = grads
            previous_updates = updates
            previous_state = state

        # Skip the training if only do the eval.
        if not do_eval_only:
            # Use ._numpy() to avoid copy.
            batch = jax.tree_map(lambda x: x._numpy(), next(train_iter))  # pylint: disable=protected-access
            state, grads, updates, metrics, training_stats, train_rngs = train_step(
                state, batch, module, loss_fn, optimizer, train_metrics_dict,
                train_rngs)
            train_metrics.append(flax.jax_utils.unreplicate(metrics))

            # Update topk temperature with linearly decreasing schedule if enabled.
            if (config.get("linear_decrease_perturbed_sigma", False) and
                    config.get("selection_method", "") == "perturbed-topk"):

                model_state = state.model_state.as_dict()

                if "/PatchNet_0" in model_state:
                    net_str = "/PatchNet_0"
                else:
                    net_str = "/"

                progress = step / config.num_train_steps
                sigma_multiplier = 1. - progress
                previous_mult = model_state[net_str]["sigma_mutiplier"]
                sigma_multiplier = sigma_multiplier + jnp.zeros_like(
                    previous_mult)
                model_state[net_str]["sigma_mutiplier"] = sigma_multiplier
                state = state.replace(model_state=nn.Collection(model_state))

            if debug_enabled:
                if utils.has_any_inf_or_nan(metrics):
                    # Save checkpoint
                    if previous_state:
                        ckpt.save(flax.jax_utils.unreplicate(previous_state))
                    ckpt.save(flax.jax_utils.unreplicate(state))

                    # Log gradients and updates.
                    if previous_grads or previous_updates:
                        write_gradient_histogram(writer,
                                                 step,
                                                 grads=previous_grads,
                                                 updates=previous_updates)
                    write_gradient_histogram(writer,
                                             step + 1,
                                             grads=grads,
                                             updates=updates)

                    raise RuntimeError(
                        "A training metric took an invalid value: "
                        f"{metrics}.")

            logging.log_first_n(logging.INFO, "Finished training step %d.", 3,
                                step)
            report_progress(step)

            if step % config.log_loss_every_steps == 0 or is_last_step:
                results = train_metrics.result()
                writer.write_scalars(step, results)
                writer.write_scalars(step, step_timer.get_and_reset(step))
                if utils.has_any_inf_or_nan(results):
                    raise ValueError(
                        "A training metric took an invalid value.")
                train_metrics.reset()

        if (step % config.checkpoint_every_steps == 0 or is_last_step):
            with step_timer.paused():
                ckpt.save(flax.jax_utils.unreplicate(state))

        # Evaluation
        if step % config.eval_every_steps == 0 or is_last_step:
            with step_timer.paused():
                eval_metrics, first_batch_stats, eval_rngs = evaluate(
                    state, module, eval_ds, eval_metrics_dict, eval_rngs)

            if jax.host_id() == 0:
                log_histograms = config.get("log_histograms", False)
                log_images = config.get("log_images", True)
                # Log the last gradients and updates histograms.
                if not do_eval_only:
                    write_stats_results(writer,
                                        step,
                                        training_stats,
                                        stats_aggregators,
                                        prefix="train/",
                                        log_images=log_images)
                    if log_histograms:
                        write_gradient_histogram(writer,
                                                 step,
                                                 grads=grads,
                                                 updates=updates)

                write_stats_results(writer,
                                    step,
                                    first_batch_stats,
                                    stats_aggregators,
                                    prefix="eval/",
                                    log_images=log_images)

                # write patch representation histograms
                if (log_histograms and first_batch_stats
                        and "patch_representations" in first_batch_stats):
                    patch_representations = first_batch_stats[
                        "patch_representations"]
                    writer.write_histograms(
                        step, {"patch_representations": patch_representations})

                if eval_metrics:
                    writer.write_scalars(step, eval_metrics)

    writer.flush()
    return state
Beispiel #7
0
def train_and_evaluate(config, workdir):
    """Runs a training and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """
    tf.io.gfile.makedirs(workdir)

    rng = jax.random.PRNGKey(config.seed)

    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.host_id())
    splits = input_pipeline.create_datasets(config, data_rng)
    num_classes = splits.info.features["label"].num_classes
    train_iter = iter(splits.train)  # pytype: disable=wrong-arg-types

    # Learning rate schedule.
    num_train_steps = config.num_train_steps
    if num_train_steps == -1:
        num_train_steps = splits.train.cardinality().numpy()
    steps_per_epoch = num_train_steps // config.num_epochs
    logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps,
                 steps_per_epoch)
    # We treat the learning rate in the config as the learning rate for batch size
    # 32 but scale it according to our batch size.
    global_batch_size = config.per_device_batch_size * jax.device_count()
    base_learning_rate = config.learning_rate * global_batch_size / 32.0
    learning_rate_fn = functools.partial(get_learning_rate,
                                         base_learning_rate=base_learning_rate,
                                         steps_per_epoch=steps_per_epoch,
                                         num_epochs=config.num_epochs,
                                         warmup_epochs=config.warmup_epochs)

    # Initialize model.
    rng, model_rng = jax.random.split(rng)
    model, state = create_train_state(
        config,
        model_rng,
        input_shape=splits.train.element_spec["input"].shape[1:],
        num_classes=num_classes)

    # Set up checkpointing of the model and the input pipeline.
    checkpoint_dir = os.path.join(workdir, "checkpoints")
    ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir,
                                          {"train_iter": train_iter},
                                          max_to_keep=2)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step) + 1

    # Count number of trainable parameters. This must be done before replicating
    # the state to avoid double-counting replicated parameters.
    param_count = sum(p.size for p in jax.tree_leaves(state.optimizer.target))

    # Distribute training over local devices.
    state = flax_utils.replicate(state)

    p_train_step = jax.pmap(functools.partial(
        train_step,
        model=model,
        learning_rate_fn=learning_rate_fn,
        weight_decay=config.weight_decay),
                            axis_name=_PMAP_AXIS_NAME)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    if initial_step == 1:
        writer.write_hparams(dict(config))
        # Log the number of trainable params.
        writer.write_scalars(initial_step, {"param_count": param_count})

    logging.info("Starting training loop at step %d.", initial_step)
    hooks = []
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=num_train_steps, writer=writer)
    if jax.host_id() == 0:
        hooks += [
            report_progress,
            periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
        ]
    train_metrics = None
    with metric_writers.ensure_flushes(writer):
        for step in range(initial_step, num_train_steps + 1):
            # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
            # devices.
            is_last_step = step == num_train_steps

            with jax.profiler.StepTraceContext("train", step_num=step):
                batch = jax.tree_map(np.asarray, next(train_iter))
                state, metrics_update = p_train_step(state=state, batch=batch)
                metric_update = flax_utils.unreplicate(metrics_update)
                train_metrics = (metric_update if train_metrics is None else
                                 train_metrics.merge(metric_update))

            # Quick indication that training is happening.
            logging.log_first_n(logging.INFO, "Finished training step %d.", 5,
                                step)
            for h in hooks:
                h(step)

            if step % config.log_loss_every_steps == 0 or is_last_step:
                writer.write_scalars(step, train_metrics.compute())
                train_metrics = None

            # When combining train and eval, we do not evaluate while training.
            if ((step % config.eval_every_steps == 0 or is_last_step)
                    and not config.combine_train_val_and_eval_on_test):
                with report_progress.timed("eval"):
                    eval_metrics = evaluate(model, state, splits.validation,
                                            config.num_eval_steps)
                writer.write_scalars(step, eval_metrics.compute())

            if step % config.checkpoint_every_steps == 0 or is_last_step:
                with report_progress.timed("checkpoint"):
                    ckpt.save(flax_utils.unreplicate(state))

            if is_last_step and config.combine_train_val_and_eval_on_test:
                # Evaluate a single time on the test set when requested.
                with report_progress.timed("test"):
                    test_metrics = evaluate(model, state, splits.test,
                                            config.num_eval_steps)
                writer.write_scalars(step, test_metrics.compute())

    logging.info("Finishing training at step %d", num_train_steps)
Beispiel #8
0
def predict_and_evaluate(config, workdir, ckpt_path=None):
    """Runs a testing and evaluation loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
    ckpt_path: The checkpoint to evaluate. If not specified, use the latest
      checkpoint.
  """
    logging.info('Starting testing at %s', workdir)
    tf.io.gfile.makedirs(workdir)

    rng = jax.random.PRNGKey(config.seed)
    # Build input pipeline.
    rng, data_rng = jax.random.split(rng)
    data_rng = jax.random.fold_in(data_rng, jax.process_index())
    test_ds = []
    for split in config.dataset.test_splits:
        ds = input_pipeline.create_val_dataset(
            config.dataset, split, config.dataset.test_per_device_batch_size,
            config.dataset.test_pad_last_batch)
        test_ds.append(ds)

    # Initialize model.
    inputs = train_utils.get_init_inputs(test_ds[0])
    rng, model_rng = jax.random.split(rng)
    predict_config = models.TransformerConfig(**config.model.to_dict())
    predict_config = predict_config.replace(decode=True)
    model = models.Model(predict_config)
    state = train_utils.create_train_state(model,
                                           config,
                                           model_rng,
                                           inputs=inputs)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.process_index() > 0)

    # Set up checkpointing of the model and the input pipeline.
    checkpoint_dir = os.path.join(workdir, 'checkpoints')
    ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir, max_to_keep=3)

    logging.info('Testing and evaluating checkpoint %s', ckpt_path)
    try:
        state = ckpt.restore(state, ckpt_path)
    except FileNotFoundError:
        state = ckpt.restore_or_initialize(state)
    step = int(state.step)

    p_pred_step = jax.pmap(functools.partial(predict_step,
                                             config=predict_config),
                           axis_name='batch',
                           static_broadcasted_argnums=(3, ))
    p_init_cache = jax.pmap(functools.partial(init_cache,
                                              config=predict_config),
                            axis_name='batch')

    # Distribute testing.
    state = flax_utils.replicate(state)
    with metric_writers.ensure_flushes(writer):
        test_metrics = {}
        for ds, split in zip(test_ds, config.dataset.test_splits):
            ds_metrics = evaluate_sequence_accuracy(p_pred_step, p_init_cache,
                                                    state, ds, config, split,
                                                    workdir,
                                                    config.num_test_steps)
            ds_metrics = {f'{k}_{split}': v for k, v in ds_metrics.items()}
            test_metrics.update(ds_metrics)
        writer.write_scalars(step, test_metrics)