Exemplo n.º 1
0
def _mc_loss(q_tm1, transitions, rng_key):
    """Calculates Monte-Carlo return loss, i.e. regression towards MC return."""
    del rng_key  # Unused.
    errors = batch_mc_learning(q_tm1.q_values, transitions.a_tm1,
                               transitions.mc_return_tm1)
    loss = jnp.mean(rlax.l2_loss(errors))
    return loss
Exemplo n.º 2
0
    def _loss(self, all_params, batch):
        obs_tm1 = batch["observations"]
        a_tm1 = batch["actions"]
        r_t = batch["rewards"]
        discount_t = batch["discounts"]
        obs_t = batch["next_observations"]

        if self._lambda is None:
            # remove time dim (batch has shape [batch, chunk_size, ...])
            a_tm1 = a_tm1.flatten()
            r_t = r_t.flatten()
            discount_t = discount_t.flatten()
            obs_tm1 = jnp.reshape(obs_tm1, (-1, obs_tm1.shape[-1]))
            obs_t = jnp.reshape(obs_t, (-1, obs_t.shape[-1]))

        q_tm1 = self._q_net.apply(all_params.online, obs_tm1)
        q_t_val = self._q_net.apply(all_params.target, obs_t)
        q_t_select = self._q_net.apply(all_params.online, obs_t)

        if self._lambda is None:
            batched_loss = jax.vmap(rlax.double_q_learning)
            td_error = batched_loss(q_tm1, a_tm1, r_t, discount_t, q_t_val, q_t_select)
        else:
            batched_loss = jax.vmap(rlax.q_lambda)
            batch_lambda = self._lambda * jnp.ones(r_t.shape)
            td_error = batched_loss(
                q_tm1, a_tm1, r_t, discount_t, q_t_val, batch_lambda
            )

        loss = jnp.mean(rlax.l2_loss(td_error))

        info = dict(loss=loss)
        return loss, info
Exemplo n.º 3
0
    def _loss(self, params, pop_art_state, obs_tm1, a_tm1, r_t, discount_t,
              obs_t):
        """Loss function."""
        indices = jnp.array(0)  # Only one output for normalization.

        # Calculate targets by unnormalizing Q-values output by network.
        norm_q_t = self._network.apply(params, obs_t)
        q_t = rlax.unnormalize(pop_art_state, norm_q_t, indices)
        target_tm1 = r_t + discount_t * jnp.max(q_t)

        # Update PopArt statistics and use them to update the network parameters to
        # POP (preserve outputs precisely). If there were target networks, the
        # parameters for these would also need to be updated.
        final_linear_module_name = "mlp/~/linear_1"
        mutable_params = hk.data_structures.to_mutable_dict(params)
        linear_params = mutable_params[final_linear_module_name]
        popped_linear_params, new_pop_art_state = self._pop_art_update(
            params=linear_params,
            state=pop_art_state,
            targets=target_tm1,
            indices=indices)
        mutable_params[final_linear_module_name] = popped_linear_params
        popped_params = hk.data_structures.to_immutable_dict(mutable_params)

        # Normalize target with updated PopArt statistics.
        norm_target_tm1 = rlax.normalize(pop_art_state, target_tm1, indices)

        # Calculate parameter update with normalized target and popped parameters.
        norm_q_t = self._network.apply(popped_params, obs_t)
        norm_q_tm1 = self._network.apply(popped_params, obs_tm1)
        td_error = jax.lax.stop_gradient(norm_target_tm1) - norm_q_tm1[a_tm1]
        return rlax.l2_loss(td_error), new_pop_art_state
Exemplo n.º 4
0
 def loss_fn(online_params, target_params, transitions, weights,
             rng_key):
     """Calculates loss given network parameters and transitions."""
     _, *apply_keys = jax.random.split(rng_key, 4)
     q_tm1 = network.apply(online_params, apply_keys[0],
                           transitions.s_tm1).q_values
     q_t = network.apply(online_params, apply_keys[1],
                         transitions.s_t).q_values
     q_target_t = network.apply(target_params, apply_keys[2],
                                transitions.s_t).q_values
     td_errors = _batch_double_q_learning(
         q_tm1,
         transitions.a_tm1,
         transitions.r_t,
         transitions.discount_t,
         q_target_t,
         q_t,
     )
     td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
                                    grad_error_bound)
     losses = rlax.l2_loss(td_errors)
     chex.assert_shape((losses, weights), (self._batch_size, ))
     # This is not the same as using a huber loss and multiplying by weights.
     loss = jnp.mean(losses * weights)
     return loss, td_errors
Exemplo n.º 5
0
        def loss_fn(online_params, target_params, transitions, rng_key):
            """Calculates loss given network parameters and transitions."""
            _, online_key, target_key, shaping_key = jax.random.split(
                rng_key, 4)
            q_tm1 = network.apply(online_params, online_key,
                                  transitions.s_tm1).q_values
            q_target_t = network.apply(target_params, target_key,
                                       transitions.s_t).q_values

            # compute shaping function F(s, a, s')
            shaped_rewards = shaping_function(q_target_t, transitions,
                                              shaping_key)

            td_errors = _batch_q_learning(
                q_tm1,
                transitions.a_tm1,
                transitions.r_t,
                transitions.discount_t,
                q_target_t,
            )
            td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
                                           grad_error_bound)
            losses = rlax.l2_loss(td_errors)
            assert losses.shape == (self._batch_size, )
            loss = jnp.mean(losses)
            return loss
Exemplo n.º 6
0
 def _loss(self, online_params, target_params, obs_tm1, a_tm1, r_t,
           discount_t, obs_t):
     q_tm1 = self._network.apply(online_params, obs_tm1)
     q_t_val = self._network.apply(target_params, obs_t)
     q_t_select = self._network.apply(online_params, obs_t)
     batched_loss = jax.vmap(rlax.double_q_learning)
     td_error = batched_loss(q_tm1, a_tm1, r_t, discount_t, q_t_val,
                             q_t_select)
     return jnp.mean(rlax.l2_loss(td_error))
Exemplo n.º 7
0
        def dqn_learning_loss(net_params, target_params, batch):
            obs_tm1, obs_t, a_tm1, r_t, discount_t = batch
            q_tm1 = network.apply(net_params, obs_tm1)
            q_t_value = network.apply(target_params, obs_t)
            q_t_selector = network.apply(net_params, obs_t)

            td_error = batched_loss(q_tm1, a_tm1, r_t, discount_t, q_t_value,
                                    q_t_selector)
            return jnp.mean(rlax.l2_loss(td_error))
Exemplo n.º 8
0
def _sarsa_loss(q_tm1, q_t, transitions, rng_key):
    """Calculates SARSA loss from network outputs and transitions."""
    del rng_key  # Unused.
    grad_error_bound = 1. / 32
    td_errors = batch_sarsa_learning(q_tm1.q_values, transitions.a_tm1,
                                     transitions.r_t, transitions.discount_t,
                                     q_t.q_values, transitions.a_t)
    td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
                                   grad_error_bound)
    losses = rlax.l2_loss(td_errors)
    loss = jnp.mean(losses)
    return loss
Exemplo n.º 9
0
    def critic_loss(self, critic_params: hk.Params,
                    target_critic_params: hk.Params,
                    target_actor_params: hk.Params, state: np.ndarray,
                    action: np.ndarray, next_state: np.ndarray,
                    reward: np.ndarray, not_done: np.ndarray,
                    rng: jnp.ndarray) -> jnp.DeviceArray:
        """
            TD3 adds truncated Gaussian noise to the policy while training the critic.
            Can be seen as a form of 'Exploration Consciousness' https://arxiv.org/abs/1812.05551 or simply as a
            regularization scheme.
            As this helps stabilize the critic, we also use this for the DDPG update rule.
        """
        noise = (jax.random.normal(rng, shape=action.shape) *
                 self.policy_noise).clip(-self.noise_clip, self.noise_clip)

        # Make sure the noisy action is within the valid bounds.
        next_action = (self.actor.apply(target_actor_params, next_state) +
                       noise).clip(-self.max_action, self.max_action)

        next_q_1, next_q_2 = self.critic.apply(
            target_critic_params, jnp.concatenate((next_state, next_action),
                                                  1))
        if self.td3_update:
            next_q = jax.lax.min(next_q_1, next_q_2)
        else:
            # Since the actor uses Q_1 for training, setting this as the target for the critic updates is sufficient to
            # obtain an equivalent update.
            next_q = next_q_1
        # Cut the gradient from flowing through the target critic. This is more efficient, computationally.
        target_q = jax.lax.stop_gradient(reward +
                                         self.discount * next_q * not_done)
        q_1, q_2 = self.critic.apply(critic_params,
                                     jnp.concatenate((state, action), 1))

        return jnp.mean(
            rlax.l2_loss(q_1, target_q) + rlax.l2_loss(q_2, target_q))
Exemplo n.º 10
0
        def loss_fn(online_params, target_params, transitions, rng_key):
            """Calculates loss given network parameters and transitions."""
            _, online_key, target_key, shaping_key = jax.random.split(
                rng_key, 4)
            q_tm1 = network.apply(online_params, online_key,
                                  transitions.s_tm1).multi_head_output
            q_target_t = network.apply(target_params, target_key,
                                       transitions.s_t).multi_head_output

            # batch by num_heads -> batch by num_heads by num_actions
            mask = jnp.einsum('ij,k->ijk', transitions.mask_t,
                              jnp.ones(q_tm1.shape[-1]))

            masked_q = jnp.multiply(mask, q_tm1)
            masked_q_target = jnp.multiply(mask, q_target_t)

            flattened_q = jnp.reshape(q_tm1, (-1, q_tm1.shape[-1]))
            flattened_q_target = jnp.reshape(q_target_t,
                                             (-1, q_target_t.shape[-1]))

            # compute shaping function F(s, a, s')
            shaped_rewards = shaping_function(q_target_t, transitions,
                                              shaping_key)

            repeated_actions = jnp.repeat(transitions.a_tm1, num_heads)
            repeated_rewards = jnp.repeat(shaped_rewards, num_heads)
            repeated_discounts = jnp.repeat(transitions.discount_t, num_heads)

            td_errors = _batch_q_learning(
                flattened_q,
                repeated_actions,
                repeated_rewards,
                repeated_discounts,
                flattened_q_target,
            )

            td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
                                           grad_error_bound)
            losses = rlax.l2_loss(td_errors)
            assert losses.shape == (self._batch_size * num_heads, )
            loss = jnp.mean(losses)
            return loss
Exemplo n.º 11
0
 def loss_fn(online_params, target_params, transitions, rng_key):
     """Calculates loss given network parameters and transitions."""
     _, online_key, target_key = jax.random.split(rng_key, 3)
     q_tm1 = network.apply(online_params, online_key,
                           transitions.s_tm1).q_values
     q_target_t = network.apply(target_params, target_key,
                                transitions.s_t).q_values
     td_errors = _batch_q_learning(
         q_tm1,
         transitions.a_tm1,
         transitions.r_t,
         transitions.discount_t,
         q_target_t,
     )
     td_errors = rlax.clip_gradient(td_errors, -grad_error_bound,
                                    grad_error_bound)
     losses = rlax.l2_loss(td_errors)
     chex.assert_shape(losses, (self._batch_size, ))
     loss = jnp.mean(losses)
     return loss
Exemplo n.º 12
0
    def _loss(self, params, actions, timesteps):
        """Calculates Q-lambda loss given parameters, actions and timesteps."""
        network_apply_sequence = jax.vmap(self._network.apply,
                                          in_axes=(None, 0))
        q = network_apply_sequence(params, timesteps.observation)

        # Use a mask since the sequence could cross episode boundaries.
        mask = jnp.not_equal(timesteps.step_type, int(dm_env.StepType.LAST))
        a_tm1 = actions[1:]
        r_t = timesteps.reward[1:]
        # Discount ought to be zero on a LAST timestep, use the mask to ensure this.
        discount_t = timesteps.discount[1:] * mask[1:]
        q_tm1 = q[:-1]
        q_t = q[1:]
        mask_tm1 = mask[:-1]

        # Mask out TD errors for the last state in an episode.
        td_error_tm1 = mask_tm1 * rlax.q_lambda(
            q_tm1, a_tm1, r_t, discount_t, q_t, lambda_=self._lambda)
        return jnp.sum(rlax.l2_loss(td_error_tm1)) / jnp.sum(mask_tm1)
Exemplo n.º 13
0
        def loss_fn(online_params, shaped_rewards, flattened_q_target,
                    transitions, rng_key):
            """Calculates loss given network parameters and transitions."""
            _, *apply_keys = jax.random.split(rng_key, 4)
            q_tm1 = network.apply(online_params, apply_keys[0],
                                  transitions.s_tm1).multi_head_output
            q_t = network.apply(online_params, apply_keys[1],
                                transitions.s_t).multi_head_output
            # q_target_t = network.apply(target_params, apply_keys[2],
            # transitions.s_t).multi_head_output

            flattened_q = jnp.reshape(q_tm1, (-1, q_tm1.shape[-1]))
            flattened_q_t = jnp.reshape(q_t, (-1, q_t.shape[-1]))
            # flattened_q_target = jnp.reshape(q_target_t, (-1, q_target_t.shape[-1]))

            # compute shaping function F(s, a, s')
            # shaped_rewards = shaping_function(q_target_t, transitions, apply_keys[2])

            repeated_actions = jnp.repeat(transitions.a_tm1, num_heads)
            repeated_rewards = jnp.repeat(shaped_rewards, num_heads)
            repeated_discounts = jnp.repeat(transitions.discount_t, num_heads)

            td_errors = _batch_double_q_learning(
                flattened_q,
                repeated_actions,
                repeated_rewards,
                repeated_discounts,
                flattened_q_target,
                flattened_q_t,
            )

            td_errors = rlax.clip_gradient(td_errors,
                                           -grad_error_bound / num_heads,
                                           grad_error_bound / num_heads)
            losses = rlax.l2_loss(td_errors)
            assert losses.shape == (self._batch_size * num_heads, )

            mask = jax.lax.stop_gradient(
                jnp.reshape(transitions.mask_t, (-1, )))
            loss = jnp.sum(mask * losses) / jnp.sum(mask)
            return loss
Exemplo n.º 14
0
 def loss(online_params, trg_params, obs_tm1, a_tm1, r_t, obs_t, lm_t, term_t, discount_t, weights_is):
     #  idxes = self._sample_proportional(batch_size)
     #  weights = []
     #  p_min = self._it_min.min() / self._it_sum.sum()
     #  max_weight = (p_min * len(self._storage)) ** (-beta)
     #  p_sample = self._it_sum[idxes] / self._it_sum.sum()
     #  weights = (p_sample * len(self._storage)) ** (-beta) / max_weight
     #  weights_is = jnp.power(
     #      priorities * transitions.observation_tm1.shape[0],
     #      -importance_beta)
     #  weights_is = weights_is * priorities
     #  td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
     #  errors = tf_util.huber_loss(td_error)
     #  weighted_error = tf.reduce_mean(importance_weights_ph * errors)
     #  gradients = optimizer.compute_gradients(weighted_error, var_list=q_func_vars)
     return rlax.clip_gradient(
         jnp.mean(weights_is *
         rlax.l2_loss(
             double_q_learning_td(
                 online_params, trg_params,
                 obs_tm1, a_tm1, r_t, obs_t, lm_t, term_t, discount_t))),
         -1, 1)
Exemplo n.º 15
0
 def q_learning_loss(net_params, obs_tm1, a_tm1, r_t, discount_t, q_t):
     q_tm1 = network.apply(net_params, obs_tm1)
     td_error = rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t)
     return rlax.l2_loss(td_error)
Exemplo n.º 16
0
 def _loss(self, params, obs_tm1, a_tm1, r_t, discount_t, obs_t):
     q_tm1 = self._network.apply(params, obs_tm1)
     q_t = self._network.apply(params, obs_t)
     td_error = rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t)
     return rlax.l2_loss(td_error)
Exemplo n.º 17
0
def _q_regression_loss(q_tm1, q_tm1_target):
    """Loss for regression of all action values towards targets."""
    errors = q_tm1.q_values - jax.lax.stop_gradient(q_tm1_target.q_values)
    loss = jnp.mean(rlax.l2_loss(errors))
    return loss