예제 #1
0
def load_model(rng,
               model_config,
               model_ckpt,
               task,
               load_full_train_state=True,
               checkpoint_step=None):
    """Set up a train state model and loads it from the given checkpoint path.

  Args:
    rng: float; JAX PRNG key.
    model_config: configdict; Hparams of the model.
    model_ckpt: str; Path to model checkpoint.
    task: Task; Task on the which the model will be applied.
    load_full_train_state: bool; Whether to load the full TrainState or just the
      model and model_state.
    checkpoint_step: int; Checkpoint step to load (if None loads the most recent
      checkpoint).

  Returns:
    TrainState if load_full_train_state else (model, model_state).
  """
    teacher_cls = all_models.get_model_class(model_config.get('model_name'))

    model_config.output_dim = task.task_params.output_dim
    flax_module, _ = teacher_cls.build_flax_module(model_config)

    # Initialize flax model.
    rng, dropout_rng = jax.random.split(rng)
    (flax_model, init_model_state, _) = create_flax_module(
        flax_module, task.dataset.meta_data['input_shape'], model_config,
        dropout_rng, task.dataset.meta_data.get('input_dtype', jnp.float32))

    if load_full_train_state:
        # Create train state.
        rng, teacher_rng = jax.random.split(rng)
        train_state = TrainState(global_step=0,
                                 optimizer=optimizers.get_optimizer(
                                     model_config).create(flax_model),
                                 model_state=init_model_state,
                                 rng=teacher_rng)

        # Load from checkpoint if checkpoint is specified.
        if model_ckpt:
            train_state, start_step = restore_checkpoint(
                model_ckpt, train_state, checkpoint_step)
            logging.info('Loading model checkpoint at step %d', start_step)

        return train_state
    elif model_ckpt:
        model, model_state = checkpoints.restore_checkpoint(
            model_ckpt, (flax_model, init_model_state), checkpoint_step)

        return model, model_state
    def maybe_reset_train_state(self):
        optimizer = jax_utils.unreplicate(self.train_state.optimizer)

        if self.hparams.get('reinitilize_params_at_each_step', False):
            del optimizer.target
            (flax_model, _, _) = pipeline_utils.create_flax_module(
                optimizer.target.module,
                self.task.dataset.meta_data['input_shape'], self.hparams,
                nn.make_rng(),
                self.task.dataset.meta_data.get('input_dtype', jnp.float32))
        else:
            flax_model = optimizer.target

        # Reset optimizer
        if self.hparams.get('reinitialize_optimizer_at_each_step', False):
            optimizer = optimizers.get_optimizer(
                self.hparams).create(flax_model)
        else:
            optimizer = optimizer.replace(target=flax_model)

        optimizer = jax_utils.replicate(optimizer)
        self.train_state = self.train_state.replace(optimizer=optimizer)
예제 #3
0
    def set_train_state(self, model_cls, rng):
        """Set up train state.

    Args:
      model_cls: Type of the flax module.
      rng: Jax PRNG.
    """
        # Build flax_model.
        self.hparams.output_dim = self.task.task_params.output_dim
        flax_module, self.hparams = model_cls.build_flax_module(self.hparams)

        # Initialize flax module.
        rng, dropout_rng = jax.random.split(rng)
        (flax_module, model_state,
         self.num_trainable_params) = pipeline_utils.create_flax_module(
             flax_module, self.task.dataset.meta_data['input_shape'],
             self.hparams, dropout_rng,
             self.task.dataset.meta_data.get('input_dtype', jnp.float32))

        if self.hparams.get('pretrained', None):
            pretrained_config = self.hparams.pretrained.get('config')
            pretrained_checkpoint_path = self.hparams.pretrained.get(
                'checkpoint_path')
            pretrained_checkpoint_step = self.hparams.pretrained.get(
                'checkpoint_step', None)

            rng, new_rng = jax.random.split(rng)
            # Create and loads the model from the pretrained path.
            if pretrained_checkpoint_step is not None:
                logging.info('load pretrained model at step %d',
                             pretrained_checkpoint_step)
            pretrained_train_state = pipeline_utils.load_model(
                rng=new_rng,
                model_config=pretrained_config,
                model_ckpt=pretrained_checkpoint_path,
                task=self.task,
                load_full_train_state=self.hparams.pretrained.get(
                    'full_trainstate_ckpt', True),
                checkpoint_step=pretrained_checkpoint_step)

            if self.hparams.pretrained.get('full_trainstate_ckpt', True):
                pretrained_model = pretrained_train_state.optimizer.target
                pretrained_model_state = pretrained_train_state.model_state
            else:
                (pretrained_model,
                 pretrained_model_state) = pretrained_train_state

            if self.hparams.pretrained.get('only_backbone_pretrained', False):
                # Update params with pretrained params
                for m_key, m_params in pretrained_model.params.items():
                    logging.info(m_key)
                    if m_key not in ['head'] and ('disc' not in m_key):
                        flax_module.params[m_key] = m_params
                    else:
                        logging.info('Not updated!')
                # Update model_state with pretrained model_state
                new_state_dict = {}
                for state_key, state_val in pretrained_model_state.as_dict(
                ).items():
                    logging.info(state_key)
                    if 'head' not in state_key and ('disc' not in state_key):
                        new_state_dict[state_key] = pretrained_model_state[
                            state_key]
                    else:
                        logging.info('Not updated!')
                        new_state_dict[state_key] = state_val
                model_state = nn.Collection(new_state_dict)
            else:
                flax_module = pretrained_model
                model_state = pretrained_model_state

        # Create optimizer.
        optimizer = optimizers.get_optimizer(self.hparams).create(flax_module)

        # Create train state.
        rng, train_rng = jax.random.split(rng)
        train_state = pipeline_utils.TrainState(global_step=0,
                                                optimizer=optimizer,
                                                model_state=model_state,
                                                rng=train_rng)
        self.start_step = train_state.global_step

        # Reset gift regularizer's init point.
        if self.hparams.get('gift_factor', None):
            self.task.regularisers = [
                functools.partial(
                    metrics.parameter_distance,
                    base_params=train_state.optimizer.target.params,
                    norm_factor=self.hparams.get('gift_factor'),
                    mode='l2')
            ]

        if self.hparams.checkpoint:
            train_state, self.start_step = pipeline_utils.restore_checkpoint(
                self.experiment_dir, train_state)
            logging.info('Loading checkpoint at step %d', self.start_step)

        # Replicate the optimzier, state, and rng.
        self.train_state = jax_utils.replicate(train_state)
        del flax_module  # do not keep a copy of the initial model

        # Save the initial state.
        if self.start_step == 0 and self.hparams.checkpoint:
            self.checkpoint(self.train_state, self.start_step)