def loss(trajectory: buffer.Trajectory, rnn_unroll_state: RNNState): """"Computes a linear combination of the policy gradient loss and value loss and regularizes it with an entropy term.""" inputs = pack(trajectory) # Dyanmically unroll the network. This Haiku utility function unpacks the # list of input tensors such that the i^{th} row from each input tensor # is presented to the i^{th} unrolled RNN module. (logits, values, _, _, state_embeddings), new_rnn_unroll_state = hk.dynamic_unroll( network, inputs, rnn_unroll_state) trajectory_len = trajectory.actions.shape[0] # Compute the combined loss given the output of the model. td_errors = rlax.td_lambda(v_tm1=values[:-1, 0], r_t=jnp.squeeze(trajectory.rewards, -1), discount_t=trajectory.discounts * discount, v_t=values[1:, 0], lambda_=jnp.array(td_lambda)) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1, 0], a_t=jnp.squeeze(trajectory.actions, 1), adv_t=td_errors, w_t=jnp.ones(trajectory_len)) entropy_loss = jnp.mean( rlax.entropy_loss(logits[:-1, 0], jnp.ones(trajectory_len))) combined_loss = (actor_loss + critic_cost * critic_loss + entropy_cost * entropy_loss) return combined_loss, new_rnn_unroll_state
def loss_fn(params: hk.Params, sample: reverb.ReplaySample) -> jnp.DeviceArray: """Batched, entropy-regularised actor-critic loss with V-trace.""" # Extract the data. data = sample.data observations, actions, rewards, discounts, extra = (data.observation, data.action, data.reward, data.discount, data.extras) initial_state = tree.map_structure(lambda s: s[0], extra['core_state']) behaviour_logits = extra['logits'] # Apply reward clipping. rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) # Unroll current policy over observations. (logits, values), _ = unroll_fn(params, observations, initial_state) # Compute importance sampling weights: current policy / behavior policy. rhos = rlax.categorical_importance_sampling_ratios(logits[:-1], behaviour_logits[:-1], actions[:-1]) # Critic loss. vtrace_returns = rlax.vtrace_td_error_and_advantage( v_tm1=values[:-1], v_t=values[1:], r_t=rewards[:-1], discount_t=discounts[:-1] * discount, rho_tm1=rhos) critic_loss = jnp.square(vtrace_returns.errors) # Policy gradient loss. policy_gradient_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=actions[:-1], adv_t=vtrace_returns.pg_advantage, w_t=jnp.ones_like(rewards[:-1])) # Entropy regulariser. entropy_loss = rlax.entropy_loss(logits[:-1], jnp.ones_like(rewards[:-1])) # Combine weighted sum of actor & critic losses, averaged over the sequence. mean_loss = jnp.mean(policy_gradient_loss + baseline_cost * critic_loss + entropy_cost * entropy_loss) # [] metrics = { 'policy_loss': jnp.mean(policy_gradient_loss), 'critic_loss': jnp.mean(baseline_cost * critic_loss), 'entropy_loss': jnp.mean(entropy_cost * entropy_loss), 'entropy': jnp.mean(entropy_loss), } return mean_loss, metrics
def loss(params: hk.Params, sample: reverb.ReplaySample) -> jnp.ndarray: """Entropy-regularised actor-critic loss.""" # Extract the data. observations, actions, rewards, discounts, extra = sample.data initial_state = tree.map_structure(lambda s: s[0], extra['core_state']) behaviour_logits = extra['logits'] # actions = actions[:-1] # [T-1] rewards = rewards[:-1] # [T-1] discounts = discounts[:-1] # [T-1] rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) # Unroll current policy over observations. net = functools.partial(network.apply, params) (logits, values), _ = hk.static_unroll(net, observations, initial_state) # Compute importance sampling weights: current policy / behavior policy. rhos = rlax.categorical_importance_sampling_ratios( logits[:-1], behaviour_logits[:-1], actions) # Critic loss. vtrace_returns = rlax.vtrace_td_error_and_advantage( v_tm1=values[:-1], v_t=values[1:], r_t=rewards, discount_t=discounts * discount, rho_t=rhos) critic_loss = jnp.square(vtrace_returns.errors) # Policy gradient loss. policy_gradient_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=actions, adv_t=vtrace_returns.pg_advantage, w_t=jnp.ones_like(rewards)) # Entropy regulariser. entropy_loss = rlax.entropy_loss(logits[:-1], jnp.ones_like(rewards)) # Combine weighted sum of actor & critic losses. mean_loss = jnp.mean(policy_gradient_loss + baseline_cost * critic_loss + entropy_cost * entropy_loss) return mean_loss
def loss(trajectory: buffer.Trajectory) -> jnp.ndarray: """"Actor-critic loss.""" observations, rewards, actions = pack(trajectory) logits, values, _, _, _ = network(observations, rewards, actions) td_errors = rlax.td_lambda(v_tm1=values[:-1], r_t=jnp.squeeze(trajectory.rewards, -1), discount_t=trajectory.discounts * discount, v_t=values[1:], lambda_=jnp.array(td_lambda)) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=jnp.squeeze(trajectory.actions, 1), adv_t=td_errors, w_t=jnp.ones_like(td_errors)) entropy_loss = jnp.mean( rlax.entropy_loss(logits[:-1], jnp.ones_like(td_errors))) return actor_loss + critic_cost * critic_loss + entropy_cost * entropy_loss
def loss(trajectory: sequence.Trajectory, rnn_unroll_state: LSTMState): """"Actor-critic loss.""" (logits, values), new_rnn_unroll_state = hk.dynamic_unroll( network, trajectory.observations[:, None, ...], rnn_unroll_state) seq_len = trajectory.actions.shape[0] td_errors = rlax.td_lambda( v_tm1=values[:-1, 0], r_t=trajectory.rewards, discount_t=trajectory.discounts * discount, v_t=values[1:, 0], lambda_=jnp.array(td_lambda), ) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1, 0], a_t=trajectory.actions, adv_t=td_errors, w_t=jnp.ones(seq_len)) entropy_loss = jnp.mean( rlax.entropy_loss(logits[:-1, 0], jnp.ones(seq_len))) combined_loss = actor_loss + critic_loss + entropy_cost * entropy_loss return combined_loss, new_rnn_unroll_state