Ejemplo n.º 1
0
    def unbundle(self, checkpoint_dir, iteration_number, bundle_dictionary):
        """Restores the agent from a checkpoint.

    Restores the agent's Python objects to those specified in bundle_dictionary,
    and restores the TensorFlow objects to those specified in the
    checkpoint_dir. If the checkpoint_dir does not exist, will not reset the
      agent's state.

    Args:
      checkpoint_dir: str, path to the checkpoint saved.
      iteration_number: int, checkpoint version, used when restoring the replay
        buffer.
      bundle_dictionary: dict, containing additional Python objects owned by
        the agent.

    Returns:
      bool, True if unbundling was successful.
    """
        try:
            # self._replay.load() will throw a NotFoundError if it does not find all
            # the necessary files.
            self._replay.load(checkpoint_dir, iteration_number)
        except tf.errors.NotFoundError:
            if not self.allow_partial_reload:
                # If we don't allow partial reloads, we will return False.
                return False
            logging.warning('Unable to reload replay buffer!')
        if bundle_dictionary is not None:
            self.state = bundle_dictionary['state']
            self.training_steps = bundle_dictionary['training_steps']
            if isinstance(bundle_dictionary['online_params'], core.FrozenDict):
                self.online_params = bundle_dictionary['online_params']
                self.target_network_params = bundle_dictionary['target_params']
            else:  # Load pre-linen checkpoint.
                self.online_params = core.FrozenDict({
                    'params':
                    checkpoints.convert_pre_linen(
                        bundle_dictionary['online_params']).unfreeze()
                })
                self.target_network_params = core.FrozenDict({
                    'params':
                    checkpoints.convert_pre_linen(
                        bundle_dictionary['target_params']).unfreeze()
                })
            # We recreate the optimizer with the new online weights.
            self.optimizer = create_optimizer(self._optimizer_name)
            if 'optimizer_state' in bundle_dictionary:
                self.optimizer_state = bundle_dictionary['optimizer_state']
            else:
                self.optimizer_state = self.optimizer.init(self.online_params)
        elif not self.allow_partial_reload:
            return False
        else:
            logging.warning("Unable to reload the agent's parameters!")
        return True
def reload_jax_checkpoint(agent, bundle_dictionary):
    """Reload variables from a fully specified checkpoint."""
    if bundle_dictionary is not None:
        agent.state = bundle_dictionary['state']
        if isinstance(bundle_dictionary['online_params'], core.FrozenDict):
            agent.online_params = bundle_dictionary['online_params']
        else:  # Load pre-linen checkpoint.
            agent.online_params = core.FrozenDict({
                'params':
                flax_checkpoints.convert_pre_linen(
                    bundle_dictionary['online_params']).unfreeze()
            })
        # We recreate the optimizer with the new online weights.
        # pylint: disable=protected-access
        agent.optimizer = dqn_agent.create_optimizer(agent._optimizer_name)
        # pylint: enable=protected-access
        if 'optimizer_state' in bundle_dictionary:
            agent.optimizer_state = bundle_dictionary['optimizer_state']
        else:
            agent.optimizer_state = agent.optimizer.init(agent.online_params)
        logging.info('Done restoring!')