Пример #1
0
    def __call__(
        self,
        network: networks_lib.FeedForwardNetwork,
        params: networks_lib.Params,
        target_params: networks_lib.Params,
        batch: reverb.ReplaySample,
        key: networks_lib.PRNGKey,
    ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        transitions: types.Transition = batch.data

        # Forward pass.
        q_tm1 = network.apply(params, transitions.observation)
        q_t = network.apply(target_params, transitions.next_observation)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                       self.max_abs_reward).astype(jnp.float32)

        # Compute Q-learning TD-error.
        batch_error = jax.vmap(rlax.q_learning)
        td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t)
        batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter)

        loss = jnp.mean(batch_loss)
        extra = learning_lib.LossExtra(metrics={})
        return loss, extra
Пример #2
0
        def loss(params: hk.Params, target_params: hk.Params,
                 sample: reverb.ReplaySample):
            o_tm1, a_tm1, r_t, d_t, o_t = sample.data
            keys, probs = sample.info[:2]

            # Forward pass.
            q_tm1 = network.apply(params, o_tm1)
            q_t_value = network.apply(target_params, o_t)
            q_t_selector = network.apply(params, o_t)

            # Cast and clip rewards.
            d_t = (d_t * discount).astype(jnp.float32)
            r_t = jnp.clip(r_t, -max_abs_reward,
                           max_abs_reward).astype(jnp.float32)

            # Compute double Q-learning n-step TD-error.
            batch_error = jax.vmap(rlax.double_q_learning)
            td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value,
                                   q_t_selector)
            batch_loss = rlax.huber_loss(td_error, huber_loss_parameter)

            # Importance weighting.
            importance_weights = (1. / probs).astype(jnp.float32)
            importance_weights **= importance_sampling_exponent
            importance_weights /= jnp.max(importance_weights)

            # Reweight.
            mean_loss = jnp.mean(importance_weights * batch_loss)  # []

            priorities = jnp.abs(td_error).astype(jnp.float64)

            return mean_loss, (keys, priorities)
Пример #3
0
    def __call__(
        self,
        network: networks_lib.FeedForwardNetwork,
        params: networks_lib.Params,
        target_params: networks_lib.Params,
        batch: reverb.ReplaySample,
        key: networks_lib.PRNGKey,
    ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        transitions: types.Transition = batch.data

        # Forward pass.
        q_online_s = network.apply(params, transitions.observation)
        action_one_hot = jax.nn.one_hot(transitions.action,
                                        q_online_s.shape[-1])
        q_online_sa = jnp.sum(action_one_hot * q_online_s, axis=-1)
        q_target_s = network.apply(target_params, transitions.observation)
        q_target_next = network.apply(target_params,
                                      transitions.next_observation)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                       self.max_abs_reward).astype(jnp.float32)

        # Munchausen term : tau * log_pi(a|s)
        munchausen_term = self.entropy_temperature * jax.nn.log_softmax(
            q_target_s / self.entropy_temperature, axis=-1)
        munchausen_term_a = jnp.sum(action_one_hot * munchausen_term, axis=-1)
        munchausen_term_a = jnp.clip(munchausen_term_a,
                                     a_min=self.clip_value_min,
                                     a_max=0.)

        # Soft Bellman operator applied to q
        next_v = self.entropy_temperature * jax.nn.logsumexp(
            q_target_next / self.entropy_temperature, axis=-1)
        target_q = jax.lax.stop_gradient(r_t + self.munchausen_coefficient *
                                         munchausen_term_a + d_t * next_v)

        batch_loss = rlax.huber_loss(target_q - q_online_sa,
                                     self.huber_loss_parameter)
        loss = jnp.mean(batch_loss)

        extra = learning_lib.LossExtra(metrics={})
        return loss, extra
Пример #4
0
  def __call__(
      self,
      network: networks_lib.FeedForwardNetwork,
      params: networks_lib.Params,
      target_params: networks_lib.Params,
      batch: reverb.ReplaySample,
      key: networks_lib.PRNGKey,
  ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
    """Calculate a loss on a single batch of data."""
    transitions: types.Transition = batch.data
    keys, probs, *_ = batch.info

    # Forward pass.
    if self.stochastic_network:
      q_tm1 = network.apply(params, key, transitions.observation)
      q_t_value = network.apply(target_params, key,
                                transitions.next_observation)
      q_t_selector = network.apply(params, key, transitions.next_observation)
    else:
      q_tm1 = network.apply(params, transitions.observation)
      q_t_value = network.apply(target_params, transitions.next_observation)
      q_t_selector = network.apply(params, transitions.next_observation)

    # Cast and clip rewards.
    d_t = (transitions.discount * self.discount).astype(jnp.float32)
    r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                   self.max_abs_reward).astype(jnp.float32)

    # Compute double Q-learning n-step TD-error.
    batch_error = jax.vmap(rlax.double_q_learning)
    td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t_value,
                           q_t_selector)
    batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter)

    # Importance weighting.
    importance_weights = (1. / probs).astype(jnp.float32)
    importance_weights **= self.importance_sampling_exponent
    importance_weights /= jnp.max(importance_weights)

    # Reweight.
    loss = jnp.mean(importance_weights * batch_loss)  # []
    reverb_update = learning_lib.ReverbUpdate(
        keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64))
    extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update)
    return loss, extra
Пример #5
0
    def __call__(
        self,
        network: hk.Transformed,
        params: hk.Params,
        target_params: hk.Params,
        batch: reverb.ReplaySample,
        key: jnp.DeviceArray,
    ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        o_tm1, a_tm1, r_t, d_t, o_t = batch.data
        keys, probs, *_ = batch.info

        # Forward pass.
        q_tm1 = network.apply(params, o_tm1)
        q_t_value = network.apply(target_params, o_t)
        q_t_selector = network.apply(params, o_t)

        # Cast and clip rewards.
        d_t = (d_t * self.discount).astype(jnp.float32)
        r_t = jnp.clip(r_t, -self.max_abs_reward,
                       self.max_abs_reward).astype(jnp.float32)

        # Compute double Q-learning n-step TD-error.
        batch_error = jax.vmap(rlax.double_q_learning)
        td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value, q_t_selector)
        batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter)

        # Importance weighting.
        importance_weights = (1. / probs).astype(jnp.float32)
        importance_weights **= self.importance_sampling_exponent
        importance_weights /= jnp.max(importance_weights)

        # Reweight.
        loss = jnp.mean(importance_weights * batch_loss)  # []
        reverb_update = learning_lib.ReverbUpdate(
            keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64))
        extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update)
        return loss, extra
Пример #6
0
    def __init__(self,
                 player_id,
                 state_representation_size,
                 num_actions,
                 hidden_layers_sizes=128,
                 replay_buffer_capacity=10000,
                 batch_size=128,
                 replay_buffer_class=ReplayBuffer,
                 learning_rate=0.01,
                 update_target_network_every=1000,
                 learn_every=10,
                 discount_factor=1.0,
                 min_buffer_size_to_learn=1000,
                 epsilon_start=1.0,
                 epsilon_end=0.1,
                 epsilon_decay_duration=int(1e6),
                 optimizer_str="sgd",
                 loss_str="mse",
                 huber_loss_parameter=1.0):
        """Initialize the DQN agent."""

        # This call to locals() is used to store every argument used to initialize
        # the class instance, so it can be copied with no hyperparameter change.
        self._kwargs = locals()

        self.player_id = player_id
        self._num_actions = num_actions
        if isinstance(hidden_layers_sizes, int):
            hidden_layers_sizes = [hidden_layers_sizes]
        self._layer_sizes = hidden_layers_sizes
        self._batch_size = batch_size
        self._update_target_network_every = update_target_network_every
        self._learn_every = learn_every
        self._min_buffer_size_to_learn = min_buffer_size_to_learn
        self._discount_factor = discount_factor
        self.huber_loss_parameter = huber_loss_parameter

        self._epsilon_start = epsilon_start
        self._epsilon_end = epsilon_end
        self._epsilon_decay_duration = epsilon_decay_duration

        # TODO(author6) Allow for optional replay buffer config.
        if not isinstance(replay_buffer_capacity, int):
            raise ValueError("Replay buffer capacity not an integer.")
        self._replay_buffer = replay_buffer_class(replay_buffer_capacity)
        self._prev_timestep = None
        self._prev_action = None

        # Step counter to keep track of learning, eps decay and target network.
        self._step_counter = 0

        # Keep track of the last training loss achieved in an update step.
        self._last_loss_value = None

        # Create the Q-network instances

        def network(x):
            mlp = hk.nets.MLP(self._layer_sizes + [num_actions])
            return mlp(x)

        self.hk_network = hk.without_apply_rng(hk.transform(network))
        self.hk_network_apply = jax.jit(self.hk_network.apply)

        rng = jax.random.PRNGKey(42)
        x = jnp.ones([1, state_representation_size])
        self.params_q_network = self.hk_network.init(rng, x)
        self.params_target_q_network = self.hk_network.init(rng, x)

        if loss_str == "mse":
            self.loss_func = lambda x: jnp.mean(x**2)
        elif loss_str == "huber":
            # pylint: disable=g-long-lambda
            self.loss_func = lambda x: jnp.mean(
                rlax.huber_loss(x, self.huber_loss_parameter))
        else:
            raise ValueError("Not implemented, choose from 'mse', 'huber'.")
        if optimizer_str == "adam":
            opt_init, opt_update = optax.chain(
                optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
                optax.scale(learning_rate))
        elif optimizer_str == "sgd":
            opt_init, opt_update = optax.sgd(learning_rate)
        else:
            raise ValueError("Not implemented, choose from 'adam' and 'sgd'.")
        self._opt_update_fn = self._get_update_func(opt_update)
        self._opt_state = opt_init(self.params_q_network)
        self._loss_and_grad = jax.value_and_grad(self._loss, has_aux=False)
        self._jit_update = jax.jit(self.get_update())
Пример #7
0
 def loss_fn(target, q_val):
     return huber_loss(target - q_val).mean()