Ejemplo n.º 1
0
 def forward_with_state(self, x, weights, state, rng):
     batch_size, length = x.shape[0], x.shape[1]
     max_pos = min(self._bases)**self._n_digits
     rng1, rng2, rng3 = math.random.split(rng, 3)
     assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length,
                                                               max_pos)
     positions = jnp.arange(0, length)[None, :]
     if self._mode == 'train':
         # In 1% of training cases still start from 0 to be exactly as in eval.
         start_from_nonzero = jax.random.randint(
             rng1, (batch_size, ), 0, self._start_from_zero_one_in)
         start_from_nonzero = jnp.minimum(1, start_from_nonzero)
         random_start = jax.random.randint(rng2, (batch_size, ), 0,
                                           max_pos - length)
         random_start *= start_from_nonzero
         positions += random_start[:, None]
     res = []
     for bn, base in enumerate(self._bases):
         pos_embeddings = []
         cur_positions = positions
         for i in range(self._n_digits):
             cur_indices = jnp.mod(cur_positions, base)
             cur_positions = cur_positions // base
             s = weights[bn][i]
             pos_embeddings.append(
                 cur_indices.astype(jnp.float32)[:, :, None] * s)
         embeddings = jnp.concatenate(pos_embeddings, axis=-1)
         if self._mode == 'train':
             base_dropout = jax.random.randint(rng3, (batch_size, ), 0,
                                               self._base_dropout_one_in)
             base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32)
             embeddings *= base_dropout[:, None, None]
         res.append(embeddings)
     res = sum(res) + jnp.zeros_like(x)
     return jnp.concatenate([x, res], axis=-1), state
Ejemplo n.º 2
0
        def PPOJointLoss(x, **unused_kwargs):
            """Definition of the Proximal Policy Optimization loss."""
            dist_inputs, values, returns, actions, old_log_probs, mask = x
            del mask  # TODO(lukaszkaiser): make PPO work with Transformer
            new_log_probs = self._policy_dist.log_prob(dist_inputs, actions)

            advantages = returns - values
            l2_value_loss = jnp.sum(advantages**2) * self._value_loss_coeff

            # Old log probs have an undesirable extra dimension which we remove here
            old_log_probs = jnp.array(old_log_probs.squeeze(axis=-1),
                                      dtype=jnp.float32)
            new_log_probs = jnp.array(new_log_probs.squeeze(axis=-1))

            # The ratio between new_probs and old_probs expressed
            # using log_probs and exponentaion
            probs_ratio = jnp.exp(new_log_probs - old_log_probs)
            unclipped_objective = probs_ratio * advantages
            clipped_objective = jnp.clip(probs_ratio, 1 - self._epsilon,
                                         1 + self._epsilon) * advantages
            ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)

            entropy_loss = self._policy_dist.entropy(new_log_probs) *\
                self._entropy_coeff

            return -ppo_objective.mean() + l2_value_loss - entropy_loss
Ejemplo n.º 3
0
 def f(log_probs, advantages, old_log_probs, mask):
     if reweight:  # Use new policy weights for sampled actions instead.
         mask *= jnp.exp(math.stop_gradient(log_probs) - old_log_probs)
     if sampled_all_discrete:  # Actions were sampled uniformly; weight them.
         mask *= jnp.exp(old_log_probs)
     weights = jnp.minimum(awr_weights(advantages, beta), w_max)
     return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask)
Ejemplo n.º 4
0
def PPOObjective(dist_inputs, values, returns, dones, rewards, actions,
                 old_log_probs, log_prob_fun, epsilon, normalize_advantages):
    """PPO Objective."""
    # dist_inputs of the shape float32[128,1,18]
    # values of the shape float32[128,1,1]
    # returns of the shape float32[128,1,1]
    # dones of the shape float32[128,1,1]
    # rewards of the shape int32[128,1,1]
    # actions of the shape int32[128,1]
    # and old_log_probs of the shape float32[128,1]
    returns = returns.squeeze(axis=2)
    values = values.squeeze(axis=2)
    dones = dones.squeeze(axis=2)
    rewards = rewards.squeeze(axis=2)
    assert rewards.shape == dones.shape, (
        f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}')
    assert dones.shape == values.shape, (
        f'dones.shape was {dones.shape} and values.shape was {values.shape}')
    assert returns.shape == values.shape, (
        f'returns.shape was {returns.shape} and values.shape was {values.shape}'
    )
    assert returns.shape == old_log_probs.shape, (
        f'returns.shape was {returns.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')

    probs_ratio = ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun)
    assert probs_ratio.shape == old_log_probs.shape, (
        f'probs_ratio.shape was {probs_ratio.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')

    # jaxified versions of
    # returns[dones] = rewards[dones]
    # values[dones] = 0
    returns = jnp.where(dones, rewards, returns)
    values = jnp.where(dones, jnp.zeros_like(values), values)
    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    assert old_log_probs.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and advantages.shape was '
        f'{advantages.shape}')

    unclipped_objective = UnclippedObjective(probs_ratio, advantages)
    assert unclipped_objective.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and'
        f'unclipped_objective.shape was {unclipped_objective.shape}')

    clipped_objective = ClippedObjective(probs_ratio, advantages, epsilon)
    assert clipped_objective.shape == advantages.shape, (
        f'clipped_objective.shape was {clipped_objective.shape} and'
        f'advantages.shape was {advantages.shape}')

    ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)
    assert ppo_objective.shape == advantages.shape, (
        f'ppo_objective.shape was {ppo_objective.shape} and'
        f'advantages.shape was {advantages.shape}')

    return ppo_objective
Ejemplo n.º 5
0
    def _calc_adv_weights(self, adv, valid_mask):
        weights = jnp.exp(adv / self._temperature)

        valid_weights = weights[valid_mask]
        weights_mean = jnp.mean(valid_weights)
        weights_min = jnp.min(valid_weights)
        weights_max = jnp.max(valid_weights)

        weights = jnp.minimum(weights, self._weight_clip)
        return weights, weights_mean, weights_min, weights_max
Ejemplo n.º 6
0
 def f(values, weights):  # pylint: disable=invalid-name
     # This function assumes weights are 0 or 1.
     # Then compute 1: not-correct, 0: correct or masked
     not_correct = (1.0 - values) * weights
     axis_to_sum = list(range(1, len(not_correct.shape)))
     # Summing not-correct on all axes but batch. We're summing 0s and 1s,
     # so the sum is 0 if it's all 0 and >=1 in all other cases.
     not_correct_seq = np.sum(not_correct, axis=axis_to_sum)
     # Sequence is correct if not_correct_seq is 0, reverting here.
     correct_seq = 1.0 - np.minimum(1.0, not_correct_seq)
     return np.mean(correct_seq)  # Mean over batch.
Ejemplo n.º 7
0
def _WeightedSequenceMean(inputs, **unused_kwargs):
  """Returns a layer to compute weighted seqeunce accuracy mean."""
  values, weights = inputs  # This function assumes weights are 0 or 1.
  not_correct = (1.0 - values) * weights  # 1: not-correct, 0: correct or masked
  axis_to_sum = list(range(1, len(not_correct.shape)))
  # Summing not-correct on all axes but batch. We're summing 0s and 1s,
  # so the sum is 0 if it's all 0 and >=1 in all other cases.
  not_correct_seq = np.sum(not_correct, axis=axis_to_sum)
  # Sequence is correct if not_correct_seq is 0, reverting here.
  correct_seq = 1.0 - np.minimum(1.0, not_correct_seq)
  return np.mean(correct_seq)  # Mean over batch.
Ejemplo n.º 8
0
def HardTanh():
  r"""Returns a layer that computes a linear approximation to `Tanh`.

  .. math::
      f(x) = \left\{ \begin{array}{cl}
          -1 & \text{if}\ x \leq 0, \\
          x  & \text{if}\ -1 < x < 1, \\
          1  & \text{otherwise}.
      \end{array} \right.
  """
  return Fn('HardTanh', lambda x: np.maximum(-1, np.minimum(1, x)))
Ejemplo n.º 9
0
def HardSigmoid():
  r"""Returns a layer that computes a linear approximation to `Sigmoid`.

  .. math::
      f(x) = \left\{ \begin{array}{cl}
          0 & \text{if}\ x \leq 0, \\
          x & \text{if}\ 0 < x < 1, \\
          1 & \text{otherwise}.
      \end{array} \right.
  """
  return Fn('HardSigmoid', lambda x: np.maximum(0, np.minimum(1, (1 + x))))
Ejemplo n.º 10
0
 def AWRLoss(x, **unused_kwargs):  # pylint: disable=invalid-name
     logps, values, returns, actions = x
     advantage = returns - values
     l2_value_loss = jnp.sum(
         (returns - values)**2) * self._value_loss_coeff
     awr_weights = jnp.minimum(jnp.exp(advantage / self._beta),
                               self._w_max)
     log_loss = -1.0 * self._policy_dist.log_prob(logps, actions)
     policy_loss = jnp.sum(
         log_loss * awr_weights) / jnp.sum(awr_weights)
     return policy_loss + l2_value_loss
Ejemplo n.º 11
0
    def f(new_log_probs, advantages, old_log_probs, mask):
        # Old log probs have an undesirable extra dimension which we remove here
        old_log_probs = old_log_probs.squeeze(axis=-1)

        # The ratio between new_probs and old_probs expressed
        # using log_probs and exponentaion
        probs_ratio = jnp.exp(new_log_probs - old_log_probs)
        unclipped_objective = probs_ratio * advantages
        clipped_objective = jnp.clip(probs_ratio, 1 - epsilon,
                                     1 + epsilon) * advantages
        ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)
        return -np.sum(ppo_objective * mask) / np.sum(mask)
Ejemplo n.º 12
0
def PPOLoss(x, epsilon, **unused_kwargs):
    """Definition of the Proximal Policy Optimization loss."""
    (new_log_probs, advantages, old_log_probs, mask) = x
    # Old log probs have an undesirable extra dimension which we remove here
    old_log_probs = old_log_probs.squeeze(axis=-1)

    # The ratio between new_probs and old_probs expressed
    # using log_probs and exponentaion
    probs_ratio = jnp.exp(new_log_probs - old_log_probs)
    unclipped_objective = probs_ratio * advantages
    clipped_objective = jnp.clip(probs_ratio, 1 - epsilon,
                                 1 + epsilon) * advantages
    ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)
    return -np.sum(ppo_objective * mask) / np.sum(mask)
Ejemplo n.º 13
0
    def f(new_log_probs, advantages, old_log_probs, mask):
      # new_log_probs of the shape float32[128,1]
      # advantages of the shape int32[128,1]
      # old_log_probs of the shape int32[128,1]
      # mask of the shape int32[128,1]
      if new_log_probs.shape != advantages.shape:
        raise ValueError('New log-probs and advantages shapes '
                         'should be the same, %s != %s' % (new_log_probs.shape,
                                                           advantages.shape))
      if new_log_probs.shape != old_log_probs.shape:
        raise ValueError('New log-probs and old log-probs shapes '
                         'should be the same, %s != %s' % (new_log_probs.shape,
                                                           old_log_probs.shape))
      if new_log_probs.shape != mask.shape:
        raise ValueError('New log-probs and mask shapes should be the same'
                         ', %s != %s' % (new_log_probs.shape, mask.shape))

      # The ratio between new_probs and old_probs expressed
      # using log_probs and exponentaion
      probs_ratio = jnp.exp(new_log_probs - old_log_probs)
      if advantages.shape != probs_ratio.shape:
        raise ValueError('New log-probs and old log probs shapes '
                         'should be the same, %s != %s' % (advantages.shape,
                                                           probs_ratio.shape))
      unclipped_objective = probs_ratio * advantages
      clipped_objective = jnp.clip(probs_ratio,
                                   1 - self._epsilon,
                                   1 + self._epsilon) * advantages

      if unclipped_objective.shape != probs_ratio.shape:
        raise ValueError('unclipped_objective and clipped_objective shapes '
                         'should be the same, %s != %s' % (
                             unclipped_objective.shape,
                             clipped_objective.shape))

      ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)

      if ppo_objective.shape != mask.shape:
        raise ValueError('ppo_objective and mask shapes '
                         'should be the same, %s != %s' % (
                             ppo_objective.shape,
                             mask.shape))

      ppo_loss = -jnp.sum(ppo_objective * mask) / jnp.sum(mask)
      entropy_vec = self._policy_dist.entropy(
          new_log_probs) * self._entropy_coeff
      entropy_loss = jnp.mean(entropy_vec)
      combined_loss = ppo_loss - entropy_loss

      return combined_loss
Ejemplo n.º 14
0
def PPOObjective(dist_inputs, values, returns, actions, old_log_probs,
                 log_prob_fun, epsilon, normalize_advantages):
    """PPO Objective."""
    # Returns and values are arriving with two extra dimensions
    # TODO(henrykm): remove these dimensions at an earlier stage?
    returns = returns.squeeze()
    values = values.squeeze()
    probs_ratio = ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun)
    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    unclipped_objective = UnclippedObjective(probs_ratio, advantages)
    clipped_objective = ClippedObjective(probs_ratio, advantages, epsilon)
    ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)
    return ppo_objective
Ejemplo n.º 15
0
 def learning_rate(step):
     """Step to learning rate function."""
     ret = 1.0
     for name in factors:
         if name == 'constant':
             ret *= constant
         elif name == 'linear_warmup':
             ret *= np.minimum(1.0, step / warmup_steps)
         elif name == 'rsqrt_decay':
             ret /= np.sqrt(np.maximum(step, warmup_steps))
         elif name == 'rsqrt_normalized_decay':
             ret *= np.sqrt(warmup_steps)
             ret /= np.sqrt(np.maximum(step, warmup_steps))
         elif name == 'decay_every':
             ret *= (decay_factor**(step // steps_per_decay))
         elif name == 'cosine_decay':
             progress = np.maximum(0.0, (step - warmup_steps) /
                                   float(steps_per_cycle))
             ret *= (0.5 * (1.0 + np.cos(np.pi * (progress % 1.0))))
         else:
             raise ValueError('Unknown factor %s.' % name)
     ret = np.asarray(ret, dtype=np.float32)
     return {'learning_rate': ret}
Ejemplo n.º 16
0
 def f(log_probs, advantages, old_log_probs, mask):
     if reweight:  # Use new policy weights for sampled actions instead.
         mask *= jnp.exp(math.stop_gradient(log_probs) - old_log_probs)
     weights = jnp.minimum(awr_weights(advantages, beta), w_max)
     return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask)
Ejemplo n.º 17
0
 def f(log_probs, advantages, old_log_probs, mask):
     del old_log_probs  # Not used in AWR.
     weights = jnp.minimum(awr_weights(advantages, beta), w_max)
     return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask)
Ejemplo n.º 18
0
def SaturationCost(x, limit=0.9):
  return np.minimum(0, np.abs(x) - limit)
Ejemplo n.º 19
0
Archivo: sm3.py Proyecto: yazinsai/trax
 def _minimum(self, tensor_list):
   minimum = tensor_list[0]
   for i in range(1, len(tensor_list)):
     minimum = np.minimum(minimum, tensor_list[i])
   return minimum
Ejemplo n.º 20
0
def HardTanh():
    """Computes a linear approximation to tanh."""
    return Fn('HardTanh', lambda x: np.maximum(-1, np.minimum(1, x)))
Ejemplo n.º 21
0
def HardSigmoid():
    """Computes a linear approximation to sigmoid."""
    return Fn('HardSigmoid', lambda x: np.maximum(0, np.minimum(1, (1 + x))))
Ejemplo n.º 22
0
def AWRLoss(x, beta, w_max, **unused_kwargs):
    """Definition of the Advantage Weighted Regression (AWR) loss."""
    (log_probs, advantages, _) = x
    weights = jnp.minimum(jnp.exp(advantages / beta), w_max)
    return -(log_probs * weights).mean()
Ejemplo n.º 23
0
def AWRLoss(x, beta, w_max, log_prob_fn, **unused_kwargs):
  """Definition of the Advantage Weighted Regression (AWR) loss."""
  (predictions, actions, advantages, _) = x
  action_log_probs = log_prob_fn(predictions, actions)
  awr_weights = jnp.minimum(jnp.exp(advantages / beta), w_max)
  return -(action_log_probs * awr_weights).mean()
Ejemplo n.º 24
0
def HardTanh(x):
    """Computes a linear approximation to tanh."""
    return np.maximum(-1, np.minimum(1, x))
Ejemplo n.º 25
0
 def f(log_probs, advantages, old_log_probs, mask):
     del old_log_probs  # Not used in AWR.
     weights = jnp.minimum(jnp.exp(advantages / beta), w_max)
     return -np.sum(log_probs * weights * mask) / np.sum(mask)
Ejemplo n.º 26
0
def HardSigmoid(x, **unused_kwargs):
    """Linear approximation to sigmoid."""
    return np.maximum(0, np.minimum(1, (1 + x)))
Ejemplo n.º 27
0
def AWRLoss(x, beta, w_max, **unused_kwargs):
    """Definition of the Advantage Weighted Regression (AWR) loss."""
    (log_probs, advantages, old_log_probs, mask) = x
    del old_log_probs  # Not used in AWR.
    weights = jnp.minimum(jnp.exp(advantages / beta), w_max)
    return -np.sum(log_probs * weights * mask) / np.sum(mask)
Ejemplo n.º 28
0
def HardTanh(x, **unused_kwargs):
    """Linear approximation to tanh."""
    return np.maximum(-1, np.minimum(1, x))
Ejemplo n.º 29
0
def HardSigmoid(x):
    """Computes a linear approximation to sigmoid."""
    return np.maximum(0, np.minimum(1, (1 + x)))