def __call__( self, network: networks_lib.FeedForwardNetwork, params: networks_lib.Params, target_params: networks_lib.Params, batch: reverb.ReplaySample, key: networks_lib.PRNGKey, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" transitions: types.Transition = batch.data keys, probs, *_ = batch.info # Forward pass. if self.stochastic_network: q_tm1 = network.apply(params, key, transitions.observation) q_t_value = network.apply(target_params, key, transitions.next_observation) q_t_selector = network.apply(params, key, transitions.next_observation) else: q_tm1 = network.apply(params, transitions.observation) q_t_value = network.apply(target_params, transitions.next_observation) q_t_selector = network.apply(params, transitions.next_observation) # Cast and clip rewards. d_t = (transitions.discount * self.discount).astype(jnp.float32) r_t = jnp.clip(transitions.reward, -self.max_abs_reward, self.max_abs_reward).astype(jnp.float32) # Compute double Q-learning n-step TD-error. batch_error = jax.vmap(rlax.double_q_learning) td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t_value, q_t_selector) batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter) # Importance weighting. importance_weights = (1. / probs).astype(jnp.float32) importance_weights **= self.importance_sampling_exponent importance_weights /= jnp.max(importance_weights) # Reweight. loss = jnp.mean(importance_weights * batch_loss) # [] reverb_update = learning_lib.ReverbUpdate( keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64)) extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update) return loss, extra
def __call__( self, network: networks_lib.FeedForwardNetwork, params: networks_lib.Params, target_params: networks_lib.Params, batch: reverb.ReplaySample, key: networks_lib.PRNGKey, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" del key transitions: types.Transition = batch.data # Forward pass. q_online_s = network.apply(params, transitions.observation) action_one_hot = jax.nn.one_hot(transitions.action, q_online_s.shape[-1]) q_online_sa = jnp.sum(action_one_hot * q_online_s, axis=-1) q_target_s = network.apply(target_params, transitions.observation) q_target_next = network.apply(target_params, transitions.next_observation) # Cast and clip rewards. d_t = (transitions.discount * self.discount).astype(jnp.float32) r_t = jnp.clip(transitions.reward, -self.max_abs_reward, self.max_abs_reward).astype(jnp.float32) # Munchausen term : tau * log_pi(a|s) munchausen_term = self.entropy_temperature * jax.nn.log_softmax( q_target_s / self.entropy_temperature, axis=-1) munchausen_term_a = jnp.sum(action_one_hot * munchausen_term, axis=-1) munchausen_term_a = jnp.clip(munchausen_term_a, a_min=self.clip_value_min, a_max=0.) # Soft Bellman operator applied to q next_v = self.entropy_temperature * jax.nn.logsumexp( q_target_next / self.entropy_temperature, axis=-1) target_q = jax.lax.stop_gradient(r_t + self.munchausen_coefficient * munchausen_term_a + d_t * next_v) batch_loss = rlax.huber_loss(target_q - q_online_sa, self.huber_loss_parameter) loss = jnp.mean(batch_loss) extra = learning_lib.LossExtra(metrics={}) return loss, extra
def __call__( self, network: networks_lib.FeedForwardNetwork, params: networks_lib.Params, target_params: networks_lib.Params, batch: reverb.ReplaySample, key: networks_lib.PRNGKey, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" del key transitions: types.Transition = batch.data keys, probs, *_ = batch.info # Forward pass. _, logits_tm1, atoms_tm1 = network.apply(params, transitions.observation) _, logits_t, atoms_t = network.apply(target_params, transitions.next_observation) q_t_selector, _, _ = network.apply(params, transitions.next_observation) # Cast and clip rewards. d_t = (transitions.discount * self.discount).astype(jnp.float32) r_t = jnp.clip(transitions.reward, -self.max_abs_reward, self.max_abs_reward).astype(jnp.float32) # Compute categorical double Q-learning loss. batch_loss_fn = jax.vmap(rlax.categorical_double_q_learning, in_axes=(None, 0, 0, 0, 0, None, 0, 0)) batch_loss = batch_loss_fn(atoms_tm1, logits_tm1, transitions.action, r_t, d_t, atoms_t, logits_t, q_t_selector) # Importance weighting. importance_weights = (1. / probs).astype(jnp.float32) importance_weights **= self.importance_sampling_exponent importance_weights /= jnp.max(importance_weights) # Reweight. loss = jnp.mean(importance_weights * batch_loss) # [] reverb_update = learning_lib.ReverbUpdate( keys=keys, priorities=jnp.abs(batch_loss).astype(jnp.float64)) extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update) return loss, extra
def __call__( self, network: hk.Transformed, params: hk.Params, target_params: hk.Params, batch: reverb.ReplaySample, key: jnp.DeviceArray, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" del key o_tm1, a_tm1, r_t, d_t, o_t = batch.data keys, probs, *_ = batch.info # Forward pass. q_tm1 = network.apply(params, o_tm1) q_t_value = network.apply(target_params, o_t) q_t_selector = network.apply(params, o_t) # Cast and clip rewards. d_t = (d_t * self.discount).astype(jnp.float32) r_t = jnp.clip(r_t, -self.max_abs_reward, self.max_abs_reward).astype(jnp.float32) # Compute double Q-learning n-step TD-error. batch_error = jax.vmap(rlax.double_q_learning) td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value, q_t_selector) batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter) # Importance weighting. importance_weights = (1. / probs).astype(jnp.float32) importance_weights **= self.importance_sampling_exponent importance_weights /= jnp.max(importance_weights) # Reweight. loss = jnp.mean(importance_weights * batch_loss) # [] reverb_update = learning_lib.ReverbUpdate( keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64)) extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update) return loss, extra
def __call__( self, network: networks_lib.FeedForwardNetwork, params: networks_lib.Params, target_params: networks_lib.Params, batch: reverb.ReplaySample, key: networks_lib.PRNGKey, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" del key transitions: types.Transition = batch.data dist_q_tm1 = network.apply(params, transitions.observation)['q_dist'] dist_q_target_t = network.apply(target_params, transitions.next_observation)['q_dist'] # Swap distribution and action dimension, since # rlax.quantile_q_learning expects it that way. dist_q_tm1 = jnp.swapaxes(dist_q_tm1, 1, 2) dist_q_target_t = jnp.swapaxes(dist_q_target_t, 1, 2) quantiles = ((jnp.arange(self.num_atoms, dtype=jnp.float32) + 0.5) / self.num_atoms) batch_quantile_q_learning = jax.vmap(rlax.quantile_q_learning, in_axes=(0, None, 0, 0, 0, 0, 0, None)) losses = batch_quantile_q_learning( dist_q_tm1, quantiles, transitions.action, transitions.reward, transitions.discount, dist_q_target_t, # No double Q-learning here. dist_q_target_t, self.huber_param, ) loss = jnp.mean(losses) extra = learning_lib.LossExtra(metrics={'mean_loss': loss}) return loss, extra