Esempio n. 1
0
 def update_policy_step():
     policy_updates, policy_opt_state = policy_optimizer.update(
         policy_gradients, state.policy_opt_state)
     policy_params = optax.apply_updates(state.policy_params,
                                         policy_updates)
     target_policy_params = optax.incremental_update(
         new_tensors=policy_params,
         old_tensors=state.target_policy_params,
         step_size=tau)
     return policy_params, target_policy_params, policy_opt_state
Esempio n. 2
0
 def ema_update(params, avg_params):
     return optax.incremental_update(params, avg_params, step_size=0.001)
Esempio n. 3
0
        def update_step(
            state: TrainingState,
            transitions: types.Transition,
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

            random_key, key_critic, key_twin = jax.random.split(
                state.random_key, 3)

            # Updates on the critic: compute the gradients, and update using
            # Polyak averaging.
            critic_loss_and_grad = jax.value_and_grad(critic_loss)
            critic_loss_value, critic_gradients = critic_loss_and_grad(
                state.critic_params, state, transitions, key_critic)
            critic_updates, critic_opt_state = critic_optimizer.update(
                critic_gradients, state.critic_opt_state)
            critic_params = optax.apply_updates(state.critic_params,
                                                critic_updates)
            # In the original authors' implementation the critic target update is
            # delayed similarly to the policy update which we found empirically to
            # perform slightly worse.
            target_critic_params = optax.incremental_update(
                new_tensors=critic_params,
                old_tensors=state.target_critic_params,
                step_size=tau)

            # Updates on the twin critic: compute the gradients, and update using
            # Polyak averaging.
            twin_critic_loss_value, twin_critic_gradients = critic_loss_and_grad(
                state.twin_critic_params, state, transitions, key_twin)
            twin_critic_updates, twin_critic_opt_state = twin_critic_optimizer.update(
                twin_critic_gradients, state.twin_critic_opt_state)
            twin_critic_params = optax.apply_updates(state.twin_critic_params,
                                                     twin_critic_updates)
            # In the original authors' implementation the twin critic target update is
            # delayed similarly to the policy update which we found empirically to
            # perform slightly worse.
            target_twin_critic_params = optax.incremental_update(
                new_tensors=twin_critic_params,
                old_tensors=state.target_twin_critic_params,
                step_size=tau)

            # Updates on the policy: compute the gradients, and update using
            # Polyak averaging (if delay enabled, the update might not be applied).
            policy_loss_and_grad = jax.value_and_grad(policy_loss)
            policy_loss_value, policy_gradients = policy_loss_and_grad(
                state.policy_params, state.critic_params, transitions)

            def update_policy_step():
                policy_updates, policy_opt_state = policy_optimizer.update(
                    policy_gradients, state.policy_opt_state)
                policy_params = optax.apply_updates(state.policy_params,
                                                    policy_updates)
                target_policy_params = optax.incremental_update(
                    new_tensors=policy_params,
                    old_tensors=state.target_policy_params,
                    step_size=tau)
                return policy_params, target_policy_params, policy_opt_state

            # The update on the policy is applied every `delay` steps.
            current_policy_state = (state.policy_params,
                                    state.target_policy_params,
                                    state.policy_opt_state)
            policy_params, target_policy_params, policy_opt_state = jax.lax.cond(
                state.steps % delay == 0,
                lambda _: update_policy_step(),
                lambda _: current_policy_state,
                operand=None)

            steps = state.steps + 1

            new_state = TrainingState(
                policy_params=policy_params,
                critic_params=critic_params,
                twin_critic_params=twin_critic_params,
                target_policy_params=target_policy_params,
                target_critic_params=target_critic_params,
                target_twin_critic_params=target_twin_critic_params,
                policy_opt_state=policy_opt_state,
                critic_opt_state=critic_opt_state,
                twin_critic_opt_state=twin_critic_opt_state,
                steps=steps,
                random_key=random_key,
            )

            metrics = {
                'policy_loss': policy_loss_value,
                'critic_loss': critic_loss_value,
                'twin_critic_loss': twin_critic_loss_value,
            }

            return new_state, metrics