def loss_fn(params: hk.Params, sample: reverb.ReplaySample) -> jnp.DeviceArray: """Batched, entropy-regularised actor-critic loss with V-trace.""" # Extract the data. data = sample.data observations, actions, rewards, discounts, extra = (data.observation, data.action, data.reward, data.discount, data.extras) initial_state = tree.map_structure(lambda s: s[0], extra['core_state']) behaviour_logits = extra['logits'] # Apply reward clipping. rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) # Unroll current policy over observations. (logits, values), _ = unroll_fn(params, observations, initial_state) # Compute importance sampling weights: current policy / behavior policy. rhos = rlax.categorical_importance_sampling_ratios(logits[:-1], behaviour_logits[:-1], actions[:-1]) # Critic loss. vtrace_returns = rlax.vtrace_td_error_and_advantage( v_tm1=values[:-1], v_t=values[1:], r_t=rewards[:-1], discount_t=discounts[:-1] * discount, rho_tm1=rhos) critic_loss = jnp.square(vtrace_returns.errors) # Policy gradient loss. policy_gradient_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=actions[:-1], adv_t=vtrace_returns.pg_advantage, w_t=jnp.ones_like(rewards[:-1])) # Entropy regulariser. entropy_loss = rlax.entropy_loss(logits[:-1], jnp.ones_like(rewards[:-1])) # Combine weighted sum of actor & critic losses, averaged over the sequence. mean_loss = jnp.mean(policy_gradient_loss + baseline_cost * critic_loss + entropy_cost * entropy_loss) # [] metrics = { 'policy_loss': jnp.mean(policy_gradient_loss), 'critic_loss': jnp.mean(baseline_cost * critic_loss), 'entropy_loss': jnp.mean(entropy_cost * entropy_loss), 'entropy': jnp.mean(entropy_loss), } return mean_loss, metrics
def loss(params: hk.Params, sample: reverb.ReplaySample) -> jnp.ndarray: """Entropy-regularised actor-critic loss.""" # Extract the data. observations, actions, rewards, discounts, extra = sample.data initial_state = tree.map_structure(lambda s: s[0], extra['core_state']) behaviour_logits = extra['logits'] # actions = actions[:-1] # [T-1] rewards = rewards[:-1] # [T-1] discounts = discounts[:-1] # [T-1] rewards = jnp.clip(rewards, -max_abs_reward, max_abs_reward) # Unroll current policy over observations. net = functools.partial(network.apply, params) (logits, values), _ = hk.static_unroll(net, observations, initial_state) # Compute importance sampling weights: current policy / behavior policy. rhos = rlax.categorical_importance_sampling_ratios( logits[:-1], behaviour_logits[:-1], actions) # Critic loss. vtrace_returns = rlax.vtrace_td_error_and_advantage( v_tm1=values[:-1], v_t=values[1:], r_t=rewards, discount_t=discounts * discount, rho_t=rhos) critic_loss = jnp.square(vtrace_returns.errors) # Policy gradient loss. policy_gradient_loss = rlax.policy_gradient_loss( logits_t=logits[:-1], a_t=actions, adv_t=vtrace_returns.pg_advantage, w_t=jnp.ones_like(rewards)) # Entropy regulariser. entropy_loss = rlax.entropy_loss(logits[:-1], jnp.ones_like(rewards)) # Combine weighted sum of actor & critic losses. mean_loss = jnp.mean(policy_gradient_loss + baseline_cost * critic_loss + entropy_cost * entropy_loss) return mean_loss