def update_policy_step(): policy_updates, policy_opt_state = policy_optimizer.update( policy_gradients, state.policy_opt_state) policy_params = optax.apply_updates(state.policy_params, policy_updates) target_policy_params = optax.incremental_update( new_tensors=policy_params, old_tensors=state.target_policy_params, step_size=tau) return policy_params, target_policy_params, policy_opt_state
def ema_update(params, avg_params): return optax.incremental_update(params, avg_params, step_size=0.001)
def update_step( state: TrainingState, transitions: types.Transition, ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: random_key, key_critic, key_twin = jax.random.split( state.random_key, 3) # Updates on the critic: compute the gradients, and update using # Polyak averaging. critic_loss_and_grad = jax.value_and_grad(critic_loss) critic_loss_value, critic_gradients = critic_loss_and_grad( state.critic_params, state, transitions, key_critic) critic_updates, critic_opt_state = critic_optimizer.update( critic_gradients, state.critic_opt_state) critic_params = optax.apply_updates(state.critic_params, critic_updates) # In the original authors' implementation the critic target update is # delayed similarly to the policy update which we found empirically to # perform slightly worse. target_critic_params = optax.incremental_update( new_tensors=critic_params, old_tensors=state.target_critic_params, step_size=tau) # Updates on the twin critic: compute the gradients, and update using # Polyak averaging. twin_critic_loss_value, twin_critic_gradients = critic_loss_and_grad( state.twin_critic_params, state, transitions, key_twin) twin_critic_updates, twin_critic_opt_state = twin_critic_optimizer.update( twin_critic_gradients, state.twin_critic_opt_state) twin_critic_params = optax.apply_updates(state.twin_critic_params, twin_critic_updates) # In the original authors' implementation the twin critic target update is # delayed similarly to the policy update which we found empirically to # perform slightly worse. target_twin_critic_params = optax.incremental_update( new_tensors=twin_critic_params, old_tensors=state.target_twin_critic_params, step_size=tau) # Updates on the policy: compute the gradients, and update using # Polyak averaging (if delay enabled, the update might not be applied). policy_loss_and_grad = jax.value_and_grad(policy_loss) policy_loss_value, policy_gradients = policy_loss_and_grad( state.policy_params, state.critic_params, transitions) def update_policy_step(): policy_updates, policy_opt_state = policy_optimizer.update( policy_gradients, state.policy_opt_state) policy_params = optax.apply_updates(state.policy_params, policy_updates) target_policy_params = optax.incremental_update( new_tensors=policy_params, old_tensors=state.target_policy_params, step_size=tau) return policy_params, target_policy_params, policy_opt_state # The update on the policy is applied every `delay` steps. current_policy_state = (state.policy_params, state.target_policy_params, state.policy_opt_state) policy_params, target_policy_params, policy_opt_state = jax.lax.cond( state.steps % delay == 0, lambda _: update_policy_step(), lambda _: current_policy_state, operand=None) steps = state.steps + 1 new_state = TrainingState( policy_params=policy_params, critic_params=critic_params, twin_critic_params=twin_critic_params, target_policy_params=target_policy_params, target_critic_params=target_critic_params, target_twin_critic_params=target_twin_critic_params, policy_opt_state=policy_opt_state, critic_opt_state=critic_opt_state, twin_critic_opt_state=twin_critic_opt_state, steps=steps, random_key=random_key, ) metrics = { 'policy_loss': policy_loss_value, 'critic_loss': critic_loss_value, 'twin_critic_loss': twin_critic_loss_value, } return new_state, metrics