예제 #1
0
 def forward(self, x):
   rng = self.rng
   batch_size, length = x.shape[0], x.shape[1]
   max_pos = min(self._bases)**self._n_digits
   rng1, rng2, rng3 = fastmath.random.split(rng, 3)
   assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos)
   positions = jnp.arange(0, length)[None, :]
   if self._mode == 'train':
     # In 1% of training cases still start from 0 to be exactly as in eval.
     start_from_nonzero = fastmath.random.randint(
         rng1, (batch_size,), 0, self._start_from_zero_one_in)
     start_from_nonzero = jnp.minimum(1, start_from_nonzero)
     random_start = fastmath.random.randint(
         rng2, (batch_size,), 0, max_pos-length)
     random_start *= start_from_nonzero
     positions += random_start[:, None]
   res = []
   for bn, base in enumerate(self._bases):
     pos_embeddings = []
     cur_positions = positions
     for i in range(self._n_digits):
       cur_indices = jnp.mod(cur_positions, base)
       cur_positions = cur_positions // base
       s = self.weights[bn][i]
       pos_embeddings.append(cur_indices.astype(jnp.float32)[:, :, None] * s)
     embeddings = jnp.concatenate(pos_embeddings, axis=-1)
     if self._mode == 'train':
       base_dropout = fastmath.random.randint(
           rng3, (batch_size,), 0, self._base_dropout_one_in)
       base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32)
       embeddings *= base_dropout[:, None, None]
     res.append(embeddings)
   res = sum(res) + jnp.zeros_like(x)
   return x + res
예제 #2
0
 def forward(self, x):
     rng = self.rng
     base_weights, start_vec = self.weights
     batch_size, length = x.shape[0], x.shape[1]
     max_pos = min(self._bases)**self._n_digits
     rng1, rng2, rng3 = fastmath.random.split(rng, 3)
     assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length,
                                                               max_pos)
     positions = jnp.arange(0, length)[None, :]
     # In training we'll randomize starts for better generalization.
     # We use the trainable start_vec to compensate and give model a way
     # to learn what is the starting position in a sequence.
     if self._mode == 'train':
         # In 1% of training cases still start from 0 to be exactly as in eval.
         start_from_nonzero = fastmath.random.randint(
             rng1, (batch_size, ), 0, self._start_from_zero_one_in)
         start_from_nonzero = jnp.minimum(1, start_from_nonzero)
         random_start = fastmath.random.randint(rng2, (batch_size, ), 0,
                                                max_pos - length)
         random_start *= start_from_nonzero
         positions += random_start[:, None]
     if self._mode == 'predict':
         positions += self.state
     res = []
     for bn, base in enumerate(self._bases):
         pos_embeddings = []
         cur_positions = positions
         for i in range(self._n_digits):
             cur_indices = jnp.mod(cur_positions, base)
             cur_positions = cur_positions // base
             s = base_weights[bn][i]
             pos_embeddings.append(
                 cur_indices.astype(jnp.float32)[:, :, None] * s)
         embeddings = jnp.concatenate(pos_embeddings, axis=-1)
         if self._mode == 'train':
             base_dropout = fastmath.random.randint(
                 rng3, (batch_size, ), 0, self._base_dropout_one_in)
             base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32)
             embeddings *= base_dropout[:, None, None]
         res.append(embeddings)
     res = sum(res)  # Sum embeddings from all bases.
     # Add start_vec to the first position only to mark it as starting.
     res0 = res[:, 0, :][:, None, :]
     start_pos = res0 + start_vec
     if self._mode == 'predict':
         start_pos = jnp.where(jnp.equal(self.state, 0), start_pos, res0)
         self.state += length  # Add input length to state.
     res = jnp.concatenate([start_pos, res[:, 1:, :]], axis=1)
     return x + res
예제 #3
0
 def f(log_probs, advantages, old_log_probs, mask):
     if reweight:  # Use new policy weights for sampled actions instead.
         mask *= jnp.exp(fastmath.stop_gradient(log_probs) - old_log_probs)
     if sampled_all_discrete:  # Actions were sampled uniformly; weight them.
         mask *= jnp.exp(old_log_probs)
     weights = jnp.minimum(awr_weights(advantages, beta), w_max)
     return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask)
예제 #4
0
 def learning_rate(step):
     """Step to learning rate function."""
     ret = 1.0
     for name in factors:
         if name == 'constant':
             ret *= constant
         elif name == 'linear_warmup':
             ret *= jnp.minimum(1.0, step / warmup_steps)
         elif name == 'rsqrt_decay':
             ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
         elif name == 'rsqrt_normalized_decay':
             ret *= jnp.sqrt(warmup_steps)
             ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
         elif name == 'decay_every':
             ret *= (decay_factor**(step // steps_per_decay))
         elif name == 'cosine_decay':
             progress = jnp.maximum(0.0, (step - warmup_steps) /
                                    float(steps_per_cycle))
             ret *= (0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
         else:
             raise ValueError('Unknown factor %s.' % name)
     # TODO(henrykm): return float(jnp.max(minimum, ret)) would be
     # better but causes TypeError: 'numpy.float64' object cannot
     # be interpreted as an integer
     if ret <= minimum:
         return minimum
     return ret
예제 #5
0
def PPOObjective(dist_inputs, values, returns, dones, rewards, actions,
                 old_log_probs, log_prob_fun, epsilon, normalize_advantages):
    """PPO Objective."""
    # dist_inputs of the shape float32[128,1,18]
    # values of the shape float32[128,1,1]
    # returns of the shape float32[128,1,1]
    # dones of the shape float32[128,1,1]
    # rewards of the shape int32[128,1,1]
    # actions of the shape int32[128,1]
    # and old_log_probs of the shape float32[128,1]
    returns = returns.squeeze(axis=2)
    values = values.squeeze(axis=2)
    dones = dones.squeeze(axis=2)
    rewards = rewards.squeeze(axis=2)
    assert rewards.shape == dones.shape, (
        f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}')
    assert dones.shape == values.shape, (
        f'dones.shape was {dones.shape} and values.shape was {values.shape}')
    assert returns.shape == values.shape, (
        f'returns.shape was {returns.shape} and values.shape was {values.shape}'
    )
    assert returns.shape == old_log_probs.shape, (
        f'returns.shape was {returns.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')

    probs_ratio = ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun)
    assert probs_ratio.shape == old_log_probs.shape, (
        f'probs_ratio.shape was {probs_ratio.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')

    # jaxified versions of
    # returns[dones] = rewards[dones]
    # values[dones] = 0
    returns = jnp.where(dones, rewards, returns)
    values = jnp.where(dones, jnp.zeros_like(values), values)
    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    assert old_log_probs.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and advantages.shape was '
        f'{advantages.shape}')

    unclipped_objective = UnclippedObjective(probs_ratio, advantages)
    assert unclipped_objective.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and'
        f'unclipped_objective.shape was {unclipped_objective.shape}')

    clipped_objective = ClippedObjective(probs_ratio, advantages, epsilon)
    assert clipped_objective.shape == advantages.shape, (
        f'clipped_objective.shape was {clipped_objective.shape} and'
        f'advantages.shape was {advantages.shape}')

    ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)
    assert ppo_objective.shape == advantages.shape, (
        f'ppo_objective.shape was {ppo_objective.shape} and'
        f'advantages.shape was {advantages.shape}')

    return ppo_objective
예제 #6
0
 def f(values, weights):  # pylint: disable=invalid-name
     # This function assumes weights are 0 or 1.
     # Then compute 1: not-correct, 0: correct or masked
     not_correct = (1.0 - values) * weights
     axis_to_sum = list(range(1, len(not_correct.shape)))
     # Summing not-correct on all axes but batch. We're summing 0s and 1s,
     # so the sum is 0 if it's all 0 and >=1 in all other cases.
     not_correct_seq = jnp.sum(not_correct, axis=axis_to_sum)
     # Sequence is correct if not_correct_seq is 0, reverting here.
     correct_seq = 1.0 - jnp.minimum(1.0, not_correct_seq)
     return jnp.mean(correct_seq)  # Mean over batch.
예제 #7
0
def HardTanh():
  r"""Returns a layer that computes a linear approximation to `Tanh`.

  .. math::
      f(x) = \left\{ \begin{array}{cl}
          -1 & \text{if}\ x \leq 0, \\
          x  & \text{if}\ -1 < x < 1, \\
          1  & \text{otherwise}.
      \end{array} \right.
  """
  return Fn('HardTanh', lambda x: jnp.maximum(-1, jnp.minimum(1, x)))
예제 #8
0
def HardSigmoid():
  r"""Returns a layer that computes a linear approximation to `Sigmoid`.

  .. math::
      f(x) = \left\{ \begin{array}{cl}
          0 & \text{if}\ x \leq 0, \\
          x & \text{if}\ 0 < x < 1, \\
          1 & \text{otherwise}.
      \end{array} \right.
  """
  return Fn('HardSigmoid', lambda x: jnp.maximum(0, jnp.minimum(1, (1 + x))))
예제 #9
0
    def f(new_log_probs, advantages, old_log_probs, mask):
      # new_log_probs of the shape float32[128,1]
      # advantages of the shape int32[128,1]
      # old_log_probs of the shape int32[128,1]
      # mask of the shape int32[128,1]
      if new_log_probs.shape != advantages.shape:
        raise ValueError('New log-probs and advantages shapes '
                         'should be the same, %s != %s' % (new_log_probs.shape,
                                                           advantages.shape))
      if new_log_probs.shape != old_log_probs.shape:
        raise ValueError('New log-probs and old log-probs shapes '
                         'should be the same, %s != %s' % (new_log_probs.shape,
                                                           old_log_probs.shape))
      if new_log_probs.shape != mask.shape:
        raise ValueError('New log-probs and mask shapes should be the same'
                         ', %s != %s' % (new_log_probs.shape, mask.shape))

      # The ratio between new_probs and old_probs expressed
      # using log_probs and exponentaion
      probs_ratio = jnp.exp(new_log_probs - old_log_probs)
      if advantages.shape != probs_ratio.shape:
        raise ValueError('New log-probs and old log probs shapes '
                         'should be the same, %s != %s' % (advantages.shape,
                                                           probs_ratio.shape))
      unclipped_objective = probs_ratio * advantages
      clipped_objective = jnp.clip(probs_ratio,
                                   1 - self._epsilon,
                                   1 + self._epsilon) * advantages

      if unclipped_objective.shape != probs_ratio.shape:
        raise ValueError('unclipped_objective and clipped_objective shapes '
                         'should be the same, %s != %s' % (
                             unclipped_objective.shape,
                             clipped_objective.shape))

      ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)

      if ppo_objective.shape != mask.shape:
        raise ValueError('ppo_objective and mask shapes '
                         'should be the same, %s != %s' % (
                             ppo_objective.shape,
                             mask.shape))

      ppo_loss = -jnp.sum(ppo_objective * mask) / jnp.sum(mask)
      entropy_vec = self._policy_dist.entropy(
          new_log_probs) * self._entropy_coeff
      entropy_loss = jnp.mean(entropy_vec)
      combined_loss = ppo_loss - entropy_loss

      return combined_loss
예제 #10
0
 def learning_rate(step):
     """Step to learning rate function."""
     ret = 1.0
     for name in factors:
         if name == 'constant':
             ret *= constant
         elif name == 'linear_warmup':
             ret *= jnp.minimum(1.0, step / warmup_steps)
         elif name == 'rsqrt_decay':
             ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
         elif name == 'rsqrt_normalized_decay':
             ret *= jnp.sqrt(warmup_steps)
             ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
         elif name == 'decay_every':
             ret *= (decay_factor**(step // steps_per_decay))
         elif name == 'cosine_decay':
             progress = jnp.maximum(0.0, (step - warmup_steps) /
                                    float(steps_per_cycle))
             ret *= (0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
         else:
             raise ValueError('Unknown factor %s.' % name)
     return float(ret)
예제 #11
0
    def forward(self, inputs):
        """Returns the input activations, with added positional information."""
        if self._mode != 'predict':
            x = inputs
            length = jnp.shape(x)[1]
            if self._mode != 'train':
                start = 0
            else:
                rng1, rng2 = fastmath.random.split(self.rng, 2)
                start = fastmath.random.randint(rng1, (), 0, self._add_offset)
                start_from_nonzero = fastmath.random.randint(
                    rng2, (), 0, self._start_from_zero_one_in)
                start_from_nonzero = jnp.minimum(1, start_from_nonzero)
                start *= start_from_nonzero
            px = self._sincos(start, length, inputs.shape[2])
            if self._dropout == 0:
                return x + px
            else:
                noise_shape = list(px.shape)
                for dim in self._dropout_broadcast_dims:
                    noise_shape[dim] = 1
                keep_prob = 1.0 - self._dropout
                keep = fastmath.random.bernoulli(self.rng, keep_prob,
                                                 tuple(noise_shape))
                multiplier = keep.astype(x.dtype) / keep_prob
                return x + px * multiplier
        else:
            if self._dropout != 0:
                raise ValueError(f'In predict mode, but dropout rate '
                                 f'({self._dropout}) is not zero.')

            # State in this class is only used for fast inference. In that case,
            # the model is called with consecutive elements position-by-position.
            # This positional encoding layer needs to store the index of the current
            # position then and increment it on each call -- that's how state is used
            # and updated below.
            pe = self._sincos(self.state, inputs.shape[1], inputs.shape[2])
            self.state += inputs.shape[1]
            return inputs + pe
예제 #12
0
 def _minimum(self, tensor_list):
   minimum = tensor_list[0]
   for i in range(1, len(tensor_list)):
     minimum = jnp.minimum(minimum, tensor_list[i])
   return minimum
예제 #13
0
 def f(log_probs, advantages, old_log_probs, mask):
     del old_log_probs  # Not used in AWR.
     weights = jnp.minimum(awr_weights(advantages, beta), w_max)
     return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask)
예제 #14
0
def SaturationCost(x, limit=0.9):
  return jnp.minimum(0, jnp.abs(x) - limit)