Esempio n. 1
0
        def f(dist_inputs, values, returns, actions, old_log_probs, mask):
            """Definition of the Proximal Policy Optimization loss."""
            del mask  # TODO(lukaszkaiser): make PPO work with Transformer

            ppo_objective = rl_layers.PPOObjective(
                dist_inputs,
                stop_gradient(values),
                returns,
                actions,
                old_log_probs,
                log_prob_fun=self._policy_dist.log_prob,
                epsilon=self._epsilon,
                normalize_advantages=self._normalize_advantages)

            entropy_loss = rl_layers.EntropyLoss(
                dist_inputs,
                actions,
                log_prob_fun=self._policy_dist.log_prob,
                entropy_coeff=self._entropy_coeff,
                entropy_fun=self._policy_dist.entropy)

            l2_value_loss = rl_layers.ValueLoss(
                values, returns, value_loss_coeff=self._value_loss_coeff)

            return -ppo_objective.mean() + l2_value_loss - entropy_loss
Esempio n. 2
0
        def f(dist_inputs, values, returns, dones, rewards, actions,
              old_log_probs, mask):
            """Definition of the A2C loss."""
            del old_log_probs

            # Typically we have dist_inputs of the shape float32[128,1,18]
            assert len(dist_inputs.shape) == 3, (
                f'dist_inputs.shape was {dist_inputs.shape} '
                f'but expected length of the tensor shape is 3')
            # values of the shape float32[128,1,1]
            # returns of the shape float32[128,1,1]
            assert values.shape == returns.shape, (
                f'values.shape was {values.shape}'
                f'returns.shape was (returns.shape)')
            # actions of the shape int32[128,1] in the case of discrete actions
            # and float32[128,1,6] in the case of of half-cheetah
            # actions agree with returns/values on the first two coordinates
            assert actions.shape[0:2] == returns.shape[0:2], (
                f'actions.shape was {actions.shape}'
                f'returns.shape was (returns.shape)')
            # and mask of the shape float32[128,1]
            assert len(mask.shape) == 2, f'mask.shape was {mask.shape}'
            # which agrees with returns/values/actions on the first two coordinates
            assert mask.shape[0:2] == returns.shape[0:2], (
                f'mask.shape was {mask.shape}'
                f'returns.shape was (returns.shape)')

            a2c_objective = rl_layers.A2CObjective(
                dist_inputs,
                stop_gradient(values),
                returns,
                dones,
                rewards,
                actions,
                mask,
                log_prob_fun=self._policy_dist.log_prob,
                normalize_advantages=self._normalize_advantages)

            # we insist that a2c_objective is a scalar
            assert jnp.ndim(
                a2c_objective) == 0, f'a2c_objective was {a2c_objective}'

            entropy_loss = rl_layers.EntropyLoss(
                dist_inputs,
                distribution=self._policy_dist,
                coeff=self._entropy_coeff,
            )

            assert jnp.ndim(
                entropy_loss) == 0, f'entropy_loss was {entropy_loss}'

            l2_value_loss = rl_layers.ValueLoss(
                values, returns, value_loss_coeff=self._value_loss_coeff)

            assert jnp.ndim(
                l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}'

            combined_loss = a2c_objective + l2_value_loss - entropy_loss

            return combined_loss
Esempio n. 3
0
        def f(dist_inputs, values, returns, actions, old_log_probs, mask):
            """Definition of the Proximal Policy Optimization loss."""
            del mask  # TODO(lukaszkaiser): make PPO work with Transformer
            # We have dist_inputs of the shape float32[128,1,18]
            assert len(dist_inputs.shape) == 3, (
                f'dist_inputs.shape was {dist_inputs.shape}'
                f'but expected length of the tensor shape is 3')
            # values of the shape float32[128,1,1]
            # returns of the shape float32[128,1,1]
            # and old_log_probs of the shape float32[128,1]
            assert values.shape == returns.shape, (
                f'values.shape was {values.shape}'
                f'returns.shape was {returns.shape}')
            assert returns.shape[0:2] == old_log_probs.shape, (
                f'returns.shape was {returns.shape}'
                f'old_log_probs.shape was {old_log_probs.shape}')
            # actions is a tensor of the shape int32[128,1]
            assert len(
                actions.shape) == 2, f'actions.shape was {actions.shape}'
            # which agrees with returns/values on the first two coordinates
            assert actions.shape[0:2] == returns.shape[0:2], (
                f'actions.shape was {actions.shape} and'
                f'returns.shape was {returns.shape}')

            ppo_objective = rl_layers.PPOObjective(
                dist_inputs,
                stop_gradient(values),
                returns,
                actions,
                old_log_probs,
                log_prob_fun=self._policy_dist.log_prob,
                epsilon=self._epsilon,
                normalize_advantages=self._normalize_advantages)

            # we insist that ppo_objective is a vector of shape [128,1]
            assert len(ppo_objective.shape) == 2, (
                f'ppo_objective was {ppo_objective}')
            # which agrees with returns/values/actions on the first two coordinates
            assert ppo_objective.shape[0:2] == values.shape[0:2], (
                f'ppo_objective.shape was {ppo_objective.shape} and '
                f'values.shape was {values.shape}')

            entropy_loss = rl_layers.EntropyLoss(
                dist_inputs,
                actions,
                log_prob_fun=self._policy_dist.log_prob,
                entropy_coeff=self._entropy_coeff,
                entropy_fun=self._policy_dist.entropy)

            assert jnp.ndim(
                entropy_loss) == 0, f'entropy_loss was {entropy_loss}'

            l2_value_loss = rl_layers.ValueLoss(
                values, returns, value_loss_coeff=self._value_loss_coeff)

            assert jnp.ndim(
                l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}'

            return -ppo_objective.mean() + l2_value_loss - entropy_loss
Esempio n. 4
0
        def f(dist_inputs, values, returns, actions, old_log_probs, mask):
            """Definition of the A2C loss."""
            del old_log_probs

            a2c_objective = rl_layers.A2CObjective(
                dist_inputs,
                stop_gradient(values),
                returns,
                actions,
                mask,
                log_prob_fun=self._policy_dist.log_prob,
                normalize_advantages=self._normalize_advantages)

            entropy_loss = rl_layers.EntropyLoss(
                dist_inputs,
                actions,
                log_prob_fun=self._policy_dist.log_prob,
                entropy_coeff=self._entropy_coeff,
                entropy_fun=self._policy_dist.entropy)

            l2_value_loss = rl_layers.ValueLoss(
                values, returns, value_loss_coeff=self._value_loss_coeff)

            return a2c_objective.mean() + l2_value_loss - entropy_loss
Esempio n. 5
0
 def f(dist_inputs, values, returns):
   del dist_inputs
   return rl_layers.ValueLoss(values, returns, self._value_loss_coeff)
Esempio n. 6
0
 def f(dist_inputs, values, returns):
     del dist_inputs
     return rl_layers.ValueLoss(values, returns, 1)
Esempio n. 7
0
 def value_loss(self):
     """Value loss - so far generic for all A2C."""
     return tl.Fn(lambda dist_inputs, values, returns: rl_layers.ValueLoss(
         values, returns, self._value_loss_coeff),
                  n_out=1)
Esempio n. 8
0
 def value_loss(self):
   """Value loss - so far generic for all A2C."""
   layer = tl.Fn(lambda dist_inputs, values, returns: rl_layers.ValueLoss(
       values, returns, self._value_loss_coeff),
                 n_in=3, n_out=1)
   return lambda **unused_kwargs: layer