def test_collection_multiple_roots(self): rng = random.PRNGKey(0) with nn.Collection().mutate() as activations: x = jnp.array([1.]) LoopModule.init(rng, x, activations, name='a') LoopModule.init(rng, x, activations, name='b') expected_state = { '/a/dummy': jnp.array([3.]), '/b/dummy': jnp.array([3.]), } self.assertEqual(activations.state, expected_state) with self.assertRaises(ValueError): with nn.Collection().mutate() as activations: x = jnp.array([1.]) LoopModule.init(rng, x, activations) LoopModule.init(rng, x, activations)
def test_mutable_collection_cannot_be_passed_to_jax(self): with nn.Collection().mutate() as collection: def fn(col): return col with self.assertRaises(ValueError): jax.jit(fn)(collection)
def test_collection_multiple_calls(self): rng = random.PRNGKey(0) with nn.Collection().mutate() as activations: x = jnp.array([1.]) _, _ = LoopModule.init(rng, x, activations) expected_state = { '/dummy': jnp.array([3.]), } self.assertEqual(activations.state, expected_state)
def test_collection_lookup(self): state = { '/dummy/sub': 1, } collection = nn.Collection(state=state) root = nn.base._ModuleFrame(None) frame = nn.base._ModuleFrame('dummy', parent=root) with nn.base._module_stack.frame(root): with nn.base._module_stack.frame(frame): self.assertEqual(collection['/dummy/sub'], 1)
def test_collection_store_and_retrieve(self): rng = random.PRNGKey(0) x = jnp.array([1.]) with nn.Collection().mutate() as activations: (_, y), initial_params = CollectionModule.init(rng, x, activations) model = nn.Model(CollectionModule, initial_params) self.assertEqual(y, None) with activations.mutate() as new_activations: _, y2 = model(x, new_activations) self.assertEqual(y2, jnp.array([2.]))
def test_module_state(self): class StatefulModule(nn.Module): def apply(self, x, coll=None): state = self.state('state', x.shape, nn.initializers.zeros, collection=coll) state.value += x x = jnp.array([ 1., ]) # no collection should raise an error with self.assertRaises(ValueError): StatefulModule.call({}, x) # pass collection explicitly with nn.Collection().mutate() as state: self.assertEqual(state.as_dict(), {}) StatefulModule.init(random.PRNGKey(0), x, state) self.assertEqual(state.as_dict(), {'/': {'state': x}}) self.assertEqual(state.as_dict(), {'/': {'state': x}}) with state.mutate() as new_state: # assert new_state is a clone of state self.assertEqual(new_state.as_dict(), state.as_dict()) StatefulModule.call({}, x, new_state) self.assertEqual(new_state.as_dict(), {'/': {'state': x + x}}) # use stateful with nn.stateful() as state: self.assertEqual(state.as_dict(), {}) StatefulModule.init(random.PRNGKey(0), x) self.assertEqual(state.as_dict(), {'/': {'state': x}}) with nn.stateful(state) as new_state: # assert new_state is a clone of state self.assertEqual(new_state.as_dict(), state.as_dict()) StatefulModule.call({}, x) self.assertEqual(new_state.as_dict(), {'/': {'state': x + x}}) self.assertEqual(new_state.as_dict(), {'/': {'state': x + x}})
def outer_fn(x): with nn.Collection().mutate() as state: NestedTransform.init(random.PRNGKey(0), state, x)
def test_inner(f): with nn.Collection().mutate() as coll: # this should fail because f is a shared module defined # in the parent. Therefore we cannot capture in the scope # of this Module. f(coll)
def test(): with nn.Collection().mutate() as coll: coll.store(1)
def apply(self, x): with nn.Collection().mutate() as activations: LoopModule(x, activations, name='a') LoopModule(x, activations, name='b') return activations
def training_loop( *, module, rng, train_ds, eval_ds, loss_fn, optimizer, train_metrics_dict, eval_metrics_dict, stats_aggregators, config, workdir, ): """Runs a training and evaluation loop. Args: module: The module that should be trained. rng: A jax pseudo-random number generator key. train_ds: Dataset used for training. eval_ds: Dataset used for evaluation. loss_fn: Loss function to use for training. optimizer: Optax optimizer to use for training. train_metrics_dict: Collection of metrics to be collected during training. eval_metrics_dict: Collection of metrics to be collected during evaluation. stats_aggregators: Dictionary of statistics aggregator functions to be run on the first evaluation batch. These functions ingest the stats returned by the model and output a Dict[str, image/scalar] that will be logged. config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. Raises: RuntimeError: If a training metric is NaN or inf. Returns: Training state. """ rng, model_rng = jax.random.split(rng) input_shape = tuple(train_ds.element_spec["image"].shape[1:]) model, init_params, init_state = create_model(module, input_shape, model_rng) parameter_overview.log_parameter_overview(model.params) # Load a pretrained model parameters and state. Ignore the step and the # optimizer state in the checkpoint. pretrained_path = config.get("pretrained_checkpoint", "") if pretrained_path: logging.info("Load pretrained weights from '%s'", pretrained_path) state_dict = checkpoint.load_state_dict(pretrained_path) flatten_model_params = utils.flatten_dict(state_dict["model_params"], sep="/") model_state = state_dict["model_state"] # A prefix can be used to replace only a subpart of the network (e.g the # encoder). Prepend the prefix (if any) to model parameters and states. prefix = config.get("pretrained_prefix", "") if prefix: flatten_model_params = utils.add_prefix_to_dict_keys( flatten_model_params, f"{prefix}/") model_state = utils.add_prefix_to_dict_keys( model_state, f"/{prefix}") # Merge the params/state from the checkpoint into the initial params/state. flatten_init_params = utils.flatten_dict(init_params, sep="/") flatten_init_params, ignored_params = utils.override_dict( flatten_init_params, flatten_model_params) init_params = utils.unflatten_dict(flatten_init_params, delimiter="/") init_state, _ = utils.override_dict(init_state, model_state) if ignored_params: logging.warning( "%d/%d parameters from the pretrained checkpoint " "were ignored: %s", len(ignored_params), len(flatten_init_params), ignored_params) optimizer_state = optimizer.init(init_params) state = TrainState(step=1, model_params=init_params, model_state=init_state, optimizer_state=optimizer_state) # type: ignore # Do not keep a copy of the initial model. del init_params, init_state, optimizer_state train_iter = iter(train_ds) # pytype: disable=wrong-arg-types checkpoint_dir = os.path.join(workdir, "checkpoints") ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) # Replicate our parameters. state = flax.jax_utils.replicate(state) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) step_timer = utils.StepTimer(batch_size=config.batch_size, initial_step=initial_step) # Write config to the summary files. This makes the hyperparameters available # in TensorBoard and makes comparison of runs with tensorboard/ easier. if initial_step == 1: writer.write_hparams(utils.flatten_dict(config.to_dict())) # Generate per-device PRNG keys for the training loop. rng, train_rng = jax.random.split(rng) train_rngs = jax.random.split(train_rng, jax.local_device_count()) # Generate per-device PRNG keys for model evaluation. rng, eval_rng = jax.random.split(rng) eval_rngs = jax.random.split(eval_rng, jax.local_device_count()) logging.info("Starting training loop at step %d.", initial_step) report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) train_metrics = utils.Means() do_eval_only = config.get("do_eval_only", False) if do_eval_only: config.num_train_steps = 1 debug_enabled = config.get("debug", False) previous_grads = grads = None previous_updates = updates = None previous_state = None for step in range(initial_step, config.num_train_steps + 1): is_last_step = step == config.num_train_steps if debug_enabled: previous_grads = grads previous_updates = updates previous_state = state # Skip the training if only do the eval. if not do_eval_only: # Use ._numpy() to avoid copy. batch = jax.tree_map(lambda x: x._numpy(), next(train_iter)) # pylint: disable=protected-access state, grads, updates, metrics, training_stats, train_rngs = train_step( state, batch, module, loss_fn, optimizer, train_metrics_dict, train_rngs) train_metrics.append(flax.jax_utils.unreplicate(metrics)) # Update topk temperature with linearly decreasing schedule if enabled. if (config.get("linear_decrease_perturbed_sigma", False) and config.get("selection_method", "") == "perturbed-topk"): model_state = state.model_state.as_dict() if "/PatchNet_0" in model_state: net_str = "/PatchNet_0" else: net_str = "/" progress = step / config.num_train_steps sigma_multiplier = 1. - progress previous_mult = model_state[net_str]["sigma_mutiplier"] sigma_multiplier = sigma_multiplier + jnp.zeros_like( previous_mult) model_state[net_str]["sigma_mutiplier"] = sigma_multiplier state = state.replace(model_state=nn.Collection(model_state)) if debug_enabled: if utils.has_any_inf_or_nan(metrics): # Save checkpoint if previous_state: ckpt.save(flax.jax_utils.unreplicate(previous_state)) ckpt.save(flax.jax_utils.unreplicate(state)) # Log gradients and updates. if previous_grads or previous_updates: write_gradient_histogram(writer, step, grads=previous_grads, updates=previous_updates) write_gradient_histogram(writer, step + 1, grads=grads, updates=updates) raise RuntimeError( "A training metric took an invalid value: " f"{metrics}.") logging.log_first_n(logging.INFO, "Finished training step %d.", 3, step) report_progress(step) if step % config.log_loss_every_steps == 0 or is_last_step: results = train_metrics.result() writer.write_scalars(step, results) writer.write_scalars(step, step_timer.get_and_reset(step)) if utils.has_any_inf_or_nan(results): raise ValueError( "A training metric took an invalid value.") train_metrics.reset() if (step % config.checkpoint_every_steps == 0 or is_last_step): with step_timer.paused(): ckpt.save(flax.jax_utils.unreplicate(state)) # Evaluation if step % config.eval_every_steps == 0 or is_last_step: with step_timer.paused(): eval_metrics, first_batch_stats, eval_rngs = evaluate( state, module, eval_ds, eval_metrics_dict, eval_rngs) if jax.host_id() == 0: log_histograms = config.get("log_histograms", False) log_images = config.get("log_images", True) # Log the last gradients and updates histograms. if not do_eval_only: write_stats_results(writer, step, training_stats, stats_aggregators, prefix="train/", log_images=log_images) if log_histograms: write_gradient_histogram(writer, step, grads=grads, updates=updates) write_stats_results(writer, step, first_batch_stats, stats_aggregators, prefix="eval/", log_images=log_images) # write patch representation histograms if (log_histograms and first_batch_stats and "patch_representations" in first_batch_stats): patch_representations = first_batch_stats[ "patch_representations"] writer.write_histograms( step, {"patch_representations": patch_representations}) if eval_metrics: writer.write_scalars(step, eval_metrics) writer.flush() return state