Ejemplo n.º 1
0
    def sgd_step(
        state: TrainingState,
        samples: reverb.ReplaySample
    ) -> Tuple[TrainingState, jnp.ndarray, Dict[str, jnp.ndarray]]:
      """Performs an update step, averaging over pmap replicas."""

      # Compute loss and gradients.
      grad_fn = jax.value_and_grad(loss, has_aux=True)
      key, key_grad = jax.random.split(state.random_key)
      (loss_value, priorities), gradients = grad_fn(state.params,
                                                    state.target_params,
                                                    key_grad,
                                                    samples)

      # Average gradients over pmap replicas before optimizer update.
      gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME)

      # Apply optimizer updates.
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optax.apply_updates(state.params, updates)

      # Periodically update target networks.
      steps = state.steps + 1
      target_params = optax.periodic_update(new_params, state.target_params,
                                            steps, self._target_update_period)

      new_state = TrainingState(
          params=new_params,
          target_params=target_params,
          opt_state=new_opt_state,
          steps=steps,
          random_key=key)
      return new_state, priorities, {'loss': loss_value}
Ejemplo n.º 2
0
 def learner_step(self, params, data, learner_state, unused_key):
     target_params = optax.periodic_update(params.online, params.target,
                                           learner_state.count,
                                           self._target_period)
     dloss_dtheta = jax.grad(self._loss)(params.online, target_params,
                                         *data)
     updates, opt_state = self._optimizer.update(dloss_dtheta,
                                                 learner_state.opt_state)
     online_params = optax.apply_updates(params.online, updates)
     return (Params(online_params, target_params),
             LearnerState(learner_state.count + 1, opt_state))
Ejemplo n.º 3
0
        def sgd_step(
            state: TrainingState,
            transitions: types.Transition,
        ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

            key, key_policy, key_critic = jax.random.split(state.key, 3)

            # Compute losses and their gradients.
            policy_loss_value, policy_gradients = policy_loss_and_grad(
                state.policy_params, state.critic_params, transitions,
                key_policy)
            critic_loss_value, critic_gradients = critic_loss_and_grad(
                state.critic_params, state.target_policy_params,
                state.target_critic_params, transitions, key_critic)

            # Get optimizer updates and state.
            policy_updates, policy_opt_state = policy_optimizer.update(
                policy_gradients, state.policy_opt_state)
            critic_updates, critic_opt_state = critic_optimizer.update(
                critic_gradients, state.critic_opt_state)

            # Apply optimizer updates to parameters.
            policy_params = optax.apply_updates(state.policy_params,
                                                policy_updates)
            critic_params = optax.apply_updates(state.critic_params,
                                                critic_updates)

            steps = state.steps + 1

            # Periodically update target networks.
            target_policy_params, target_critic_params = optax.periodic_update(
                (policy_params, critic_params),
                (state.target_policy_params, state.target_critic_params),
                steps, target_update_period)

            new_state = TrainingState(
                policy_params=policy_params,
                target_policy_params=target_policy_params,
                critic_params=critic_params,
                target_critic_params=target_critic_params,
                policy_opt_state=policy_opt_state,
                critic_opt_state=critic_opt_state,
                steps=steps,
                key=key,
            )

            metrics = {
                'policy_loss': policy_loss_value,
                'critic_loss': critic_loss_value,
            }

            return new_state, metrics
Ejemplo n.º 4
0
    def sgd_step(
        state: TrainingState,
        transitions: types.Transition,
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:

      # TODO(jaslanides): Use a shared forward pass for efficiency.
      policy_loss_and_grad = jax.value_and_grad(policy_loss)
      critic_loss_and_grad = jax.value_and_grad(critic_loss)

      # Compute losses and their gradients.
      policy_loss_value, policy_gradients = policy_loss_and_grad(
          state.policy_params, state.critic_params,
          transitions.next_observation)
      critic_loss_value, critic_gradients = critic_loss_and_grad(
          state.critic_params, state, transitions)

      # Get optimizer updates and state.
      policy_updates, policy_opt_state = policy_optimizer.update(  # pytype: disable=attribute-error
          policy_gradients, state.policy_opt_state)
      critic_updates, critic_opt_state = critic_optimizer.update(  # pytype: disable=attribute-error
          critic_gradients, state.critic_opt_state)

      # Apply optimizer updates to parameters.
      policy_params = optax.apply_updates(state.policy_params, policy_updates)
      critic_params = optax.apply_updates(state.critic_params, critic_updates)

      steps = state.steps + 1

      # Periodically update target networks.
      target_policy_params, target_critic_params = optax.periodic_update(
          (policy_params, critic_params),
          (state.target_policy_params, state.target_critic_params), steps,
          self._target_update_period)

      new_state = TrainingState(
          policy_params=policy_params,
          critic_params=critic_params,
          target_policy_params=target_policy_params,
          target_critic_params=target_critic_params,
          policy_opt_state=policy_opt_state,
          critic_opt_state=critic_opt_state,
          steps=steps,
      )

      metrics = {
          'policy_loss': policy_loss_value,
          'critic_loss': critic_loss_value,
      }

      return new_state, metrics
Ejemplo n.º 5
0
        def sgd_step(
                state: TrainingState,
                batch: reverb.ReplaySample) -> Tuple[TrainingState, LossExtra]:
            next_rng_key, rng_key = jax.random.split(state.rng_key)
            # Implements one SGD step of the loss and updates training state
            (loss, extra), grads = jax.value_and_grad(
                self._loss, has_aux=True)(state.params, state.target_params,
                                          batch, rng_key)
            extra.metrics.update({'total_loss': loss})

            # Apply the optimizer updates
            updates, new_opt_state = optimizer.update(grads, state.opt_state)
            new_params = optax.apply_updates(state.params, updates)

            # Periodically update target networks.
            steps = state.steps + 1
            target_params = optax.periodic_update(new_params,
                                                  state.target_params, steps,
                                                  target_update_period)
            new_training_state = TrainingState(new_params, target_params,
                                               new_opt_state, steps,
                                               next_rng_key)
            return new_training_state, extra