def loss(trajectory: buffer.Trajectory, rnn_unroll_state: RNNState): """"Computes a linear combination of the policy gradient loss and value loss and regularizes it with an entropy term.""" inputs = pack(trajectory) # Dyanmically unroll the network. This Haiku utility function unpacks the # list of input tensors such that the i^{th} row from each input tensor # is presented to the i^{th} unrolled RNN module. (logits, values, _, _, state_embeddings), new_rnn_unroll_state = hk.dynamic_unroll( network, inputs, rnn_unroll_state) trajectory_len = trajectory.actions.shape[0] # Compute the combined loss given the output of the model. td_errors = rlax.td_lambda(v_tm1=values[:-1, 0], r_t=jnp.squeeze(trajectory.rewards, -1), discount_t=trajectory.discounts * discount, v_t=values[1:, 0], lambda_=jnp.array(td_lambda)) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1, 0], a_t=jnp.squeeze(trajectory.actions, 1), adv_t=td_errors, w_t=jnp.ones(trajectory_len)) entropy_loss = jnp.mean( rlax.entropy_loss(logits[:-1, 0], jnp.ones(trajectory_len))) combined_loss = (actor_loss + critic_cost * critic_loss + entropy_cost * entropy_loss) return combined_loss, new_rnn_unroll_state
def loss(trajectory: sequence.Trajectory) -> jnp.ndarray: """"Actor-critic loss.""" logits, values = network(trajectory.observations) td_errors = rlax.td_lambda( v_tm1=values[:-1], r_t=trajectory.rewards, discount_t=trajectory.discounts * discount, v_t=values[1:], lambda_=jnp.array(td_lambda), ) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=trajectory.actions, adv_t=td_errors, w_t=jnp.ones_like(td_errors)) return actor_loss + critic_loss
def loss(trajectory: buffer.Trajectory) -> jnp.ndarray: """"Actor-critic loss.""" observations, rewards, actions = pack(trajectory) logits, values, _, _, _ = network(observations, rewards, actions) td_errors = rlax.td_lambda(v_tm1=values[:-1], r_t=jnp.squeeze(trajectory.rewards, -1), discount_t=trajectory.discounts * discount, v_t=values[1:], lambda_=jnp.array(td_lambda)) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=jnp.squeeze(trajectory.actions, 1), adv_t=td_errors, w_t=jnp.ones_like(td_errors)) entropy_loss = jnp.mean( rlax.entropy_loss(logits[:-1], jnp.ones_like(td_errors))) return actor_loss + critic_cost * critic_loss + entropy_cost * entropy_loss
def loss( weights, observations, actions, rewards, td_lambda=0.2, discount=0.99, policy_cost=0.25, entropy_cost=1e-3, ): """Actor-critic loss.""" logits, values = network(weights, observations) values = jnp.append(values, jnp.sum(rewards)) # replace -inf values by tiny finite value logits = jnp.maximum(logits, MINIMUM_LOGIT) td_errors = rlax.td_lambda( v_tm1=values[:-1], r_t=rewards, discount_t=jnp.full_like(rewards, discount), v_t=values[1:], lambda_=td_lambda, ) critic_loss = jnp.mean(td_errors ** 2) if type_ == "a2c": actor_loss = rlax.policy_gradient_loss( logits_t=logits, a_t=actions, adv_t=td_errors, w_t=jnp.ones(td_errors.shape[0]), ) elif type_ == "supervised": actor_loss = jnp.mean(cross_entropy(logits, actions)) entropy_loss = -jnp.mean(entropy(logits)) return policy_cost * actor_loss, critic_loss, entropy_cost * entropy_loss
def loss(trajectory: sequence.Trajectory, rnn_unroll_state: LSTMState): """"Actor-critic loss.""" (logits, values), new_rnn_unroll_state = hk.dynamic_unroll( network, trajectory.observations[:, None, ...], rnn_unroll_state) seq_len = trajectory.actions.shape[0] td_errors = rlax.td_lambda( v_tm1=values[:-1, 0], r_t=trajectory.rewards, discount_t=trajectory.discounts * discount, v_t=values[1:, 0], lambda_=jnp.array(td_lambda), ) critic_loss = jnp.mean(td_errors**2) actor_loss = rlax.policy_gradient_loss( logits_t=logits[:-1, 0], a_t=trajectory.actions, adv_t=td_errors, w_t=jnp.ones(seq_len)) entropy_loss = jnp.mean( rlax.entropy_loss(logits[:-1, 0], jnp.ones(seq_len))) combined_loss = actor_loss + critic_loss + entropy_cost * entropy_loss return combined_loss, new_rnn_unroll_state