Пример #1
0
    def _step(self) -> Dict[str, tf.Tensor]:
        """Does an SGD step on a batch of sequences."""

        # Retrieve a batch of data from replay.
        inputs: reverb.ReplaySample = next(self._iterator)
        data = tf2_utils.batch_to_sequence(inputs.data)
        observations, actions, rewards, discounts, extra = (data.observation,
                                                            data.action,
                                                            data.reward,
                                                            data.discount,
                                                            data.extras)
        core_state = tree.map_structure(lambda s: s[0], extra['core_state'])

        #
        actions = actions[:-1]  # [T-1]
        rewards = rewards[:-1]  # [T-1]
        discounts = discounts[:-1]  # [T-1]

        with tf.GradientTape() as tape:
            # Unroll current policy over observations.
            (logits, values), _ = snt.static_unroll(self._network,
                                                    observations, core_state)

            # Compute importance sampling weights: current policy / behavior policy.
            behaviour_logits = extra['logits']
            pi_behaviour = tfd.Categorical(logits=behaviour_logits[:-1])
            pi_target = tfd.Categorical(logits=logits[:-1])
            log_rhos = pi_target.log_prob(actions) - pi_behaviour.log_prob(
                actions)

            # Optionally clip rewards.
            rewards = tf.clip_by_value(
                rewards, tf.cast(-self._max_abs_reward, rewards.dtype),
                tf.cast(self._max_abs_reward, rewards.dtype))

            # Critic loss.
            vtrace_returns = trfl.vtrace_from_importance_weights(
                log_rhos=tf.cast(log_rhos, tf.float32),
                discounts=tf.cast(self._discount * discounts, tf.float32),
                rewards=tf.cast(rewards, tf.float32),
                values=tf.cast(values[:-1], tf.float32),
                bootstrap_value=values[-1],
            )
            critic_loss = tf.square(vtrace_returns.vs - values[:-1])

            # Policy-gradient loss.
            policy_gradient_loss = trfl.policy_gradient(
                policies=pi_target,
                actions=actions,
                action_values=vtrace_returns.pg_advantages,
            )

            # Entropy regulariser.
            entropy_loss = trfl.policy_entropy_loss(pi_target).loss

            # Combine weighted sum of actor & critic losses.
            loss = tf.reduce_mean(policy_gradient_loss +
                                  self._baseline_cost * critic_loss +
                                  self._entropy_cost * entropy_loss)

        # Compute gradients and optionally apply clipping.
        gradients = tape.gradient(loss, self._network.trainable_variables)
        gradients, _ = tf.clip_by_global_norm(gradients,
                                              self._max_gradient_norm)
        self._optimizer.apply(gradients, self._network.trainable_variables)

        metrics = {
            'loss': loss,
            'critic_loss': tf.reduce_mean(critic_loss),
            'entropy_loss': tf.reduce_mean(entropy_loss),
            'policy_gradient_loss': tf.reduce_mean(policy_gradient_loss),
        }

        return metrics
Пример #2
0
  def _step(self, data: Step) -> Dict[str, tf.Tensor]:
    """Does an SGD step on a batch of sequences."""
    observations, actions, rewards, discounts, _, extra = data
    core_state = tree.map_structure(lambda s: s[0], extra['core_state'])

    actions = actions[:-1]  # [T-1]
    rewards = rewards[:-1]  # [T-1]
    discounts = discounts[:-1]  # [T-1]

    # Workaround for NO_OP actions
    # In some environments, passing NO_OP(-1) actions would lead to a crash.
    # These actions (at episode boundaries) should be ignored anyway,
    # so we replace NO_OP actions with a valid action index (0).
    actions = (tf.zeros_like(actions) * tf.cast(actions == -1, tf.int32) +
               actions * tf.cast(actions != -1, tf.int32))

    with tf.GradientTape() as tape:
      # Unroll current policy over observations.
      (logits, values), _ = snt.static_unroll(self._network, observations,
                                              core_state)

      pi = tfd.Categorical(logits=logits[:-1])

      # Optionally clip rewards.
      rewards = tf.clip_by_value(rewards,
                                 tf.cast(-self._max_abs_reward, rewards.dtype),
                                 tf.cast(self._max_abs_reward, rewards.dtype))

      # Compute actor & critic losses.
      discounted_returns = trfl.generalized_lambda_returns(
          rewards=tf.cast(rewards, tf.float32),
          pcontinues=tf.cast(self._discount*discounts, tf.float32),
          values=tf.cast(values[:-1], tf.float32),
          bootstrap_value=tf.cast(values[-1], tf.float32)
      )
      advantages = discounted_returns - values[:-1]

      critic_loss = tf.square(advantages)
      policy_gradient_loss = trfl.policy_gradient(
          policies=pi,
          actions=actions,
          action_values=advantages
      )
      entropy_loss = trfl.policy_entropy_loss(pi).loss

      loss = tf.reduce_mean(policy_gradient_loss +
                            self._baseline_cost * critic_loss +
                            self._entropy_cost * entropy_loss)

    # Compute gradients and optionally apply clipping.
    gradients = tape.gradient(loss, self._network.trainable_variables)
    gradients, _ = tf.clip_by_global_norm(gradients, self._max_gradient_norm)
    self._optimizer.apply(gradients, self._network.trainable_variables)

    metrics = {
        'loss': loss,
        'critic_loss': tf.reduce_mean(critic_loss),
        'entropy_loss': tf.reduce_mean(entropy_loss),
        'policy_gradient_loss': tf.reduce_mean(policy_gradient_loss),
    }

    return metrics