def load_model(rng, model_config, model_ckpt, task, load_full_train_state=True, checkpoint_step=None): """Set up a train state model and loads it from the given checkpoint path. Args: rng: float; JAX PRNG key. model_config: configdict; Hparams of the model. model_ckpt: str; Path to model checkpoint. task: Task; Task on the which the model will be applied. load_full_train_state: bool; Whether to load the full TrainState or just the model and model_state. checkpoint_step: int; Checkpoint step to load (if None loads the most recent checkpoint). Returns: TrainState if load_full_train_state else (model, model_state). """ teacher_cls = all_models.get_model_class(model_config.get('model_name')) model_config.output_dim = task.task_params.output_dim flax_module, _ = teacher_cls.build_flax_module(model_config) # Initialize flax model. rng, dropout_rng = jax.random.split(rng) (flax_model, init_model_state, _) = create_flax_module( flax_module, task.dataset.meta_data['input_shape'], model_config, dropout_rng, task.dataset.meta_data.get('input_dtype', jnp.float32)) if load_full_train_state: # Create train state. rng, teacher_rng = jax.random.split(rng) train_state = TrainState(global_step=0, optimizer=optimizers.get_optimizer( model_config).create(flax_model), model_state=init_model_state, rng=teacher_rng) # Load from checkpoint if checkpoint is specified. if model_ckpt: train_state, start_step = restore_checkpoint( model_ckpt, train_state, checkpoint_step) logging.info('Loading model checkpoint at step %d', start_step) return train_state elif model_ckpt: model, model_state = checkpoints.restore_checkpoint( model_ckpt, (flax_model, init_model_state), checkpoint_step) return model, model_state
def maybe_reset_train_state(self): optimizer = jax_utils.unreplicate(self.train_state.optimizer) if self.hparams.get('reinitilize_params_at_each_step', False): del optimizer.target (flax_model, _, _) = pipeline_utils.create_flax_module( optimizer.target.module, self.task.dataset.meta_data['input_shape'], self.hparams, nn.make_rng(), self.task.dataset.meta_data.get('input_dtype', jnp.float32)) else: flax_model = optimizer.target # Reset optimizer if self.hparams.get('reinitialize_optimizer_at_each_step', False): optimizer = optimizers.get_optimizer( self.hparams).create(flax_model) else: optimizer = optimizer.replace(target=flax_model) optimizer = jax_utils.replicate(optimizer) self.train_state = self.train_state.replace(optimizer=optimizer)
def set_train_state(self, model_cls, rng): """Set up train state. Args: model_cls: Type of the flax module. rng: Jax PRNG. """ # Build flax_model. self.hparams.output_dim = self.task.task_params.output_dim flax_module, self.hparams = model_cls.build_flax_module(self.hparams) # Initialize flax module. rng, dropout_rng = jax.random.split(rng) (flax_module, model_state, self.num_trainable_params) = pipeline_utils.create_flax_module( flax_module, self.task.dataset.meta_data['input_shape'], self.hparams, dropout_rng, self.task.dataset.meta_data.get('input_dtype', jnp.float32)) if self.hparams.get('pretrained', None): pretrained_config = self.hparams.pretrained.get('config') pretrained_checkpoint_path = self.hparams.pretrained.get( 'checkpoint_path') pretrained_checkpoint_step = self.hparams.pretrained.get( 'checkpoint_step', None) rng, new_rng = jax.random.split(rng) # Create and loads the model from the pretrained path. if pretrained_checkpoint_step is not None: logging.info('load pretrained model at step %d', pretrained_checkpoint_step) pretrained_train_state = pipeline_utils.load_model( rng=new_rng, model_config=pretrained_config, model_ckpt=pretrained_checkpoint_path, task=self.task, load_full_train_state=self.hparams.pretrained.get( 'full_trainstate_ckpt', True), checkpoint_step=pretrained_checkpoint_step) if self.hparams.pretrained.get('full_trainstate_ckpt', True): pretrained_model = pretrained_train_state.optimizer.target pretrained_model_state = pretrained_train_state.model_state else: (pretrained_model, pretrained_model_state) = pretrained_train_state if self.hparams.pretrained.get('only_backbone_pretrained', False): # Update params with pretrained params for m_key, m_params in pretrained_model.params.items(): logging.info(m_key) if m_key not in ['head'] and ('disc' not in m_key): flax_module.params[m_key] = m_params else: logging.info('Not updated!') # Update model_state with pretrained model_state new_state_dict = {} for state_key, state_val in pretrained_model_state.as_dict( ).items(): logging.info(state_key) if 'head' not in state_key and ('disc' not in state_key): new_state_dict[state_key] = pretrained_model_state[ state_key] else: logging.info('Not updated!') new_state_dict[state_key] = state_val model_state = nn.Collection(new_state_dict) else: flax_module = pretrained_model model_state = pretrained_model_state # Create optimizer. optimizer = optimizers.get_optimizer(self.hparams).create(flax_module) # Create train state. rng, train_rng = jax.random.split(rng) train_state = pipeline_utils.TrainState(global_step=0, optimizer=optimizer, model_state=model_state, rng=train_rng) self.start_step = train_state.global_step # Reset gift regularizer's init point. if self.hparams.get('gift_factor', None): self.task.regularisers = [ functools.partial( metrics.parameter_distance, base_params=train_state.optimizer.target.params, norm_factor=self.hparams.get('gift_factor'), mode='l2') ] if self.hparams.checkpoint: train_state, self.start_step = pipeline_utils.restore_checkpoint( self.experiment_dir, train_state) logging.info('Loading checkpoint at step %d', self.start_step) # Replicate the optimzier, state, and rng. self.train_state = jax_utils.replicate(train_state) del flax_module # do not keep a copy of the initial model # Save the initial state. if self.start_step == 0 and self.hparams.checkpoint: self.checkpoint(self.train_state, self.start_step)