Example #1
0
 def test_collection_multiple_roots(self):
   rng = random.PRNGKey(0)
   with nn.Collection().mutate() as activations:
     x = jnp.array([1.])
     LoopModule.init(rng, x, activations, name='a')
     LoopModule.init(rng, x, activations, name='b')
   expected_state = {
       '/a/dummy': jnp.array([3.]),
       '/b/dummy': jnp.array([3.]),
   }
   self.assertEqual(activations.state, expected_state)
   with self.assertRaises(ValueError):
     with nn.Collection().mutate() as activations:
       x = jnp.array([1.])
       LoopModule.init(rng, x, activations)
       LoopModule.init(rng, x, activations)
Example #2
0
    def test_mutable_collection_cannot_be_passed_to_jax(self):
        with nn.Collection().mutate() as collection:

            def fn(col):
                return col

            with self.assertRaises(ValueError):
                jax.jit(fn)(collection)
Example #3
0
 def test_collection_multiple_calls(self):
   rng = random.PRNGKey(0)
   with nn.Collection().mutate() as activations:
     x = jnp.array([1.])
     _, _ = LoopModule.init(rng, x, activations)
   expected_state = {
       '/dummy': jnp.array([3.]),
   }
   self.assertEqual(activations.state, expected_state)
Example #4
0
 def test_collection_lookup(self):
   state = {
       '/dummy/sub': 1,
   }
   collection = nn.Collection(state=state)
   root = nn.base._ModuleFrame(None)
   frame = nn.base._ModuleFrame('dummy', parent=root)
   with nn.base._module_stack.frame(root):
     with nn.base._module_stack.frame(frame):
       self.assertEqual(collection['/dummy/sub'], 1)
Example #5
0
 def test_collection_store_and_retrieve(self):
   rng = random.PRNGKey(0)
   x = jnp.array([1.])
   with nn.Collection().mutate() as activations:
     (_, y), initial_params = CollectionModule.init(rng, x, activations)
     model = nn.Model(CollectionModule, initial_params)
     self.assertEqual(y, None)
   with activations.mutate() as new_activations:
     _, y2 = model(x, new_activations)
   self.assertEqual(y2, jnp.array([2.]))
Example #6
0
    def test_module_state(self):
        class StatefulModule(nn.Module):
            def apply(self, x, coll=None):
                state = self.state('state',
                                   x.shape,
                                   nn.initializers.zeros,
                                   collection=coll)
                state.value += x

        x = jnp.array([
            1.,
        ])
        # no collection should raise an error
        with self.assertRaises(ValueError):
            StatefulModule.call({}, x)

        # pass collection explicitly
        with nn.Collection().mutate() as state:
            self.assertEqual(state.as_dict(), {})
            StatefulModule.init(random.PRNGKey(0), x, state)
            self.assertEqual(state.as_dict(), {'/': {'state': x}})
        self.assertEqual(state.as_dict(), {'/': {'state': x}})
        with state.mutate() as new_state:
            # assert new_state is a clone of state
            self.assertEqual(new_state.as_dict(), state.as_dict())
            StatefulModule.call({}, x, new_state)
        self.assertEqual(new_state.as_dict(), {'/': {'state': x + x}})

        # use stateful
        with nn.stateful() as state:
            self.assertEqual(state.as_dict(), {})
            StatefulModule.init(random.PRNGKey(0), x)
        self.assertEqual(state.as_dict(), {'/': {'state': x}})
        with nn.stateful(state) as new_state:
            # assert new_state is a clone of state
            self.assertEqual(new_state.as_dict(), state.as_dict())
            StatefulModule.call({}, x)
            self.assertEqual(new_state.as_dict(), {'/': {'state': x + x}})
        self.assertEqual(new_state.as_dict(), {'/': {'state': x + x}})
Example #7
0
 def outer_fn(x):
   with nn.Collection().mutate() as state:
     NestedTransform.init(random.PRNGKey(0), state, x)
Example #8
0
 def test_inner(f):
   with nn.Collection().mutate() as coll:
     # this should fail because f is a shared module defined
     # in the parent. Therefore we cannot capture in the scope
     # of this Module.
     f(coll)
Example #9
0
 def test():
   with nn.Collection().mutate() as coll:
     coll.store(1)
Example #10
0
 def apply(self, x):
   with nn.Collection().mutate() as activations:
     LoopModule(x, activations, name='a')
     LoopModule(x, activations, name='b')
   return activations
def training_loop(
    *,
    module,
    rng,
    train_ds,
    eval_ds,
    loss_fn,
    optimizer,
    train_metrics_dict,
    eval_metrics_dict,
    stats_aggregators,
    config,
    workdir,
):
    """Runs a training and evaluation loop.

  Args:
    module: The module that should be trained.
    rng: A jax pseudo-random number generator key.
    train_ds: Dataset used for training.
    eval_ds: Dataset used for evaluation.
    loss_fn: Loss function to use for training.
    optimizer: Optax optimizer to use for training.
    train_metrics_dict: Collection of metrics to be collected during training.
    eval_metrics_dict: Collection of metrics to be collected during evaluation.
    stats_aggregators: Dictionary of statistics aggregator functions to be run
      on the first evaluation batch. These functions ingest the stats returned
      by the model and output a Dict[str, image/scalar] that will be logged.
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.

  Raises:
    RuntimeError: If a training metric is NaN or inf.

  Returns:
    Training state.
  """
    rng, model_rng = jax.random.split(rng)
    input_shape = tuple(train_ds.element_spec["image"].shape[1:])
    model, init_params, init_state = create_model(module, input_shape,
                                                  model_rng)
    parameter_overview.log_parameter_overview(model.params)

    # Load a pretrained model parameters and state. Ignore the step and the
    # optimizer state in the checkpoint.
    pretrained_path = config.get("pretrained_checkpoint", "")
    if pretrained_path:
        logging.info("Load pretrained weights from '%s'", pretrained_path)
        state_dict = checkpoint.load_state_dict(pretrained_path)
        flatten_model_params = utils.flatten_dict(state_dict["model_params"],
                                                  sep="/")
        model_state = state_dict["model_state"]

        # A prefix can be used to replace only a subpart of the network (e.g the
        # encoder). Prepend the prefix (if any) to model parameters and states.
        prefix = config.get("pretrained_prefix", "")
        if prefix:
            flatten_model_params = utils.add_prefix_to_dict_keys(
                flatten_model_params, f"{prefix}/")
            model_state = utils.add_prefix_to_dict_keys(
                model_state, f"/{prefix}")

        # Merge the params/state from the checkpoint into the initial params/state.
        flatten_init_params = utils.flatten_dict(init_params, sep="/")
        flatten_init_params, ignored_params = utils.override_dict(
            flatten_init_params, flatten_model_params)
        init_params = utils.unflatten_dict(flatten_init_params, delimiter="/")
        init_state, _ = utils.override_dict(init_state, model_state)

        if ignored_params:
            logging.warning(
                "%d/%d parameters from the pretrained checkpoint "
                "were ignored: %s", len(ignored_params),
                len(flatten_init_params), ignored_params)

    optimizer_state = optimizer.init(init_params)

    state = TrainState(step=1,
                       model_params=init_params,
                       model_state=init_state,
                       optimizer_state=optimizer_state)  # type: ignore
    # Do not keep a copy of the initial model.
    del init_params, init_state, optimizer_state

    train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
    checkpoint_dir = os.path.join(workdir, "checkpoints")

    ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step)
    # Replicate our parameters.
    state = flax.jax_utils.replicate(state)

    writer = metric_writers.create_default_writer(
        workdir, just_logging=jax.host_id() > 0)
    step_timer = utils.StepTimer(batch_size=config.batch_size,
                                 initial_step=initial_step)

    # Write config to the summary files. This makes the hyperparameters available
    # in TensorBoard and makes comparison of runs with tensorboard/ easier.
    if initial_step == 1:
        writer.write_hparams(utils.flatten_dict(config.to_dict()))

    # Generate per-device PRNG keys for the training loop.
    rng, train_rng = jax.random.split(rng)
    train_rngs = jax.random.split(train_rng, jax.local_device_count())

    # Generate per-device PRNG keys for model evaluation.
    rng, eval_rng = jax.random.split(rng)
    eval_rngs = jax.random.split(eval_rng, jax.local_device_count())

    logging.info("Starting training loop at step %d.", initial_step)
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    train_metrics = utils.Means()

    do_eval_only = config.get("do_eval_only", False)
    if do_eval_only:
        config.num_train_steps = 1

    debug_enabled = config.get("debug", False)
    previous_grads = grads = None
    previous_updates = updates = None
    previous_state = None
    for step in range(initial_step, config.num_train_steps + 1):
        is_last_step = step == config.num_train_steps
        if debug_enabled:
            previous_grads = grads
            previous_updates = updates
            previous_state = state

        # Skip the training if only do the eval.
        if not do_eval_only:
            # Use ._numpy() to avoid copy.
            batch = jax.tree_map(lambda x: x._numpy(), next(train_iter))  # pylint: disable=protected-access
            state, grads, updates, metrics, training_stats, train_rngs = train_step(
                state, batch, module, loss_fn, optimizer, train_metrics_dict,
                train_rngs)
            train_metrics.append(flax.jax_utils.unreplicate(metrics))

            # Update topk temperature with linearly decreasing schedule if enabled.
            if (config.get("linear_decrease_perturbed_sigma", False) and
                    config.get("selection_method", "") == "perturbed-topk"):

                model_state = state.model_state.as_dict()

                if "/PatchNet_0" in model_state:
                    net_str = "/PatchNet_0"
                else:
                    net_str = "/"

                progress = step / config.num_train_steps
                sigma_multiplier = 1. - progress
                previous_mult = model_state[net_str]["sigma_mutiplier"]
                sigma_multiplier = sigma_multiplier + jnp.zeros_like(
                    previous_mult)
                model_state[net_str]["sigma_mutiplier"] = sigma_multiplier
                state = state.replace(model_state=nn.Collection(model_state))

            if debug_enabled:
                if utils.has_any_inf_or_nan(metrics):
                    # Save checkpoint
                    if previous_state:
                        ckpt.save(flax.jax_utils.unreplicate(previous_state))
                    ckpt.save(flax.jax_utils.unreplicate(state))

                    # Log gradients and updates.
                    if previous_grads or previous_updates:
                        write_gradient_histogram(writer,
                                                 step,
                                                 grads=previous_grads,
                                                 updates=previous_updates)
                    write_gradient_histogram(writer,
                                             step + 1,
                                             grads=grads,
                                             updates=updates)

                    raise RuntimeError(
                        "A training metric took an invalid value: "
                        f"{metrics}.")

            logging.log_first_n(logging.INFO, "Finished training step %d.", 3,
                                step)
            report_progress(step)

            if step % config.log_loss_every_steps == 0 or is_last_step:
                results = train_metrics.result()
                writer.write_scalars(step, results)
                writer.write_scalars(step, step_timer.get_and_reset(step))
                if utils.has_any_inf_or_nan(results):
                    raise ValueError(
                        "A training metric took an invalid value.")
                train_metrics.reset()

        if (step % config.checkpoint_every_steps == 0 or is_last_step):
            with step_timer.paused():
                ckpt.save(flax.jax_utils.unreplicate(state))

        # Evaluation
        if step % config.eval_every_steps == 0 or is_last_step:
            with step_timer.paused():
                eval_metrics, first_batch_stats, eval_rngs = evaluate(
                    state, module, eval_ds, eval_metrics_dict, eval_rngs)

            if jax.host_id() == 0:
                log_histograms = config.get("log_histograms", False)
                log_images = config.get("log_images", True)
                # Log the last gradients and updates histograms.
                if not do_eval_only:
                    write_stats_results(writer,
                                        step,
                                        training_stats,
                                        stats_aggregators,
                                        prefix="train/",
                                        log_images=log_images)
                    if log_histograms:
                        write_gradient_histogram(writer,
                                                 step,
                                                 grads=grads,
                                                 updates=updates)

                write_stats_results(writer,
                                    step,
                                    first_batch_stats,
                                    stats_aggregators,
                                    prefix="eval/",
                                    log_images=log_images)

                # write patch representation histograms
                if (log_histograms and first_batch_stats
                        and "patch_representations" in first_batch_stats):
                    patch_representations = first_batch_stats[
                        "patch_representations"]
                    writer.write_histograms(
                        step, {"patch_representations": patch_representations})

                if eval_metrics:
                    writer.write_scalars(step, eval_metrics)

    writer.flush()
    return state