def set_train_state(self, model_cls, rng): """Set up train state. Args: model_cls: Type of the flax module. rng: Jax PRNG. """ # Build flax_model. self.hparams.output_dim = self.task.task_params.output_dim flax_module, self.hparams = model_cls.build_flax_module(self.hparams) # Initialize flax module. rng, dropout_rng = jax.random.split(rng) (flax_module, model_state, self.num_trainable_params) = pipeline_utils.create_flax_module( flax_module, self.task.dataset.meta_data['input_shape'], self.hparams, dropout_rng, self.task.dataset.meta_data.get('input_dtype', jnp.float32)) if self.hparams.get('pretrained', None): pretrained_config = self.hparams.pretrained.get('config') pretrained_checkpoint_path = self.hparams.pretrained.get( 'checkpoint_path') pretrained_checkpoint_step = self.hparams.pretrained.get( 'checkpoint_step', None) rng, new_rng = jax.random.split(rng) # Create and loads the model from the pretrained path. if pretrained_checkpoint_step is not None: logging.info('load pretrained model at step %d', pretrained_checkpoint_step) pretrained_train_state = pipeline_utils.load_model( rng=new_rng, model_config=pretrained_config, model_ckpt=pretrained_checkpoint_path, task=self.task, load_full_train_state=self.hparams.pretrained.get( 'full_trainstate_ckpt', True), checkpoint_step=pretrained_checkpoint_step) if self.hparams.pretrained.get('full_trainstate_ckpt', True): pretrained_model = pretrained_train_state.optimizer.target pretrained_model_state = pretrained_train_state.model_state else: (pretrained_model, pretrained_model_state) = pretrained_train_state if self.hparams.pretrained.get('only_backbone_pretrained', False): # Update params with pretrained params for m_key, m_params in pretrained_model.params.items(): logging.info(m_key) if m_key not in ['head'] and ('disc' not in m_key): flax_module.params[m_key] = m_params else: logging.info('Not updated!') # Update model_state with pretrained model_state new_state_dict = {} for state_key, state_val in pretrained_model_state.as_dict( ).items(): logging.info(state_key) if 'head' not in state_key and ('disc' not in state_key): new_state_dict[state_key] = pretrained_model_state[ state_key] else: logging.info('Not updated!') new_state_dict[state_key] = state_val model_state = nn.Collection(new_state_dict) else: flax_module = pretrained_model model_state = pretrained_model_state # Create optimizer. optimizer = optimizers.get_optimizer(self.hparams).create(flax_module) # Create train state. rng, train_rng = jax.random.split(rng) train_state = pipeline_utils.TrainState(global_step=0, optimizer=optimizer, model_state=model_state, rng=train_rng) self.start_step = train_state.global_step # Reset gift regularizer's init point. if self.hparams.get('gift_factor', None): self.task.regularisers = [ functools.partial( metrics.parameter_distance, base_params=train_state.optimizer.target.params, norm_factor=self.hparams.get('gift_factor'), mode='l2') ] if self.hparams.checkpoint: train_state, self.start_step = pipeline_utils.restore_checkpoint( self.experiment_dir, train_state) logging.info('Loading checkpoint at step %d', self.start_step) # Replicate the optimzier, state, and rng. self.train_state = jax_utils.replicate(train_state) del flax_module # do not keep a copy of the initial model # Save the initial state. if self.start_step == 0 and self.hparams.checkpoint: self.checkpoint(self.train_state, self.start_step)
def training_loop( *, module, rng, train_ds, eval_ds, loss_fn, optimizer, train_metrics_dict, eval_metrics_dict, stats_aggregators, config, workdir, ): """Runs a training and evaluation loop. Args: module: The module that should be trained. rng: A jax pseudo-random number generator key. train_ds: Dataset used for training. eval_ds: Dataset used for evaluation. loss_fn: Loss function to use for training. optimizer: Optax optimizer to use for training. train_metrics_dict: Collection of metrics to be collected during training. eval_metrics_dict: Collection of metrics to be collected during evaluation. stats_aggregators: Dictionary of statistics aggregator functions to be run on the first evaluation batch. These functions ingest the stats returned by the model and output a Dict[str, image/scalar] that will be logged. config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. Raises: RuntimeError: If a training metric is NaN or inf. Returns: Training state. """ rng, model_rng = jax.random.split(rng) input_shape = tuple(train_ds.element_spec["image"].shape[1:]) model, init_params, init_state = create_model(module, input_shape, model_rng) parameter_overview.log_parameter_overview(model.params) # Load a pretrained model parameters and state. Ignore the step and the # optimizer state in the checkpoint. pretrained_path = config.get("pretrained_checkpoint", "") if pretrained_path: logging.info("Load pretrained weights from '%s'", pretrained_path) state_dict = checkpoint.load_state_dict(pretrained_path) flatten_model_params = utils.flatten_dict(state_dict["model_params"], sep="/") model_state = state_dict["model_state"] # A prefix can be used to replace only a subpart of the network (e.g the # encoder). Prepend the prefix (if any) to model parameters and states. prefix = config.get("pretrained_prefix", "") if prefix: flatten_model_params = utils.add_prefix_to_dict_keys( flatten_model_params, f"{prefix}/") model_state = utils.add_prefix_to_dict_keys( model_state, f"/{prefix}") # Merge the params/state from the checkpoint into the initial params/state. flatten_init_params = utils.flatten_dict(init_params, sep="/") flatten_init_params, ignored_params = utils.override_dict( flatten_init_params, flatten_model_params) init_params = utils.unflatten_dict(flatten_init_params, delimiter="/") init_state, _ = utils.override_dict(init_state, model_state) if ignored_params: logging.warning("%d/%d parameters from the pretrained checkpoint " "were ignored: %s", len(ignored_params), len(flatten_init_params), ignored_params) optimizer_state = optimizer.init(init_params) state = TrainState( step=1, model_params=init_params, model_state=init_state, optimizer_state=optimizer_state) # type: ignore # Do not keep a copy of the initial model. del init_params, init_state, optimizer_state train_iter = iter(train_ds) # pytype: disable=wrong-arg-types checkpoint_dir = os.path.join(workdir, "checkpoints") ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir) state = ckpt.restore_or_initialize(state) initial_step = int(state.step) # Replicate our parameters. state = flax.jax_utils.replicate(state) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) step_timer = utils.StepTimer( batch_size=config.batch_size, initial_step=initial_step) # Write config to the summary files. This makes the hyperparameters available # in TensorBoard and makes comparison of runs with tensorboard/ easier. if initial_step == 1: writer.write_hparams(utils.flatten_dict(config.to_dict())) # Generate per-device PRNG keys for the training loop. rng, train_rng = jax.random.split(rng) train_rngs = jax.random.split(train_rng, jax.local_device_count()) # Generate per-device PRNG keys for model evaluation. rng, eval_rng = jax.random.split(rng) eval_rngs = jax.random.split(eval_rng, jax.local_device_count()) logging.info("Starting training loop at step %d.", initial_step) report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) train_metrics = utils.Means() do_eval_only = config.get("do_eval_only", False) if do_eval_only: config.num_train_steps = 1 debug_enabled = config.get("debug", False) previous_grads = grads = None previous_updates = updates = None previous_state = None for step in range(initial_step, config.num_train_steps + 1): is_last_step = step == config.num_train_steps if debug_enabled: previous_grads = grads previous_updates = updates previous_state = state # Skip the training if only do the eval. if not do_eval_only: # Use ._numpy() to avoid copy. batch = jax.tree_map(lambda x: x._numpy(), next(train_iter)) # pylint: disable=protected-access state, grads, updates, metrics, training_stats, train_rngs = train_step( state, batch, module, loss_fn, optimizer, train_metrics_dict, train_rngs) train_metrics.append(flax.jax_utils.unreplicate(metrics)) # Update topk temperature with linearly decreasing schedule if enabled. if (config.get("linear_decrease_perturbed_sigma", False) and config.get("selection_method", "") == "perturbed-topk"): model_state = state.model_state.as_dict() if "/PatchNet_0" in model_state: net_str = "/PatchNet_0" else: net_str = "/" progress = step / config.num_train_steps sigma_multiplier = 1. - progress previous_mult = model_state[net_str]["sigma_mutiplier"] sigma_multiplier = sigma_multiplier + jnp.zeros_like(previous_mult) model_state[net_str]["sigma_mutiplier"] = sigma_multiplier state = state.replace(model_state=nn.Collection(model_state)) if debug_enabled: if utils.has_any_inf_or_nan(metrics): # Save checkpoint if previous_state: ckpt.save(flax.jax_utils.unreplicate(previous_state)) ckpt.save(flax.jax_utils.unreplicate(state)) # Log gradients and updates. if previous_grads or previous_updates: write_gradient_histogram(writer, step, grads=previous_grads, updates=previous_updates) write_gradient_histogram(writer, step + 1, grads=grads, updates=updates) raise RuntimeError("A training metric took an invalid value: " f"{metrics}.") logging.log_first_n(logging.INFO, "Finished training step %d.", 3, step) report_progress(step) if step % config.log_loss_every_steps == 0 or is_last_step: results = train_metrics.result() writer.write_scalars(step, results) writer.write_scalars(step, step_timer.get_and_reset(step)) if utils.has_any_inf_or_nan(results): raise ValueError("A training metric took an invalid value.") train_metrics.reset() if (step % config.checkpoint_every_steps == 0 or is_last_step): with step_timer.paused(): ckpt.save(flax.jax_utils.unreplicate(state)) # Evaluation if step % config.eval_every_steps == 0 or is_last_step: with step_timer.paused(): eval_metrics, first_batch_stats, eval_rngs = evaluate( state, module, eval_ds, eval_metrics_dict, eval_rngs) if jax.host_id() == 0: log_histograms = config.get("log_histograms", False) log_images = config.get("log_images", True) # Log the last gradients and updates histograms. if not do_eval_only: write_stats_results(writer, step, training_stats, stats_aggregators, prefix="train/", log_images=log_images) if log_histograms: write_gradient_histogram(writer, step, grads=grads, updates=updates) write_stats_results(writer, step, first_batch_stats, stats_aggregators, prefix="eval/", log_images=log_images) # write patch representation histograms if (log_histograms and first_batch_stats and "patch_representations" in first_batch_stats): patch_representations = first_batch_stats["patch_representations"] writer.write_histograms(step, { "patch_representations": patch_representations }) if eval_metrics: writer.write_scalars(step, eval_metrics) writer.flush() return state