コード例 #1
0
ファイル: actor_critic.py プロジェクト: ulrikSebastienR/trax
 def _preprocess_advantages(self, advantages):
   if self._advantage_normalization:
     advantages = (
         (advantages - jnp.mean(advantages)) /
         (jnp.std(advantages) + self._advantage_normalization_epsilon)
     )
   return advantages
コード例 #2
0
def PPOObjective(dist_inputs, values, returns, dones, rewards, actions,
                 old_log_probs, log_prob_fun, epsilon, normalize_advantages):
    """PPO Objective."""
    # dist_inputs of the shape float32[128,1,18]
    # values of the shape float32[128,1,1]
    # returns of the shape float32[128,1,1]
    # dones of the shape float32[128,1,1]
    # rewards of the shape int32[128,1,1]
    # actions of the shape int32[128,1]
    # and old_log_probs of the shape float32[128,1]
    returns = returns.squeeze(axis=2)
    values = values.squeeze(axis=2)
    dones = dones.squeeze(axis=2)
    rewards = rewards.squeeze(axis=2)
    assert rewards.shape == dones.shape, (
        f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}')
    assert dones.shape == values.shape, (
        f'dones.shape was {dones.shape} and values.shape was {values.shape}')
    assert returns.shape == values.shape, (
        f'returns.shape was {returns.shape} and values.shape was {values.shape}'
    )
    assert returns.shape == old_log_probs.shape, (
        f'returns.shape was {returns.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')

    probs_ratio = ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun)
    assert probs_ratio.shape == old_log_probs.shape, (
        f'probs_ratio.shape was {probs_ratio.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')

    # jaxified versions of
    # returns[dones] = rewards[dones]
    # values[dones] = 0
    returns = jnp.where(dones, rewards, returns)
    values = jnp.where(dones, jnp.zeros_like(values), values)
    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    assert old_log_probs.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and advantages.shape was '
        f'{advantages.shape}')

    unclipped_objective = UnclippedObjective(probs_ratio, advantages)
    assert unclipped_objective.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and'
        f'unclipped_objective.shape was {unclipped_objective.shape}')

    clipped_objective = ClippedObjective(probs_ratio, advantages, epsilon)
    assert clipped_objective.shape == advantages.shape, (
        f'clipped_objective.shape was {clipped_objective.shape} and'
        f'advantages.shape was {advantages.shape}')

    ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)
    assert ppo_objective.shape == advantages.shape, (
        f'ppo_objective.shape was {ppo_objective.shape} and'
        f'advantages.shape was {advantages.shape}')

    return ppo_objective
コード例 #3
0
ファイル: policy_tasks.py プロジェクト: ixxxxu/trax
 def calculate_weights(self, advantages):
     """Calculates advantage-based weights for log loss in policy training."""
     if self._advantage_normalization:
         # Normalize advantages.
         advantages -= jnp.mean(advantages)
         advantage_std = jnp.std(advantages)
         advantages /= advantage_std + self._advantage_normalization_epsilon
     weights = self._weight_fn(advantages)
     assert weights.shape == advantages.shape
     return weights
コード例 #4
0
def A2CObjective(dist_inputs, values, returns, dones, rewards, actions, mask,
                 log_prob_fun, normalize_advantages):
    """Definition of the Advantage Actor Critic (A2C) loss."""
    # dist_inputs of the shape float32[128,1,18]
    # values of the shape float32[128,1,1]
    # returns of the shape float32[128,1,1]
    # dones of the shape int32[128,1,1]
    # actions of the shape int32[128,1]
    # and mask of the shape float32[128,1]
    # We have to squeeze values and returns, because we
    # are planning to compute (return - values) * new_log_probs * mask
    # and all of them should be of the same dimension
    values = values.squeeze(axis=2)
    returns = returns.squeeze(axis=2)
    dones = dones.squeeze(axis=2)
    rewards = rewards.squeeze(axis=2)
    assert rewards.shape == dones.shape, (
        f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}')
    assert dones.shape == values.shape, (
        f'dones.shape was {dones.shape} and values.shape was {values.shape}')
    assert returns.shape == values.shape, (
        f'returns.shape was {returns.shape} and values.shape was {values.shape}'
    )
    assert values.shape == mask.shape, (
        f'values.shape was {values.shape} and mask.shape was {mask.shape}')
    assert returns.shape[0] == dist_inputs.shape[0], (
        f'returns.shape[0] was {returns.shape[0]} and dist_inputs.shape[0] was '
        f'{dist_inputs.shape[0]}')

    new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun)
    assert new_log_probs.shape == mask.shape, (
        f'new_log_probs.shape was {new_log_probs.shape} and mask.shape was '
        f'{mask.shape}')

    # jaxified versions of
    # returns[dones] = rewards[dones]
    # values[dones] = 0
    returns = jnp.where(dones, rewards, returns)
    values = jnp.where(dones, jnp.zeros_like(values), values)
    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    assert new_log_probs.shape == advantages.shape, (
        f'new_log_probs.shape was {new_log_probs.shape} and advantages.shape was '
        f'{advantages.shape}')

    # One of the motivation to the squeezes and assertions is to
    # avoid [128,1] * [128,1,1] * [128] multiplications in the definition
    # of the a2c objective - we insist on the same shapes
    a2c_objective = -jnp.sum(new_log_probs * advantages * mask) / jnp.sum(mask)
    return a2c_objective
コード例 #5
0
ファイル: actor_critic.py プロジェクト: ulrikSebastienR/trax
 def policy_metrics(self):
   metrics = {
       'policy_loss': self.policy_loss,
       'advantage_mean': tl.Serial(
           self._policy_inputs_to_advantages(False),
           tl.Fn('Mean', lambda x: jnp.mean(x))  # pylint: disable=unnecessary-lambda
       ),
       'advantage_std': tl.Serial(
           self._policy_inputs_to_advantages(False),
           tl.Fn('Std', lambda x: jnp.std(x))  # pylint: disable=unnecessary-lambda
       )
   }
   metrics.update(awr_metrics(
       self._beta, preprocess_layer=self._policy_inputs_to_advantages(True)))
   return metrics
コード例 #6
0
ファイル: actor_critic.py プロジェクト: tvjoseph/trax
    def _aggregate_values(self, values, aggregate, act_log_probs):
        # Normalize the Q-values before aggragetion, so it can adapt to the scale
        # of the returns. This does not affect mean and max aggregation.
        scale = 1
        epsilon = 1e-5
        if self._q_value_normalization == 'std':
            scale = jnp.std(values) + epsilon
        elif self._q_value_normalization == 'abs':
            scale = jnp.mean(jnp.abs(values - jnp.mean(values))) + epsilon
        values /= scale

        temp = self._q_value_temperature
        if self._q_value:
            assert values.shape[:2] == (self._value_batch_size,
                                        self._q_value_n_samples)
            if aggregate == 'max':
                # max_a Q(s, a)
                values = jnp.max(values, axis=1)
            elif aggregate == 'softmax':
                # sum_a (Q(s, a) * w(s, a))
                # where w(s, .) = softmax (Q(s, .) / T)
                weights = tl.Softmax(axis=1)(values / temp)
                values = jnp.sum(values * weights, axis=1)
            elif aggregate == 'logsumexp':
                # log(mean_a exp(Q(s, a) / T)) * T
                n = values.shape[1]
                values = (fastmath.logsumexp(values / temp, axis=1) -
                          jnp.log(n)) * temp
            else:
                assert aggregate == 'mean'
                # mean_a Q(s, a)
                if self._sample_all_discrete_actions:
                    values = jnp.sum(values * jnp.exp(act_log_probs), axis=1)
                else:
                    values = jnp.mean(values, axis=1)

        # Re-scale the Q-values after aggregation.
        values *= scale
        return np.array(values)  # Move the values to CPU.
コード例 #7
0
ファイル: actor_critic.py プロジェクト: tvjoseph/trax
 def advantage_std(self):
     return tl.Serial([
         # (dist_inputs, advantages, old_dist_inputs, mask)
         tl.Select([1]),  # Select just the advantages.
         tl.Fn('AdvantageStd', lambda x: jnp.std(x)),  # pylint: disable=unnecessary-lambda
     ])