Beispiel #1
0
        def f(dist_inputs, values, returns, dones, rewards, actions,
              old_log_probs, mask):
            """Definition of the A2C loss."""
            del dones, rewards, old_log_probs

            # Typically we have dist_inputs of the shape float32[128,1,18]
            assert len(dist_inputs.shape) == 3, (
                f'dist_inputs.shape was {dist_inputs.shape} '
                f'but expected length of the tensor shape is 3')
            # values of the shape float32[128,1,1]
            # returns of the shape float32[128,1,1]
            assert values.shape == returns.shape, (
                f'values.shape was {values.shape}'
                f'returns.shape was (returns.shape)')
            # actions of the shape int32[128,1] in the case of discrete actions
            # and float32[128,1,6] in the case of of half-cheetah
            # actions agree with returns/values on the first two coordinates
            assert actions.shape[0:2] == returns.shape[0:2], (
                f'actions.shape was {actions.shape}'
                f'returns.shape was (returns.shape)')
            # and mask of the shape float32[128,1]
            assert len(mask.shape) == 2, f'mask.shape was {mask.shape}'
            # which agrees with returns/values/actions on the first two coordinates
            assert mask.shape[0:2] == returns.shape[0:2], (
                f'mask.shape was {mask.shape}'
                f'returns.shape was (returns.shape)')

            a2c_objective = rl_layers.A2CObjective(
                dist_inputs,
                stop_gradient(values),
                returns,
                actions,
                mask,
                log_prob_fun=self._policy_dist.log_prob,
                normalize_advantages=self._normalize_advantages)

            # we insist that a2c_objective is a scalar
            assert jnp.ndim(
                a2c_objective) == 0, f'a2c_objective was {a2c_objective}'

            entropy_loss = rl_layers.EntropyLoss(
                dist_inputs,
                actions,
                log_prob_fun=self._policy_dist.log_prob,
                entropy_coeff=self._entropy_coeff,
                entropy_fun=self._policy_dist.entropy)

            assert jnp.ndim(
                entropy_loss) == 0, f'entropy_loss was {entropy_loss}'

            l2_value_loss = rl_layers.ValueLoss(
                values, returns, value_loss_coeff=self._value_loss_coeff)

            assert jnp.ndim(
                l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}'

            combined_loss = a2c_objective + l2_value_loss - entropy_loss

            return combined_loss
Beispiel #2
0
        def f(dist_inputs, values, returns, actions, old_log_probs, mask):
            """Definition of the Proximal Policy Optimization loss."""
            del mask  # TODO(lukaszkaiser): make PPO work with Transformer
            # We have dist_inputs of the shape float32[128,1,18]
            assert len(dist_inputs.shape) == 3, (
                f'dist_inputs.shape was {dist_inputs.shape}'
                f'but expected length of the tensor shape is 3')
            # values of the shape float32[128,1,1]
            # returns of the shape float32[128,1,1]
            # and old_log_probs of the shape float32[128,1]
            assert values.shape == returns.shape, (
                f'values.shape was {values.shape}'
                f'returns.shape was {returns.shape}')
            assert returns.shape[0:2] == old_log_probs.shape, (
                f'returns.shape was {returns.shape}'
                f'old_log_probs.shape was {old_log_probs.shape}')
            # actions is a tensor of the shape int32[128,1]
            assert len(
                actions.shape) == 2, f'actions.shape was {actions.shape}'
            # which agrees with returns/values on the first two coordinates
            assert actions.shape[0:2] == returns.shape[0:2], (
                f'actions.shape was {actions.shape} and'
                f'returns.shape was {returns.shape}')

            ppo_objective = rl_layers.PPOObjective(
                dist_inputs,
                stop_gradient(values),
                returns,
                actions,
                old_log_probs,
                log_prob_fun=self._policy_dist.log_prob,
                epsilon=self._epsilon,
                normalize_advantages=self._normalize_advantages)

            # we insist that ppo_objective is a vector of shape [128,1]
            assert len(ppo_objective.shape) == 2, (
                f'ppo_objective was {ppo_objective}')
            # which agrees with returns/values/actions on the first two coordinates
            assert ppo_objective.shape[0:2] == values.shape[0:2], (
                f'ppo_objective.shape was {ppo_objective.shape} and '
                f'values.shape was {values.shape}')

            entropy_loss = rl_layers.EntropyLoss(
                dist_inputs,
                actions,
                log_prob_fun=self._policy_dist.log_prob,
                entropy_coeff=self._entropy_coeff,
                entropy_fun=self._policy_dist.entropy)

            assert jnp.ndim(
                entropy_loss) == 0, f'entropy_loss was {entropy_loss}'

            l2_value_loss = rl_layers.ValueLoss(
                values, returns, value_loss_coeff=self._value_loss_coeff)

            assert jnp.ndim(
                l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}'

            return -ppo_objective.mean() + l2_value_loss - entropy_loss
Beispiel #3
0
 def _beta_gamma_with_correct_axes(self, x, weights):
   # Expand the parameters to have the right axes.
   beta, gamma = weights
   # TODO(phawkins): np.expand_dims should accept an axis tuple.
   # (https://github.com/numpy/numpy/issues/12290)
   ed = tuple(None if i in self._axis else slice(None)
              for i in range(np.ndim(x)))
   beta = beta[ed]
   gamma = gamma[ed]
   return beta, gamma