Ejemplo n.º 1
0
        def sgd_step(
            state: TrainingState, samples: reverb.ReplaySample
        ) -> Tuple[TrainingState, jnp.ndarray, Dict[str, jnp.ndarray]]:
            """Performs an update step, averaging over pmap replicas."""

            # Compute loss and gradients.
            grad_fn = jax.value_and_grad(loss, has_aux=True)
            key, key_grad = jax.random.split(state.random_key)
            (loss_value,
             priorities), gradients = grad_fn(state.params,
                                              state.target_params, key_grad,
                                              samples)

            # Average gradients over pmap replicas before optimizer update.
            gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME)

            # Apply optimizer updates.
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            # Periodically update target networks.
            steps = state.steps + 1
            target_params = rlax.periodic_update(new_params,
                                                 state.target_params, steps,
                                                 self._target_update_period)

            new_state = TrainingState(params=new_params,
                                      target_params=target_params,
                                      opt_state=new_opt_state,
                                      steps=steps,
                                      random_key=key)
            return new_state, priorities, {'loss': loss_value}
Ejemplo n.º 2
0
        def sgd_step(
            state: TrainingState, samples: reverb.ReplaySample
        ) -> Tuple[TrainingState, LearnerOutputs]:
            grad_fn = jax.grad(loss, has_aux=True)
            gradients, (keys, priorities) = grad_fn(state.params,
                                                    state.target_params,
                                                    samples)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            steps = state.steps + 1

            # Periodically update target networks.
            target_params = rlax.periodic_update(new_params,
                                                 state.target_params, steps,
                                                 self._target_update_period)

            new_state = TrainingState(params=new_params,
                                      target_params=target_params,
                                      opt_state=new_opt_state,
                                      steps=steps)

            outputs = LearnerOutputs(keys=keys, priorities=priorities)

            return new_state, outputs
Ejemplo n.º 3
0
 def learner_step(self, params, data, learner_state, unused_key):
     target_params = rlax.periodic_update(params.online, params.target,
                                          learner_state.count,
                                          self._target_period)
     dloss_dtheta = jax.grad(self._loss)(params.online, target_params,
                                         *data)
     updates, opt_state = self._optimizer.update(dloss_dtheta,
                                                 learner_state.opt_state)
     online_params = optax.apply_updates(params.online, updates)
     return (Params(online_params, target_params),
             LearnerState(learner_state.count + 1, opt_state))
Ejemplo n.º 4
0
        def sgd_step(
            state: TrainingState,
            transitions: types.Transition,
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

            # TODO(jaslanides): Use a shared forward pass for efficiency.
            policy_loss_and_grad = jax.value_and_grad(policy_loss)
            critic_loss_and_grad = jax.value_and_grad(critic_loss)

            # Compute losses and their gradients.
            policy_loss_value, policy_gradients = policy_loss_and_grad(
                state.policy_params, state.critic_params,
                transitions.next_observation)
            critic_loss_value, critic_gradients = critic_loss_and_grad(
                state.critic_params, state, transitions)

            # Get optimizer updates and state.
            policy_updates, policy_opt_state = policy_optimizer.update(  # pytype: disable=attribute-error
                policy_gradients, state.policy_opt_state)
            critic_updates, critic_opt_state = critic_optimizer.update(  # pytype: disable=attribute-error
                critic_gradients, state.critic_opt_state)

            # Apply optimizer updates to parameters.
            policy_params = optax.apply_updates(state.policy_params,
                                                policy_updates)
            critic_params = optax.apply_updates(state.critic_params,
                                                critic_updates)

            steps = state.steps + 1

            # Periodically update target networks.
            target_policy_params, target_critic_params = rlax.periodic_update(
                (policy_params, critic_params),
                (state.target_policy_params, state.target_critic_params),
                steps, self._target_update_period)

            new_state = TrainingState(
                policy_params=policy_params,
                critic_params=critic_params,
                target_policy_params=target_policy_params,
                target_critic_params=target_critic_params,
                policy_opt_state=policy_opt_state,
                critic_opt_state=critic_opt_state,
                steps=steps,
            )

            metrics = {
                'policy_loss': policy_loss_value,
                'critic_loss': critic_loss_value,
            }

            return new_state, metrics
Ejemplo n.º 5
0
        def sgd_step(state: TrainingState, batch: reverb.ReplaySample,
                     key: jnp.DeviceArray) -> Tuple[TrainingState, LossExtra]:
            # Implements one SGD step of the loss and updates training state
            (loss, extra), grads = jax.value_and_grad(
                self._loss, has_aux=True)(state.params, state.target_params,
                                          batch, key)
            extra.metrics.update({'total_loss': loss})

            # Apply the optimizer updates
            updates, new_opt_state = optimizer.update(grads, state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            # Periodically update target networks.
            steps = state.steps + 1
            target_params = rlax.periodic_update(new_params,
                                                 state.target_params, steps,
                                                 target_update_period)
            new_training_state = TrainingState(new_params, target_params,
                                               new_opt_state, steps)
            return new_training_state, extra
Ejemplo n.º 6
0
 def _learner_step(self, all_params, all_states, batch):
     target_params = rlax.periodic_update(
         all_params.online,
         all_params.target,
         all_states.learner_steps,
         self._target_update_interval,
     )
     grad, info = jax.grad(self._loss, has_aux=True)(all_params, batch)
     updates, optimizer_state = self._optimizer.update(
         grad.online, all_states.optimizer
     )
     online_params = optax.apply_updates(all_params.online, updates)
     return (
         AllParams(online=online_params, target=target_params),
         AllStates(
             optimizer=optimizer_state,
             learner_steps=all_states.learner_steps + 1,
             actor_steps=all_states.actor_steps,
         ),
         info,
     )