def __call__( self, network: networks_lib.FeedForwardNetwork, params: networks_lib.Params, target_params: networks_lib.Params, batch: reverb.ReplaySample, key: networks_lib.PRNGKey, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" transitions: types.Transition = batch.data keys, probs, *_ = batch.info # Forward pass. if self.stochastic_network: q_tm1 = network.apply(params, key, transitions.observation) q_t_value = network.apply(target_params, key, transitions.next_observation) q_t_selector = network.apply(params, key, transitions.next_observation) else: q_tm1 = network.apply(params, transitions.observation) q_t_value = network.apply(target_params, transitions.next_observation) q_t_selector = network.apply(params, transitions.next_observation) # Cast and clip rewards. d_t = (transitions.discount * self.discount).astype(jnp.float32) r_t = jnp.clip(transitions.reward, -self.max_abs_reward, self.max_abs_reward).astype(jnp.float32) # Compute double Q-learning n-step TD-error. batch_error = jax.vmap(rlax.double_q_learning) td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t_value, q_t_selector) batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter) # Importance weighting. importance_weights = (1. / probs).astype(jnp.float32) importance_weights **= self.importance_sampling_exponent importance_weights /= jnp.max(importance_weights) # Reweight. loss = jnp.mean(importance_weights * batch_loss) # [] reverb_update = learning_lib.ReverbUpdate( keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64)) extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update) return loss, extra
def __call__( self, network: networks_lib.FeedForwardNetwork, params: networks_lib.Params, target_params: networks_lib.Params, batch: reverb.ReplaySample, key: networks_lib.PRNGKey, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" del key transitions: types.Transition = batch.data keys, probs, *_ = batch.info # Forward pass. _, logits_tm1, atoms_tm1 = network.apply(params, transitions.observation) _, logits_t, atoms_t = network.apply(target_params, transitions.next_observation) q_t_selector, _, _ = network.apply(params, transitions.next_observation) # Cast and clip rewards. d_t = (transitions.discount * self.discount).astype(jnp.float32) r_t = jnp.clip(transitions.reward, -self.max_abs_reward, self.max_abs_reward).astype(jnp.float32) # Compute categorical double Q-learning loss. batch_loss_fn = jax.vmap(rlax.categorical_double_q_learning, in_axes=(None, 0, 0, 0, 0, None, 0, 0)) batch_loss = batch_loss_fn(atoms_tm1, logits_tm1, transitions.action, r_t, d_t, atoms_t, logits_t, q_t_selector) # Importance weighting. importance_weights = (1. / probs).astype(jnp.float32) importance_weights **= self.importance_sampling_exponent importance_weights /= jnp.max(importance_weights) # Reweight. loss = jnp.mean(importance_weights * batch_loss) # [] reverb_update = learning_lib.ReverbUpdate( keys=keys, priorities=jnp.abs(batch_loss).astype(jnp.float64)) extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update) return loss, extra
def __call__( self, network: hk.Transformed, params: hk.Params, target_params: hk.Params, batch: reverb.ReplaySample, key: jnp.DeviceArray, ) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]: """Calculate a loss on a single batch of data.""" del key o_tm1, a_tm1, r_t, d_t, o_t = batch.data keys, probs, *_ = batch.info # Forward pass. q_tm1 = network.apply(params, o_tm1) q_t_value = network.apply(target_params, o_t) q_t_selector = network.apply(params, o_t) # Cast and clip rewards. d_t = (d_t * self.discount).astype(jnp.float32) r_t = jnp.clip(r_t, -self.max_abs_reward, self.max_abs_reward).astype(jnp.float32) # Compute double Q-learning n-step TD-error. batch_error = jax.vmap(rlax.double_q_learning) td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value, q_t_selector) batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter) # Importance weighting. importance_weights = (1. / probs).astype(jnp.float32) importance_weights **= self.importance_sampling_exponent importance_weights /= jnp.max(importance_weights) # Reweight. loss = jnp.mean(importance_weights * batch_loss) # [] reverb_update = learning_lib.ReverbUpdate( keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64)) extra = learning_lib.LossExtra(metrics={}, reverb_update=reverb_update) return loss, extra