def unbundle(self, checkpoint_dir, iteration_number, bundle_dictionary): """Restores the agent from a checkpoint. Restores the agent's Python objects to those specified in bundle_dictionary, and restores the TensorFlow objects to those specified in the checkpoint_dir. If the checkpoint_dir does not exist, will not reset the agent's state. Args: checkpoint_dir: str, path to the checkpoint saved. iteration_number: int, checkpoint version, used when restoring the replay buffer. bundle_dictionary: dict, containing additional Python objects owned by the agent. Returns: bool, True if unbundling was successful. """ try: # self._replay.load() will throw a NotFoundError if it does not find all # the necessary files. self._replay.load(checkpoint_dir, iteration_number) except tf.errors.NotFoundError: if not self.allow_partial_reload: # If we don't allow partial reloads, we will return False. return False logging.warning('Unable to reload replay buffer!') if bundle_dictionary is not None: self.state = bundle_dictionary['state'] self.training_steps = bundle_dictionary['training_steps'] if isinstance(bundle_dictionary['online_params'], core.FrozenDict): self.online_params = bundle_dictionary['online_params'] self.target_network_params = bundle_dictionary['target_params'] else: # Load pre-linen checkpoint. self.online_params = core.FrozenDict({ 'params': checkpoints.convert_pre_linen( bundle_dictionary['online_params']).unfreeze() }) self.target_network_params = core.FrozenDict({ 'params': checkpoints.convert_pre_linen( bundle_dictionary['target_params']).unfreeze() }) # We recreate the optimizer with the new online weights. self.optimizer = create_optimizer(self._optimizer_name) if 'optimizer_state' in bundle_dictionary: self.optimizer_state = bundle_dictionary['optimizer_state'] else: self.optimizer_state = self.optimizer.init(self.online_params) elif not self.allow_partial_reload: return False else: logging.warning("Unable to reload the agent's parameters!") return True
def reload_jax_checkpoint(agent, bundle_dictionary): """Reload variables from a fully specified checkpoint.""" if bundle_dictionary is not None: agent.state = bundle_dictionary['state'] if isinstance(bundle_dictionary['online_params'], core.FrozenDict): agent.online_params = bundle_dictionary['online_params'] else: # Load pre-linen checkpoint. agent.online_params = core.FrozenDict({ 'params': flax_checkpoints.convert_pre_linen( bundle_dictionary['online_params']).unfreeze() }) # We recreate the optimizer with the new online weights. # pylint: disable=protected-access agent.optimizer = dqn_agent.create_optimizer(agent._optimizer_name) # pylint: enable=protected-access if 'optimizer_state' in bundle_dictionary: agent.optimizer_state = bundle_dictionary['optimizer_state'] else: agent.optimizer_state = agent.optimizer.init(agent.online_params) logging.info('Done restoring!')