示例#1
0
def Max(axis=-1, keepdims=False):
    """Returns a layer that applies max along one tensor axis.

  Args:
    axis: Axis along which values are grouped for computing maximum.
    keepdims: If `True`, keep the resulting size 1 axis as a separate tensor
        axis; else, remove that axis.
  """
    return Fn('Max', lambda x: jnp.max(x, axis, keepdims=keepdims))
示例#2
0
 def _aggregate_values(self, values, aggregate_max, act_log_probs):
     if self._q_value:
         if aggregate_max:
             values = jnp.max(values, axis=1)
         elif self._sample_all_discrete_actions:
             values = jnp.sum(values * jnp.exp(act_log_probs), axis=1)
         else:
             values = jnp.mean(values, axis=1)
     return np.array(values)  # Move the values to CPU.
示例#3
0
文件: training.py 项目: yliu45/trax
  def value_batches_stream(self):
    """Use the RLTask self._task to create inputs to the value model."""
    max_slice_length = self._max_slice_length
    min_slice_length = 1
    for np_trajectory in self._task.trajectory_batch_stream(
        self._value_batch_size,
        max_slice_length=max_slice_length,
        min_slice_length=min_slice_length,
        margin=0,
        epochs=self._replay_epochs,
    ):
      values_target = self._run_value_model(
          np_trajectory.observations, use_eval_model=True)
      if self._double_dqn:
        values = self._run_value_model(
            np_trajectory.observations, use_eval_model=False
        )
        index_max = np.argmax(values, axis=-1)
        ind_0, ind_1 = np.indices(index_max.shape)
        values_max = values_target[ind_0, ind_1, index_max]
      else:
        values_max = np.array(jnp.max(values_target, axis=-1))

      # The advantage_estimator returns
      #     gamma^n_steps * values_max(s_{i + n_steps}) + discounted_rewards
      #        - values_max(s_i)
      # hence we have to add values_max(s_i) in order to get n-step returns:
      #     gamma^n_steps * values_max(s_{i + n_steps}) + discounted_rewards
      # Notice, that in DQN the tensor values_max[:, :-self._margin]
      # is the same as values_max[:, :-1].
      n_step_returns = values_max[:, :-self._margin] + \
          self._advantage_estimator(
              rewards=np_trajectory.rewards,
              returns=np_trajectory.returns,
              values=values_max,
              dones=np_trajectory.dones
              )

      length = n_step_returns.shape[1]
      target_returns = n_step_returns[:, :length]
      inputs = np_trajectory.observations[:, :length, :]

      yield (
          # Inputs are observations
          # (batch, length, obs)
          inputs,
          # the max indices will be needed to compute the loss
          np_trajectory.actions[:, :length],  # index_max,
          # Targets: computed returns.
          # target_returns, we expect here shapes such as
          # (batch, length, num_actions)
          target_returns / self._value_network_scale,
          # TODO(henrykm): mask has the shape (batch, max_slice_length)
          # that is it misses the action dimension; the preferred format
          # would be np_trajectory.mask[:, :length, :] but for now we pass:
          np.ones(shape=target_returns.shape),
      )
示例#4
0
    def value_batches_stream(self):
        """Use the RLTask self._task to create inputs to the value model."""
        max_slice_length = self._max_slice_length
        min_slice_length = 1
        for np_trajectory in self._task.trajectory_batch_stream(
                self._value_batch_size,
                max_slice_length=max_slice_length,
                min_slice_length=min_slice_length,
                margin=0,
                epochs=self._replay_epochs,
        ):
            values = self._run_value_model(np_trajectory.observations)
            values_max = np.array(jnp.max(values, axis=-1))

            adv = self._advantage_estimator(
                rewards=np_trajectory.rewards,
                returns=np_trajectory.returns,
                values=values_max,
                dones=np_trajectory.dones,
            )

            length = adv.shape[1]
            values = values[:, :length, :]
            indices_max = (np.arange(values.shape[0]),
                           np.arange(values.shape[1]),
                           np.argmax(values, axis=-1))
            # TODO(henrykm): change it to fastmath instead of jax.ops
            target_returns = jax.ops.index_add(values, indices_max, adv)
            inputs = np_trajectory.observations[:, :length, :]

            yield (
                # Inputs are observations
                # (batch, length, obs)
                inputs,
                # Targets: computed returns.
                # target_returns, we expect here shapes such as
                # (batch, length, num_actions)
                target_returns / self._value_network_scale,
                # TODO(henrykm): mask has the shape (batch, max_slice_length)
                # that is it misses the action dimension; the preferred format
                # would be np_trajectory.mask[:, :length, :] but for now we pass:
                np.ones(shape=target_returns.shape))
示例#5
0
    def _aggregate_values(self, values, aggregate, act_log_probs):
        # Normalize the Q-values before aggragetion, so it can adapt to the scale
        # of the returns. This does not affect mean and max aggregation.
        scale = 1
        epsilon = 1e-5
        if self._q_value_normalization == 'std':
            scale = jnp.std(values) + epsilon
        elif self._q_value_normalization == 'abs':
            scale = jnp.mean(jnp.abs(values - jnp.mean(values))) + epsilon
        values /= scale

        temp = self._q_value_temperature
        if self._q_value:
            assert values.shape[:2] == (self._value_batch_size,
                                        self._q_value_n_samples)
            if aggregate == 'max':
                # max_a Q(s, a)
                values = jnp.max(values, axis=1)
            elif aggregate == 'softmax':
                # sum_a (Q(s, a) * w(s, a))
                # where w(s, .) = softmax (Q(s, .) / T)
                weights = tl.Softmax(axis=1)(values / temp)
                values = jnp.sum(values * weights, axis=1)
            elif aggregate == 'logsumexp':
                # log(mean_a exp(Q(s, a) / T)) * T
                n = values.shape[1]
                values = (fastmath.logsumexp(values / temp, axis=1) -
                          jnp.log(n)) * temp
            else:
                assert aggregate == 'mean'
                # mean_a Q(s, a)
                if self._sample_all_discrete_actions:
                    values = jnp.sum(values * jnp.exp(act_log_probs), axis=1)
                else:
                    values = jnp.mean(values, axis=1)

        # Re-scale the Q-values after aggregation.
        values *= scale
        return np.array(values)  # Move the values to CPU.