def test_convert_checkpoint(self): inputs = jnp.ones([2, 5, 5, 1]) rng = jax.random.PRNGKey(0) # pre-Linen. with flax.nn.stateful() as model_state: y, params = ModelPreLinen.init(rng, inputs) pre_linen_optimizer = flax.optim.GradientDescent(0.1).create(params) train_state = TrainState(optimizer=pre_linen_optimizer, model_state=model_state) state_dict = flax.serialization.to_state_dict(train_state) # Linen. model = Model() variables = model.init(rng, inputs) optimizer = flax.optim.GradientDescent(0.1).create(variables['params']) optimizer = optimizer.restore_state( flax.core.unfreeze( checkpoints.convert_pre_linen(state_dict['optimizer']))) optimizer = optimizer.apply_gradient(variables['params']) batch_stats = checkpoints.convert_pre_linen( flax.traverse_util.unflatten_dict({ tuple(k.split('/')[1:]): v for k, v in model_state.as_dict().items() })) y, updated_state = model.apply(dict(params=optimizer.target, batch_stats=batch_stats), inputs, mutable=['batch_stats']) del y, updated_state # not used.
def unbundle(self, checkpoint_dir, iteration_number, bundle_dictionary): """Restores the agent from a checkpoint. Restores the agent's Python objects to those specified in bundle_dictionary, and restores the TensorFlow objects to those specified in the checkpoint_dir. If the checkpoint_dir does not exist, will not reset the agent's state. Args: checkpoint_dir: str, path to the checkpoint saved. iteration_number: int, checkpoint version, used when restoring the replay buffer. bundle_dictionary: dict, containing additional Python objects owned by the agent. Returns: bool, True if unbundling was successful. """ try: # self._replay.load() will throw a NotFoundError if it does not find all # the necessary files. self._replay.load(checkpoint_dir, iteration_number) except tf.errors.NotFoundError: if not self.allow_partial_reload: # If we don't allow partial reloads, we will return False. return False logging.warning('Unable to reload replay buffer!') if bundle_dictionary is not None: self.state = bundle_dictionary['state'] self.training_steps = bundle_dictionary['training_steps'] if isinstance(bundle_dictionary['online_params'], core.FrozenDict): self.online_params = bundle_dictionary['online_params'] self.target_network_params = bundle_dictionary['target_params'] else: # Load pre-linen checkpoint. self.online_params = core.FrozenDict({ 'params': checkpoints.convert_pre_linen( bundle_dictionary['online_params']).unfreeze() }) self.target_network_params = core.FrozenDict({ 'params': checkpoints.convert_pre_linen( bundle_dictionary['target_params']).unfreeze() }) # We recreate the optimizer with the new online weights. self.optimizer = create_optimizer(self._optimizer_name) if 'optimizer_state' in bundle_dictionary: self.optimizer_state = bundle_dictionary['optimizer_state'] else: self.optimizer_state = self.optimizer.init(self.online_params) elif not self.allow_partial_reload: return False else: logging.warning("Unable to reload the agent's parameters!") return True
def test_convert_pre_linen(self): params = checkpoints.convert_pre_linen({ 'mod_0': { 'submod1_0': {}, 'submod2_1': {}, 'submod1_2': {}, }, 'mod2_2': { 'submod2_2_0': {} }, 'mod2_11': { 'submod2_11_0': {} }, 'mod2_1': { 'submod2_1_0': {} }, }) self.assertDictEqual( core.unfreeze(params), { 'mod_0': { 'submod1_0': {}, 'submod1_1': {}, 'submod2_0': {}, }, 'mod2_0': { 'submod2_1_0': {} }, 'mod2_1': { 'submod2_2_0': {} }, 'mod2_2': { 'submod2_11_0': {} }, })
def load(path): """Loads params from a checkpoint previously stored with `save()`.""" with gfile.GFile(path, 'rb') as f: ckpt_dict = np.load(f, allow_pickle=False) keys, values = zip(*list(ckpt_dict.items())) params = checkpoints.convert_pre_linen(recover_tree(keys, values)) if isinstance(params, flax.core.FrozenDict): params = params.unfreeze() return params
def reload_jax_checkpoint(agent, bundle_dictionary): """Reload variables from a fully specified checkpoint.""" if bundle_dictionary is not None: agent.state = bundle_dictionary['state'] if isinstance(bundle_dictionary['online_params'], core.FrozenDict): agent.online_params = bundle_dictionary['online_params'] else: # Load pre-linen checkpoint. agent.online_params = core.FrozenDict({ 'params': flax_checkpoints.convert_pre_linen( bundle_dictionary['online_params']).unfreeze() }) # We recreate the optimizer with the new online weights. # pylint: disable=protected-access agent.optimizer = dqn_agent.create_optimizer(agent._optimizer_name) # pylint: enable=protected-access if 'optimizer_state' in bundle_dictionary: agent.optimizer_state = bundle_dictionary['optimizer_state'] else: agent.optimizer_state = agent.optimizer.init(agent.online_params) logging.info('Done restoring!')