Exemplo n.º 1
0
    def set_train_state(self, model_cls, rng):
        """Set up train state.

    Args:
      model_cls: Type of the flax module.
      rng: Jax PRNG.
    """
        # Build flax_model.
        self.hparams.output_dim = self.task.task_params.output_dim
        flax_module, self.hparams = model_cls.build_flax_module(self.hparams)

        # Initialize flax module.
        rng, dropout_rng = jax.random.split(rng)
        (flax_module, model_state,
         self.num_trainable_params) = pipeline_utils.create_flax_module(
             flax_module, self.task.dataset.meta_data['input_shape'],
             self.hparams, dropout_rng,
             self.task.dataset.meta_data.get('input_dtype', jnp.float32))

        if self.hparams.get('pretrained', None):
            pretrained_config = self.hparams.pretrained.get('config')
            pretrained_checkpoint_path = self.hparams.pretrained.get(
                'checkpoint_path')
            pretrained_checkpoint_step = self.hparams.pretrained.get(
                'checkpoint_step', None)

            rng, new_rng = jax.random.split(rng)
            # Create and loads the model from the pretrained path.
            if pretrained_checkpoint_step is not None:
                logging.info('load pretrained model at step %d',
                             pretrained_checkpoint_step)
            pretrained_train_state = pipeline_utils.load_model(
                rng=new_rng,
                model_config=pretrained_config,
                model_ckpt=pretrained_checkpoint_path,
                task=self.task,
                load_full_train_state=self.hparams.pretrained.get(
                    'full_trainstate_ckpt', True),
                checkpoint_step=pretrained_checkpoint_step)

            if self.hparams.pretrained.get('full_trainstate_ckpt', True):
                pretrained_model = pretrained_train_state.optimizer.target
                pretrained_model_state = pretrained_train_state.model_state
            else:
                (pretrained_model,
                 pretrained_model_state) = pretrained_train_state

            if self.hparams.pretrained.get('only_backbone_pretrained', False):
                # Update params with pretrained params
                for m_key, m_params in pretrained_model.params.items():
                    logging.info(m_key)
                    if m_key not in ['head'] and ('disc' not in m_key):
                        flax_module.params[m_key] = m_params
                    else:
                        logging.info('Not updated!')
                # Update model_state with pretrained model_state
                new_state_dict = {}
                for state_key, state_val in pretrained_model_state.as_dict(
                ).items():
                    logging.info(state_key)
                    if 'head' not in state_key and ('disc' not in state_key):
                        new_state_dict[state_key] = pretrained_model_state[
                            state_key]
                    else:
                        logging.info('Not updated!')
                        new_state_dict[state_key] = state_val
                model_state = nn.Collection(new_state_dict)
            else:
                flax_module = pretrained_model
                model_state = pretrained_model_state

        # Create optimizer.
        optimizer = optimizers.get_optimizer(self.hparams).create(flax_module)

        # Create train state.
        rng, train_rng = jax.random.split(rng)
        train_state = pipeline_utils.TrainState(global_step=0,
                                                optimizer=optimizer,
                                                model_state=model_state,
                                                rng=train_rng)
        self.start_step = train_state.global_step

        # Reset gift regularizer's init point.
        if self.hparams.get('gift_factor', None):
            self.task.regularisers = [
                functools.partial(
                    metrics.parameter_distance,
                    base_params=train_state.optimizer.target.params,
                    norm_factor=self.hparams.get('gift_factor'),
                    mode='l2')
            ]

        if self.hparams.checkpoint:
            train_state, self.start_step = pipeline_utils.restore_checkpoint(
                self.experiment_dir, train_state)
            logging.info('Loading checkpoint at step %d', self.start_step)

        # Replicate the optimzier, state, and rng.
        self.train_state = jax_utils.replicate(train_state)
        del flax_module  # do not keep a copy of the initial model

        # Save the initial state.
        if self.start_step == 0 and self.hparams.checkpoint:
            self.checkpoint(self.train_state, self.start_step)
Exemplo n.º 2
0
def training_loop(
    *,
    module,
    rng,
    train_ds,
    eval_ds,
    loss_fn,
    optimizer,
    train_metrics_dict,
    eval_metrics_dict,
    stats_aggregators,
    config,
    workdir,
):
  """Runs a training and evaluation loop.

  Args:
    module: The module that should be trained.
    rng: A jax pseudo-random number generator key.
    train_ds: Dataset used for training.
    eval_ds: Dataset used for evaluation.
    loss_fn: Loss function to use for training.
    optimizer: Optax optimizer to use for training.
    train_metrics_dict: Collection of metrics to be collected during training.
    eval_metrics_dict: Collection of metrics to be collected during evaluation.
    stats_aggregators: Dictionary of statistics aggregator functions to be run
      on the first evaluation batch. These functions ingest the stats returned
      by the model and output a Dict[str, image/scalar] that will be logged.
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.

  Raises:
    RuntimeError: If a training metric is NaN or inf.

  Returns:
    Training state.
  """
  rng, model_rng = jax.random.split(rng)
  input_shape = tuple(train_ds.element_spec["image"].shape[1:])
  model, init_params, init_state = create_model(module, input_shape, model_rng)
  parameter_overview.log_parameter_overview(model.params)

  # Load a pretrained model parameters and state. Ignore the step and the
  # optimizer state in the checkpoint.
  pretrained_path = config.get("pretrained_checkpoint", "")
  if pretrained_path:
    logging.info("Load pretrained weights from '%s'", pretrained_path)
    state_dict = checkpoint.load_state_dict(pretrained_path)
    flatten_model_params = utils.flatten_dict(state_dict["model_params"],
                                              sep="/")
    model_state = state_dict["model_state"]

    # A prefix can be used to replace only a subpart of the network (e.g the
    # encoder). Prepend the prefix (if any) to model parameters and states.
    prefix = config.get("pretrained_prefix", "")
    if prefix:
      flatten_model_params = utils.add_prefix_to_dict_keys(
          flatten_model_params, f"{prefix}/")
      model_state = utils.add_prefix_to_dict_keys(
          model_state, f"/{prefix}")

    # Merge the params/state from the checkpoint into the initial params/state.
    flatten_init_params = utils.flatten_dict(init_params, sep="/")
    flatten_init_params, ignored_params = utils.override_dict(
        flatten_init_params, flatten_model_params)
    init_params = utils.unflatten_dict(flatten_init_params, delimiter="/")
    init_state, _ = utils.override_dict(init_state, model_state)

    if ignored_params:
      logging.warning("%d/%d parameters from the pretrained checkpoint "
                      "were ignored: %s", len(ignored_params),
                      len(flatten_init_params), ignored_params)

  optimizer_state = optimizer.init(init_params)

  state = TrainState(
      step=1,
      model_params=init_params,
      model_state=init_state,
      optimizer_state=optimizer_state)  # type: ignore
  # Do not keep a copy of the initial model.
  del init_params, init_state, optimizer_state

  train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
  checkpoint_dir = os.path.join(workdir, "checkpoints")

  ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
  state = ckpt.restore_or_initialize(state)
  initial_step = int(state.step)
  # Replicate our parameters.
  state = flax.jax_utils.replicate(state)

  writer = metric_writers.create_default_writer(
      workdir, just_logging=jax.host_id() > 0)
  step_timer = utils.StepTimer(
      batch_size=config.batch_size, initial_step=initial_step)

  # Write config to the summary files. This makes the hyperparameters available
  # in TensorBoard and makes comparison of runs with tensorboard/ easier.
  if initial_step == 1:
    writer.write_hparams(utils.flatten_dict(config.to_dict()))

  # Generate per-device PRNG keys for the training loop.
  rng, train_rng = jax.random.split(rng)
  train_rngs = jax.random.split(train_rng, jax.local_device_count())

  # Generate per-device PRNG keys for model evaluation.
  rng, eval_rng = jax.random.split(rng)
  eval_rngs = jax.random.split(eval_rng, jax.local_device_count())

  logging.info("Starting training loop at step %d.", initial_step)
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)
  train_metrics = utils.Means()

  do_eval_only = config.get("do_eval_only", False)
  if do_eval_only:
    config.num_train_steps = 1

  debug_enabled = config.get("debug", False)
  previous_grads = grads = None
  previous_updates = updates = None
  previous_state = None
  for step in range(initial_step, config.num_train_steps + 1):
    is_last_step = step == config.num_train_steps
    if debug_enabled:
      previous_grads = grads
      previous_updates = updates
      previous_state = state

    # Skip the training if only do the eval.
    if not do_eval_only:
      # Use ._numpy() to avoid copy.
      batch = jax.tree_map(lambda x: x._numpy(), next(train_iter))  # pylint: disable=protected-access
      state, grads, updates, metrics, training_stats, train_rngs = train_step(
          state, batch, module, loss_fn, optimizer, train_metrics_dict,
          train_rngs)
      train_metrics.append(flax.jax_utils.unreplicate(metrics))

      # Update topk temperature with linearly decreasing schedule if enabled.
      if (config.get("linear_decrease_perturbed_sigma", False) and
          config.get("selection_method", "") == "perturbed-topk"):

        model_state = state.model_state.as_dict()

        if "/PatchNet_0" in model_state:
          net_str = "/PatchNet_0"
        else:
          net_str = "/"

        progress = step / config.num_train_steps
        sigma_multiplier = 1. - progress
        previous_mult = model_state[net_str]["sigma_mutiplier"]
        sigma_multiplier = sigma_multiplier + jnp.zeros_like(previous_mult)
        model_state[net_str]["sigma_mutiplier"] = sigma_multiplier
        state = state.replace(model_state=nn.Collection(model_state))

      if debug_enabled:
        if utils.has_any_inf_or_nan(metrics):
          # Save checkpoint
          if previous_state:
            ckpt.save(flax.jax_utils.unreplicate(previous_state))
          ckpt.save(flax.jax_utils.unreplicate(state))

          # Log gradients and updates.
          if previous_grads or previous_updates:
            write_gradient_histogram(writer, step,
                                     grads=previous_grads,
                                     updates=previous_updates)
          write_gradient_histogram(writer, step + 1,
                                   grads=grads, updates=updates)

          raise RuntimeError("A training metric took an invalid value: "
                             f"{metrics}.")

      logging.log_first_n(logging.INFO, "Finished training step %d.", 3, step)
      report_progress(step)

      if step % config.log_loss_every_steps == 0 or is_last_step:
        results = train_metrics.result()
        writer.write_scalars(step, results)
        writer.write_scalars(step, step_timer.get_and_reset(step))
        if utils.has_any_inf_or_nan(results):
          raise ValueError("A training metric took an invalid value.")
        train_metrics.reset()

    if (step % config.checkpoint_every_steps == 0 or is_last_step):
      with step_timer.paused():
        ckpt.save(flax.jax_utils.unreplicate(state))

    # Evaluation
    if step % config.eval_every_steps == 0 or is_last_step:
      with step_timer.paused():
        eval_metrics, first_batch_stats, eval_rngs = evaluate(
            state, module, eval_ds, eval_metrics_dict, eval_rngs)

      if jax.host_id() == 0:
        log_histograms = config.get("log_histograms", False)
        log_images = config.get("log_images", True)
        # Log the last gradients and updates histograms.
        if not do_eval_only:
          write_stats_results(writer, step, training_stats, stats_aggregators,
                              prefix="train/", log_images=log_images)
          if log_histograms:
            write_gradient_histogram(writer, step, grads=grads, updates=updates)

        write_stats_results(writer, step, first_batch_stats,
                            stats_aggregators, prefix="eval/",
                            log_images=log_images)

        # write patch representation histograms
        if (log_histograms and first_batch_stats and
            "patch_representations" in first_batch_stats):
          patch_representations = first_batch_stats["patch_representations"]
          writer.write_histograms(step, {
              "patch_representations": patch_representations
          })

        if eval_metrics:
          writer.write_scalars(step, eval_metrics)

  writer.flush()
  return state