def _mc_loss(q_tm1, transitions, rng_key): """Calculates Monte-Carlo return loss, i.e. regression towards MC return.""" del rng_key # Unused. errors = batch_mc_learning(q_tm1.q_values, transitions.a_tm1, transitions.mc_return_tm1) loss = jnp.mean(rlax.l2_loss(errors)) return loss
def _loss(self, all_params, batch): obs_tm1 = batch["observations"] a_tm1 = batch["actions"] r_t = batch["rewards"] discount_t = batch["discounts"] obs_t = batch["next_observations"] if self._lambda is None: # remove time dim (batch has shape [batch, chunk_size, ...]) a_tm1 = a_tm1.flatten() r_t = r_t.flatten() discount_t = discount_t.flatten() obs_tm1 = jnp.reshape(obs_tm1, (-1, obs_tm1.shape[-1])) obs_t = jnp.reshape(obs_t, (-1, obs_t.shape[-1])) q_tm1 = self._q_net.apply(all_params.online, obs_tm1) q_t_val = self._q_net.apply(all_params.target, obs_t) q_t_select = self._q_net.apply(all_params.online, obs_t) if self._lambda is None: batched_loss = jax.vmap(rlax.double_q_learning) td_error = batched_loss(q_tm1, a_tm1, r_t, discount_t, q_t_val, q_t_select) else: batched_loss = jax.vmap(rlax.q_lambda) batch_lambda = self._lambda * jnp.ones(r_t.shape) td_error = batched_loss( q_tm1, a_tm1, r_t, discount_t, q_t_val, batch_lambda ) loss = jnp.mean(rlax.l2_loss(td_error)) info = dict(loss=loss) return loss, info
def _loss(self, params, pop_art_state, obs_tm1, a_tm1, r_t, discount_t, obs_t): """Loss function.""" indices = jnp.array(0) # Only one output for normalization. # Calculate targets by unnormalizing Q-values output by network. norm_q_t = self._network.apply(params, obs_t) q_t = rlax.unnormalize(pop_art_state, norm_q_t, indices) target_tm1 = r_t + discount_t * jnp.max(q_t) # Update PopArt statistics and use them to update the network parameters to # POP (preserve outputs precisely). If there were target networks, the # parameters for these would also need to be updated. final_linear_module_name = "mlp/~/linear_1" mutable_params = hk.data_structures.to_mutable_dict(params) linear_params = mutable_params[final_linear_module_name] popped_linear_params, new_pop_art_state = self._pop_art_update( params=linear_params, state=pop_art_state, targets=target_tm1, indices=indices) mutable_params[final_linear_module_name] = popped_linear_params popped_params = hk.data_structures.to_immutable_dict(mutable_params) # Normalize target with updated PopArt statistics. norm_target_tm1 = rlax.normalize(pop_art_state, target_tm1, indices) # Calculate parameter update with normalized target and popped parameters. norm_q_t = self._network.apply(popped_params, obs_t) norm_q_tm1 = self._network.apply(popped_params, obs_tm1) td_error = jax.lax.stop_gradient(norm_target_tm1) - norm_q_tm1[a_tm1] return rlax.l2_loss(td_error), new_pop_art_state
def loss_fn(online_params, target_params, transitions, weights, rng_key): """Calculates loss given network parameters and transitions.""" _, *apply_keys = jax.random.split(rng_key, 4) q_tm1 = network.apply(online_params, apply_keys[0], transitions.s_tm1).q_values q_t = network.apply(online_params, apply_keys[1], transitions.s_t).q_values q_target_t = network.apply(target_params, apply_keys[2], transitions.s_t).q_values td_errors = _batch_double_q_learning( q_tm1, transitions.a_tm1, transitions.r_t, transitions.discount_t, q_target_t, q_t, ) td_errors = rlax.clip_gradient(td_errors, -grad_error_bound, grad_error_bound) losses = rlax.l2_loss(td_errors) chex.assert_shape((losses, weights), (self._batch_size, )) # This is not the same as using a huber loss and multiplying by weights. loss = jnp.mean(losses * weights) return loss, td_errors
def loss_fn(online_params, target_params, transitions, rng_key): """Calculates loss given network parameters and transitions.""" _, online_key, target_key, shaping_key = jax.random.split( rng_key, 4) q_tm1 = network.apply(online_params, online_key, transitions.s_tm1).q_values q_target_t = network.apply(target_params, target_key, transitions.s_t).q_values # compute shaping function F(s, a, s') shaped_rewards = shaping_function(q_target_t, transitions, shaping_key) td_errors = _batch_q_learning( q_tm1, transitions.a_tm1, transitions.r_t, transitions.discount_t, q_target_t, ) td_errors = rlax.clip_gradient(td_errors, -grad_error_bound, grad_error_bound) losses = rlax.l2_loss(td_errors) assert losses.shape == (self._batch_size, ) loss = jnp.mean(losses) return loss
def _loss(self, online_params, target_params, obs_tm1, a_tm1, r_t, discount_t, obs_t): q_tm1 = self._network.apply(online_params, obs_tm1) q_t_val = self._network.apply(target_params, obs_t) q_t_select = self._network.apply(online_params, obs_t) batched_loss = jax.vmap(rlax.double_q_learning) td_error = batched_loss(q_tm1, a_tm1, r_t, discount_t, q_t_val, q_t_select) return jnp.mean(rlax.l2_loss(td_error))
def dqn_learning_loss(net_params, target_params, batch): obs_tm1, obs_t, a_tm1, r_t, discount_t = batch q_tm1 = network.apply(net_params, obs_tm1) q_t_value = network.apply(target_params, obs_t) q_t_selector = network.apply(net_params, obs_t) td_error = batched_loss(q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector) return jnp.mean(rlax.l2_loss(td_error))
def _sarsa_loss(q_tm1, q_t, transitions, rng_key): """Calculates SARSA loss from network outputs and transitions.""" del rng_key # Unused. grad_error_bound = 1. / 32 td_errors = batch_sarsa_learning(q_tm1.q_values, transitions.a_tm1, transitions.r_t, transitions.discount_t, q_t.q_values, transitions.a_t) td_errors = rlax.clip_gradient(td_errors, -grad_error_bound, grad_error_bound) losses = rlax.l2_loss(td_errors) loss = jnp.mean(losses) return loss
def critic_loss(self, critic_params: hk.Params, target_critic_params: hk.Params, target_actor_params: hk.Params, state: np.ndarray, action: np.ndarray, next_state: np.ndarray, reward: np.ndarray, not_done: np.ndarray, rng: jnp.ndarray) -> jnp.DeviceArray: """ TD3 adds truncated Gaussian noise to the policy while training the critic. Can be seen as a form of 'Exploration Consciousness' https://arxiv.org/abs/1812.05551 or simply as a regularization scheme. As this helps stabilize the critic, we also use this for the DDPG update rule. """ noise = (jax.random.normal(rng, shape=action.shape) * self.policy_noise).clip(-self.noise_clip, self.noise_clip) # Make sure the noisy action is within the valid bounds. next_action = (self.actor.apply(target_actor_params, next_state) + noise).clip(-self.max_action, self.max_action) next_q_1, next_q_2 = self.critic.apply( target_critic_params, jnp.concatenate((next_state, next_action), 1)) if self.td3_update: next_q = jax.lax.min(next_q_1, next_q_2) else: # Since the actor uses Q_1 for training, setting this as the target for the critic updates is sufficient to # obtain an equivalent update. next_q = next_q_1 # Cut the gradient from flowing through the target critic. This is more efficient, computationally. target_q = jax.lax.stop_gradient(reward + self.discount * next_q * not_done) q_1, q_2 = self.critic.apply(critic_params, jnp.concatenate((state, action), 1)) return jnp.mean( rlax.l2_loss(q_1, target_q) + rlax.l2_loss(q_2, target_q))
def loss_fn(online_params, target_params, transitions, rng_key): """Calculates loss given network parameters and transitions.""" _, online_key, target_key, shaping_key = jax.random.split( rng_key, 4) q_tm1 = network.apply(online_params, online_key, transitions.s_tm1).multi_head_output q_target_t = network.apply(target_params, target_key, transitions.s_t).multi_head_output # batch by num_heads -> batch by num_heads by num_actions mask = jnp.einsum('ij,k->ijk', transitions.mask_t, jnp.ones(q_tm1.shape[-1])) masked_q = jnp.multiply(mask, q_tm1) masked_q_target = jnp.multiply(mask, q_target_t) flattened_q = jnp.reshape(q_tm1, (-1, q_tm1.shape[-1])) flattened_q_target = jnp.reshape(q_target_t, (-1, q_target_t.shape[-1])) # compute shaping function F(s, a, s') shaped_rewards = shaping_function(q_target_t, transitions, shaping_key) repeated_actions = jnp.repeat(transitions.a_tm1, num_heads) repeated_rewards = jnp.repeat(shaped_rewards, num_heads) repeated_discounts = jnp.repeat(transitions.discount_t, num_heads) td_errors = _batch_q_learning( flattened_q, repeated_actions, repeated_rewards, repeated_discounts, flattened_q_target, ) td_errors = rlax.clip_gradient(td_errors, -grad_error_bound, grad_error_bound) losses = rlax.l2_loss(td_errors) assert losses.shape == (self._batch_size * num_heads, ) loss = jnp.mean(losses) return loss
def loss_fn(online_params, target_params, transitions, rng_key): """Calculates loss given network parameters and transitions.""" _, online_key, target_key = jax.random.split(rng_key, 3) q_tm1 = network.apply(online_params, online_key, transitions.s_tm1).q_values q_target_t = network.apply(target_params, target_key, transitions.s_t).q_values td_errors = _batch_q_learning( q_tm1, transitions.a_tm1, transitions.r_t, transitions.discount_t, q_target_t, ) td_errors = rlax.clip_gradient(td_errors, -grad_error_bound, grad_error_bound) losses = rlax.l2_loss(td_errors) chex.assert_shape(losses, (self._batch_size, )) loss = jnp.mean(losses) return loss
def _loss(self, params, actions, timesteps): """Calculates Q-lambda loss given parameters, actions and timesteps.""" network_apply_sequence = jax.vmap(self._network.apply, in_axes=(None, 0)) q = network_apply_sequence(params, timesteps.observation) # Use a mask since the sequence could cross episode boundaries. mask = jnp.not_equal(timesteps.step_type, int(dm_env.StepType.LAST)) a_tm1 = actions[1:] r_t = timesteps.reward[1:] # Discount ought to be zero on a LAST timestep, use the mask to ensure this. discount_t = timesteps.discount[1:] * mask[1:] q_tm1 = q[:-1] q_t = q[1:] mask_tm1 = mask[:-1] # Mask out TD errors for the last state in an episode. td_error_tm1 = mask_tm1 * rlax.q_lambda( q_tm1, a_tm1, r_t, discount_t, q_t, lambda_=self._lambda) return jnp.sum(rlax.l2_loss(td_error_tm1)) / jnp.sum(mask_tm1)
def loss_fn(online_params, shaped_rewards, flattened_q_target, transitions, rng_key): """Calculates loss given network parameters and transitions.""" _, *apply_keys = jax.random.split(rng_key, 4) q_tm1 = network.apply(online_params, apply_keys[0], transitions.s_tm1).multi_head_output q_t = network.apply(online_params, apply_keys[1], transitions.s_t).multi_head_output # q_target_t = network.apply(target_params, apply_keys[2], # transitions.s_t).multi_head_output flattened_q = jnp.reshape(q_tm1, (-1, q_tm1.shape[-1])) flattened_q_t = jnp.reshape(q_t, (-1, q_t.shape[-1])) # flattened_q_target = jnp.reshape(q_target_t, (-1, q_target_t.shape[-1])) # compute shaping function F(s, a, s') # shaped_rewards = shaping_function(q_target_t, transitions, apply_keys[2]) repeated_actions = jnp.repeat(transitions.a_tm1, num_heads) repeated_rewards = jnp.repeat(shaped_rewards, num_heads) repeated_discounts = jnp.repeat(transitions.discount_t, num_heads) td_errors = _batch_double_q_learning( flattened_q, repeated_actions, repeated_rewards, repeated_discounts, flattened_q_target, flattened_q_t, ) td_errors = rlax.clip_gradient(td_errors, -grad_error_bound / num_heads, grad_error_bound / num_heads) losses = rlax.l2_loss(td_errors) assert losses.shape == (self._batch_size * num_heads, ) mask = jax.lax.stop_gradient( jnp.reshape(transitions.mask_t, (-1, ))) loss = jnp.sum(mask * losses) / jnp.sum(mask) return loss
def loss(online_params, trg_params, obs_tm1, a_tm1, r_t, obs_t, lm_t, term_t, discount_t, weights_is): # idxes = self._sample_proportional(batch_size) # weights = [] # p_min = self._it_min.min() / self._it_sum.sum() # max_weight = (p_min * len(self._storage)) ** (-beta) # p_sample = self._it_sum[idxes] / self._it_sum.sum() # weights = (p_sample * len(self._storage)) ** (-beta) / max_weight # weights_is = jnp.power( # priorities * transitions.observation_tm1.shape[0], # -importance_beta) # weights_is = weights_is * priorities # td_error = q_t_selected - tf.stop_gradient(q_t_selected_target) # errors = tf_util.huber_loss(td_error) # weighted_error = tf.reduce_mean(importance_weights_ph * errors) # gradients = optimizer.compute_gradients(weighted_error, var_list=q_func_vars) return rlax.clip_gradient( jnp.mean(weights_is * rlax.l2_loss( double_q_learning_td( online_params, trg_params, obs_tm1, a_tm1, r_t, obs_t, lm_t, term_t, discount_t))), -1, 1)
def q_learning_loss(net_params, obs_tm1, a_tm1, r_t, discount_t, q_t): q_tm1 = network.apply(net_params, obs_tm1) td_error = rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t) return rlax.l2_loss(td_error)
def _loss(self, params, obs_tm1, a_tm1, r_t, discount_t, obs_t): q_tm1 = self._network.apply(params, obs_tm1) q_t = self._network.apply(params, obs_t) td_error = rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t) return rlax.l2_loss(td_error)
def _q_regression_loss(q_tm1, q_tm1_target): """Loss for regression of all action values towards targets.""" errors = q_tm1.q_values - jax.lax.stop_gradient(q_tm1_target.q_values) loss = jnp.mean(rlax.l2_loss(errors)) return loss