Example #1
0
        def f(dist_inputs, values, returns, dones, rewards, actions,
              old_log_probs, mask):
            """Definition of the A2C loss."""
            del old_log_probs

            # Typically we have dist_inputs of the shape float32[128,1,18]
            assert len(dist_inputs.shape) == 3, (
                f'dist_inputs.shape was {dist_inputs.shape} '
                f'but expected length of the tensor shape is 3')
            # values of the shape float32[128,1,1]
            # returns of the shape float32[128,1,1]
            assert values.shape == returns.shape, (
                f'values.shape was {values.shape}'
                f'returns.shape was (returns.shape)')
            # actions of the shape int32[128,1] in the case of discrete actions
            # and float32[128,1,6] in the case of of half-cheetah
            # actions agree with returns/values on the first two coordinates
            assert actions.shape[0:2] == returns.shape[0:2], (
                f'actions.shape was {actions.shape}'
                f'returns.shape was (returns.shape)')
            # and mask of the shape float32[128,1]
            assert len(mask.shape) == 2, f'mask.shape was {mask.shape}'
            # which agrees with returns/values/actions on the first two coordinates
            assert mask.shape[0:2] == returns.shape[0:2], (
                f'mask.shape was {mask.shape}'
                f'returns.shape was (returns.shape)')

            a2c_objective = rl_layers.A2CObjective(
                dist_inputs,
                stop_gradient(values),
                returns,
                dones,
                rewards,
                actions,
                mask,
                log_prob_fun=self._policy_dist.log_prob,
                normalize_advantages=self._normalize_advantages)

            # we insist that a2c_objective is a scalar
            assert jnp.ndim(
                a2c_objective) == 0, f'a2c_objective was {a2c_objective}'

            entropy_loss = rl_layers.EntropyLoss(
                dist_inputs,
                distribution=self._policy_dist,
                coeff=self._entropy_coeff,
            )

            assert jnp.ndim(
                entropy_loss) == 0, f'entropy_loss was {entropy_loss}'

            l2_value_loss = rl_layers.ValueLoss(
                values, returns, value_loss_coeff=self._value_loss_coeff)

            assert jnp.ndim(
                l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}'

            combined_loss = a2c_objective + l2_value_loss - entropy_loss

            return combined_loss
Example #2
0
 def _beta_gamma_with_correct_axes(self, x, weights):
     # Expand the parameters to have the right axes.
     beta, gamma = weights
     # TODO(phawkins): jnp.expand_dims should accept an axis tuple.
     # (https://github.com/numpy/numpy/issues/12290)
     ed = tuple(None if i in self._axis else slice(None)
                for i in range(jnp.ndim(x)))
     beta = beta[ed]
     gamma = gamma[ed]
     return beta, gamma
Example #3
0
    def f(dist_inputs, values, returns, dones, rewards,
          actions, old_log_probs, mask):
      """Definition of the Proximal Policy Optimization loss."""
      del mask  # TODO(lukaszkaiser): make PPO work with Transformer
      # We have dist_inputs of the shape float32[128,1,18]
      assert len(dist_inputs.shape) == 3, (
          f'dist_inputs.shape was {dist_inputs.shape}'
          f'but expected length of the tensor shape is 3')
      # values of the shape float32[128,1,1]
      # returns of the shape float32[128,1,1]
      # dones of the shape int32[128,1,1]
      # rewards of the shape float32[128,1,1]
      # and old_log_probs of the shape float32[128,1]
      assert values.shape == returns.shape, (
          f'values.shape was {values.shape}'
          f'returns.shape was {returns.shape}')
      assert values.shape == dones.shape, (
          f'values.shape was {values.shape}'
          f'returns.shape was {dones.shape}')
      assert rewards.shape == dones.shape, (
          f'values.shape was {values.shape}'
          f'returns.shape was {dones.shape}')
      assert returns.shape[0:2] == old_log_probs.shape, (
          f'returns.shape was {returns.shape}'
          f'old_log_probs.shape was {old_log_probs.shape}')

      # actions is a tensor of the shape int32[128,1] in the case
      # of discrete actions and float32[128,1,6] in the case of
      # half-cheetah and other continuous actions
      # actions agree with returns/values on the first two coordinates
      # meaning batch and time
      assert actions.shape[0:2] == returns.shape[0:2], (
          f'actions.shape was {actions.shape} and '
          f'returns.shape was {returns.shape}')

      ppo_objective = rl_layers.PPOObjective(
          dist_inputs, stop_gradient(values), returns, dones, rewards,
          actions, old_log_probs,
          log_prob_fun=self._policy_dist.log_prob,
          epsilon=self._epsilon,
          normalize_advantages=self._normalize_advantages)

      # we insist that ppo_objective is a vector of shape [128,1]
      assert len(ppo_objective.shape) == 2, (
          f'ppo_objective was {ppo_objective}')
      # which agrees with returns/values/actions on the first two coordinates
      assert ppo_objective.shape[0:2] == values.shape[0:2], (
          f'ppo_objective.shape was {ppo_objective.shape} and '
          f'values.shape was {values.shape}')

      entropy_loss = rl_layers.EntropyLoss(
          dist_inputs,
          distribution=self._policy_dist,
          coeff=self._entropy_coeff,
      )

      assert jnp.ndim(entropy_loss) == 0, f'entropy_loss was {entropy_loss}'

      l2_value_loss = rl_layers.ValueLoss(
          values, returns, value_loss_coeff=self._value_loss_coeff)

      assert jnp.ndim(l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}'

      return -ppo_objective.mean() + l2_value_loss - entropy_loss