示例#1
0
        def f(dist_inputs, values, returns, dones, rewards, actions,
              old_log_probs, mask):
            """Definition of the A2C loss."""
            del old_log_probs

            # Typically we have dist_inputs of the shape float32[128,1,18]
            assert len(dist_inputs.shape) == 3, (
                f'dist_inputs.shape was {dist_inputs.shape} '
                f'but expected length of the tensor shape is 3')
            # values of the shape float32[128,1,1]
            # returns of the shape float32[128,1,1]
            assert values.shape == returns.shape, (
                f'values.shape was {values.shape}'
                f'returns.shape was (returns.shape)')
            # actions of the shape int32[128,1] in the case of discrete actions
            # and float32[128,1,6] in the case of of half-cheetah
            # actions agree with returns/values on the first two coordinates
            assert actions.shape[0:2] == returns.shape[0:2], (
                f'actions.shape was {actions.shape}'
                f'returns.shape was (returns.shape)')
            # and mask of the shape float32[128,1]
            assert len(mask.shape) == 2, f'mask.shape was {mask.shape}'
            # which agrees with returns/values/actions on the first two coordinates
            assert mask.shape[0:2] == returns.shape[0:2], (
                f'mask.shape was {mask.shape}'
                f'returns.shape was (returns.shape)')

            a2c_objective = rl_layers.A2CObjective(
                dist_inputs,
                stop_gradient(values),
                returns,
                dones,
                rewards,
                actions,
                mask,
                log_prob_fun=self._policy_dist.log_prob,
                normalize_advantages=self._normalize_advantages)

            # we insist that a2c_objective is a scalar
            assert jnp.ndim(
                a2c_objective) == 0, f'a2c_objective was {a2c_objective}'

            entropy_loss = rl_layers.EntropyLoss(
                dist_inputs,
                distribution=self._policy_dist,
                coeff=self._entropy_coeff,
            )

            assert jnp.ndim(
                entropy_loss) == 0, f'entropy_loss was {entropy_loss}'

            l2_value_loss = rl_layers.ValueLoss(
                values, returns, value_loss_coeff=self._value_loss_coeff)

            assert jnp.ndim(
                l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}'

            combined_loss = a2c_objective + l2_value_loss - entropy_loss

            return combined_loss
示例#2
0
 def f(dist_inputs, values, returns, actions, old_log_probs, mask):
   """A2C objective mean."""
   del old_log_probs
   a2c_objective = rl_layers.A2CObjective(
       dist_inputs, values, returns, actions, mask,
       log_prob_fun=self._policy_dist.log_prob,
       normalize_advantages=self._normalize_advantages)
   return jnp.mean(a2c_objective)
示例#3
0
 def a2c_objective(self):
   """A2C objective with local parameters."""
   return tl.Fn(
       'A2CObjective',
       lambda dist_inputs, values, returns, actions, old_log_probs, mask:
       rl_layers.A2CObjective(
           dist_inputs, values, returns, actions, mask,
           log_prob_fun=self._policy_dist.log_prob,
           normalize_advantages=self._normalize_advantages),
       n_out=1)
示例#4
0
        def f(dist_inputs, values, returns, actions, old_log_probs, mask):
            """Definition of the A2C loss."""
            del old_log_probs

            a2c_objective = rl_layers.A2CObjective(
                dist_inputs,
                stop_gradient(values),
                returns,
                actions,
                mask,
                log_prob_fun=self._policy_dist.log_prob,
                normalize_advantages=self._normalize_advantages)

            entropy_loss = rl_layers.EntropyLoss(
                dist_inputs,
                actions,
                log_prob_fun=self._policy_dist.log_prob,
                entropy_coeff=self._entropy_coeff,
                entropy_fun=self._policy_dist.entropy)

            l2_value_loss = rl_layers.ValueLoss(
                values, returns, value_loss_coeff=self._value_loss_coeff)

            return a2c_objective.mean() + l2_value_loss - entropy_loss