Exemplo n.º 1
0
        def loss_fn(online_params, target_params, transitions, rng_key):
            """Calculates loss given network parameters and transitions."""
            _, online_key, target_key, shaping_key = jax.random.split(
                rng_key, 4)
            q_tm1 = network.apply(online_params, online_key,
                                  transitions.s_tm1).q_values
            q_target_t = network.apply(target_params, target_key,
                                       transitions.s_t).q_values

            # compute shaping function F(s, a, s')
            shaped_rewards = shaping_function(q_target_t, transitions,
                                              shaping_key)

            td_errors = _batch_q_learning(
                q_tm1,
                transitions.a_tm1,
                transitions.r_t,
                transitions.discount_t,
                q_target_t,
            )
            td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
                                           grad_error_bound)
            losses = rlax.l2_loss(td_errors)
            assert losses.shape == (self._batch_size, )
            loss = jnp.mean(losses)
            return loss
Exemplo n.º 2
0
 def loss_fn(online_params, target_params, transitions, weights,
             rng_key):
     """Calculates loss given network parameters and transitions."""
     _, *apply_keys = jax.random.split(rng_key, 4)
     q_tm1 = network.apply(online_params, apply_keys[0],
                           transitions.s_tm1).q_values
     q_t = network.apply(online_params, apply_keys[1],
                         transitions.s_t).q_values
     q_target_t = network.apply(target_params, apply_keys[2],
                                transitions.s_t).q_values
     td_errors = _batch_double_q_learning(
         q_tm1,
         transitions.a_tm1,
         transitions.r_t,
         transitions.discount_t,
         q_target_t,
         q_t,
     )
     td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
                                    grad_error_bound)
     losses = rlax.l2_loss(td_errors)
     chex.assert_shape((losses, weights), (self._batch_size, ))
     # This is not the same as using a huber loss and multiplying by weights.
     loss = jnp.mean(losses * weights)
     return loss, td_errors
Exemplo n.º 3
0
def _sarsa_loss(q_tm1, q_t, transitions, rng_key):
    """Calculates SARSA loss from network outputs and transitions."""
    del rng_key  # Unused.
    grad_error_bound = 1. / 32
    td_errors = batch_sarsa_learning(q_tm1.q_values, transitions.a_tm1,
                                     transitions.r_t, transitions.discount_t,
                                     q_t.q_values, transitions.a_t)
    td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
                                   grad_error_bound)
    losses = rlax.l2_loss(td_errors)
    loss = jnp.mean(losses)
    return loss
Exemplo n.º 4
0
        def loss_fn(online_params, target_params, transitions, rng_key):
            """Calculates loss given network parameters and transitions."""
            _, online_key, target_key, shaping_key = jax.random.split(
                rng_key, 4)
            q_tm1 = network.apply(online_params, online_key,
                                  transitions.s_tm1).multi_head_output
            q_target_t = network.apply(target_params, target_key,
                                       transitions.s_t).multi_head_output

            # batch by num_heads -> batch by num_heads by num_actions
            mask = jnp.einsum('ij,k->ijk', transitions.mask_t,
                              jnp.ones(q_tm1.shape[-1]))

            masked_q = jnp.multiply(mask, q_tm1)
            masked_q_target = jnp.multiply(mask, q_target_t)

            flattened_q = jnp.reshape(q_tm1, (-1, q_tm1.shape[-1]))
            flattened_q_target = jnp.reshape(q_target_t,
                                             (-1, q_target_t.shape[-1]))

            # compute shaping function F(s, a, s')
            shaped_rewards = shaping_function(q_target_t, transitions,
                                              shaping_key)

            repeated_actions = jnp.repeat(transitions.a_tm1, num_heads)
            repeated_rewards = jnp.repeat(shaped_rewards, num_heads)
            repeated_discounts = jnp.repeat(transitions.discount_t, num_heads)

            td_errors = _batch_q_learning(
                flattened_q,
                repeated_actions,
                repeated_rewards,
                repeated_discounts,
                flattened_q_target,
            )

            td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
                                           grad_error_bound)
            losses = rlax.l2_loss(td_errors)
            assert losses.shape == (self._batch_size * num_heads, )
            loss = jnp.mean(losses)
            return loss
Exemplo n.º 5
0
 def loss_fn(online_params, target_params, transitions, rng_key):
     """Calculates loss given network parameters and transitions."""
     _, online_key, target_key = jax.random.split(rng_key, 3)
     q_tm1 = network.apply(online_params, online_key,
                           transitions.s_tm1).q_values
     q_target_t = network.apply(target_params, target_key,
                                transitions.s_t).q_values
     td_errors = _batch_q_learning(
         q_tm1,
         transitions.a_tm1,
         transitions.r_t,
         transitions.discount_t,
         q_target_t,
     )
     td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
                                    grad_error_bound)
     losses = rlax.l2_loss(td_errors)
     chex.assert_shape(losses, (self._batch_size, ))
     loss = jnp.mean(losses)
     return loss
Exemplo n.º 6
0
        def loss_fn(online_params, shaped_rewards, flattened_q_target,
                    transitions, rng_key):
            """Calculates loss given network parameters and transitions."""
            _, *apply_keys = jax.random.split(rng_key, 4)
            q_tm1 = network.apply(online_params, apply_keys[0],
                                  transitions.s_tm1).multi_head_output
            q_t = network.apply(online_params, apply_keys[1],
                                transitions.s_t).multi_head_output
            # q_target_t = network.apply(target_params, apply_keys[2],
            # transitions.s_t).multi_head_output

            flattened_q = jnp.reshape(q_tm1, (-1, q_tm1.shape[-1]))
            flattened_q_t = jnp.reshape(q_t, (-1, q_t.shape[-1]))
            # flattened_q_target = jnp.reshape(q_target_t, (-1, q_target_t.shape[-1]))

            # compute shaping function F(s, a, s')
            # shaped_rewards = shaping_function(q_target_t, transitions, apply_keys[2])

            repeated_actions = jnp.repeat(transitions.a_tm1, num_heads)
            repeated_rewards = jnp.repeat(shaped_rewards, num_heads)
            repeated_discounts = jnp.repeat(transitions.discount_t, num_heads)

            td_errors = _batch_double_q_learning(
                flattened_q,
                repeated_actions,
                repeated_rewards,
                repeated_discounts,
                flattened_q_target,
                flattened_q_t,
            )

            td_errors = rlax.clip_gradient(td_errors,
                                           -grad_error_bound / num_heads,
                                           grad_error_bound / num_heads)
            losses = rlax.l2_loss(td_errors)
            assert losses.shape == (self._batch_size * num_heads, )

            mask = jax.lax.stop_gradient(
                jnp.reshape(transitions.mask_t, (-1, )))
            loss = jnp.sum(mask * losses) / jnp.sum(mask)
            return loss
Exemplo n.º 7
0
 def loss(online_params, trg_params, obs_tm1, a_tm1, r_t, obs_t, lm_t, term_t, discount_t, weights_is):
     #  idxes = self._sample_proportional(batch_size)
     #  weights = []
     #  p_min = self._it_min.min() / self._it_sum.sum()
     #  max_weight = (p_min * len(self._storage)) ** (-beta)
     #  p_sample = self._it_sum[idxes] / self._it_sum.sum()
     #  weights = (p_sample * len(self._storage)) ** (-beta) / max_weight
     #  weights_is = jnp.power(
     #      priorities * transitions.observation_tm1.shape[0],
     #      -importance_beta)
     #  weights_is = weights_is * priorities
     #  td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
     #  errors = tf_util.huber_loss(td_error)
     #  weighted_error = tf.reduce_mean(importance_weights_ph * errors)
     #  gradients = optimizer.compute_gradients(weighted_error, var_list=q_func_vars)
     return rlax.clip_gradient(
         jnp.mean(weights_is *
         rlax.l2_loss(
             double_q_learning_td(
                 online_params, trg_params,
                 obs_tm1, a_tm1, r_t, obs_t, lm_t, term_t, discount_t))),
         -1, 1)