Exemplo n.º 1
0
  def _build_networks_and_optimizer(self):
    self._rng, init_key = jax.random.split(self._rng)

    # We can reuse init_key safely for the action selection key
    # since it is only used for shape inference during initialization.
    self.network_params = self.network_def.init(init_key, self.state, init_key)
    self.network_optimizer = dqn_agent.create_optimizer(self._optimizer_name)
    self.optimizer_state = self.network_optimizer.init(self.network_params)

    # TODO(joshgreaves): Find a way to just copy the critic params
    self.target_params = self.network_params

    # \alpha network
    self.log_alpha = jnp.zeros(1)
    self.alpha_optimizer = dqn_agent.create_optimizer(self._optimizer_name)
    self.alpha_optimizer_state = self.alpha_optimizer.init(self.log_alpha)
Exemplo n.º 2
0
 def _build_networks_and_optimizer(self):
     self._rng, rng = jax.random.split(self._rng)
     self.online_params = self.network_def.init(rng,
                                                x=self.state,
                                                support=self._support)
     self.optimizer = dqn_agent.create_optimizer(self._optimizer_name)
     self.optimizer_state = self.optimizer.init(self.online_params)
     self.target_network_params = self.online_params
Exemplo n.º 3
0
 def _build_networks_and_optimizer(self):
     self._rng, rng = jax.random.split(self._rng)
     online_network_params = self.network_def.init(rng,
                                                   x=self.state,
                                                   rng=self._rng)
     optimizer_def = dqn_agent.create_optimizer(self._optimizer_name)
     self.optimizer = optimizer_def.create(online_network_params)
     self.target_network_params = copy.deepcopy(online_network_params)
Exemplo n.º 4
0
    def unbundle(self, checkpoint_dir, iteration_number, bundle_dictionary):
        """Restores the agent from a checkpoint.

    Restores the agent's Python objects to those specified in bundle_dictionary,
    and restores the TensorFlow objects to those specified in the
    checkpoint_dir. If the checkpoint_dir does not exist, will not reset the
      agent's state.

    Args:
      checkpoint_dir: str, path to the checkpoint saved.
      iteration_number: int, checkpoint version, used when restoring the replay
        buffer.
      bundle_dictionary: dict, containing additional Python objects owned by the
        agent.

    Returns:
      bool, True if unbundling was successful.
    """
        try:
            # self._replay.load() will throw a NotFoundError if it does not find all
            # the necessary files.
            self._replay.load(checkpoint_dir, iteration_number)
        except tf.errors.NotFoundError:
            if not self.allow_partial_reload:
                # If we don't allow partial reloads, we will return False.
                return False
            logging.warning('Unable to reload replay buffer!')
        if bundle_dictionary is not None:
            self.state = bundle_dictionary['state']
            self.training_steps = bundle_dictionary['training_steps']

            self.network_params = bundle_dictionary['network_params']
            self.network_optimizer = dqn_agent.create_optimizer(
                self._optimizer_name)
            self.optimizer_state = bundle_dictionary['optimizer_state']
            self.target_params = bundle_dictionary['target_params']
            self.log_alpha = bundle_dictionary['log_alpha']
            self.alpha_optimizer = dqn_agent.create_optimizer(
                self._optimizer_name)
            self.alpha_optimizer_state = bundle_dictionary[
                'alpha_optimizer_state']
        elif not self.allow_partial_reload:
            return False
        else:
            logging.warning("Unable to reload the agent's parameters!")
        return True
Exemplo n.º 5
0
 def _build_networks_and_optimizer(self):
     self._rng, active_rng, passive_rng = jax.random.split(self._rng, 3)
     # Initialize active networks.
     self.active_online_params = self.network_def.init(active_rng,
                                                       x=self.state)
     self.active_optimizer = dqn_agent.create_optimizer(
         self._optimizer_name)
     self.active_optimizer_state = self.active_optimizer.init(
         self.active_online_params)
     self.active_target_params = self.active_online_params
     # Initialize passive network with the regular network.
     self.passive_online_params = self.network_def.init(passive_rng,
                                                        x=self.state)
     self.passive_optimizer = dqn_agent.create_optimizer(
         self._optimizer_name)
     self.passive_optimizer_state = self.passive_optimizer.init(
         self.passive_online_params)
     self.passive_target_params = self.passive_online_params
Exemplo n.º 6
0
    def _build_networks_and_optimizer(self):
        rngs = jax.random.split(self._rng, num=5)
        self._rng, encoder_key, network_key, reward_key, dynamics_key = rngs
        self.network_optimizer = dqn_agent.create_optimizer(
            self._optimizer_name)
        self.encoder_optimizer = dqn_agent.create_optimizer(
            self._optimizer_name)
        self.reward_optimizer = dqn_agent.create_optimizer(
            self._optimizer_name)
        self.dynamics_optimizer = dqn_agent.create_optimizer(
            self._optimizer_name)
        self.alpha_optimizer = dqn_agent.create_optimizer(self._optimizer_name)

        # Initialize encoder network.
        self.encoder_params = self.encoder_network_def.init(
            encoder_key, self.state)
        self.encoder_optimizer_state = self.encoder_optimizer.init(
            self.encoder_params)

        # Create a sample latent state for initializing the SAC network.
        sample_z = jnp.zeros_like(
            self.encoder_network_def.apply(self.encoder_params,
                                           self.state).critic_z)
        # since it is only used for shape inference during initialization.
        self.network_params = self.network_def.init(network_key, sample_z,
                                                    network_key)
        self.optimizer_state = self.network_optimizer.init(self.network_params)

        # Initialize reward and dynamics models.
        self.reward_params = self.reward_model_def.init(reward_key, sample_z)
        self.reward_optimizer_state = self.reward_optimizer.init(
            self.reward_params)
        # Sending a dummy action and key for initialization.
        self.dynamics_params = self.dynamics_model_def.init(
            dynamics_key, sample_z, jnp.zeros(self.action_shape), dynamics_key)
        self.dynamics_optimizer_state = self.dynamics_optimizer.init(
            self.dynamics_params)

        self.encoder_target_params = self.encoder_params
        self.target_params = self.network_params

        # \alpha network
        self.log_alpha = jnp.zeros(1)
        self.alpha_optimizer_state = self.alpha_optimizer.init(self.log_alpha)
Exemplo n.º 7
0
 def _build_networks_and_optimizer(self):
     self._rng, rng = jax.random.split(self._rng)
     self.online_params = self.network_def.init(
         rng,
         x=self.state,
         num_quantiles=self.num_tau_samples,
         rng=self._rng)
     self.optimizer = dqn_agent.create_optimizer(self._optimizer_name)
     self.optimizer_state = self.optimizer.init(self.online_params)
     self.target_network_params = self.online_params
def reload_jax_checkpoint(agent, bundle_dictionary):
    """Reload variables from a fully specified checkpoint."""
    if bundle_dictionary is not None:
        agent.state = bundle_dictionary['state']
        if isinstance(bundle_dictionary['online_params'], core.FrozenDict):
            agent.online_params = bundle_dictionary['online_params']
        else:  # Load pre-linen checkpoint.
            agent.online_params = core.FrozenDict({
                'params':
                flax_checkpoints.convert_pre_linen(
                    bundle_dictionary['online_params']).unfreeze()
            })
        # We recreate the optimizer with the new online weights.
        # pylint: disable=protected-access
        agent.optimizer = dqn_agent.create_optimizer(agent._optimizer_name)
        # pylint: enable=protected-access
        if 'optimizer_state' in bundle_dictionary:
            agent.optimizer_state = bundle_dictionary['optimizer_state']
        else:
            agent.optimizer_state = agent.optimizer.init(agent.online_params)
        logging.info('Done restoring!')