Beispiel #1
0
 def init(self, params):
     shape = params.shape
     slots = []
     if self._factored and len(shape) >= 2:
         v_row = np.zeros(shape[:-1], dtype=np.float32)
         v_col = np.zeros(shape[:-2] + shape[-1:], dtype=np.float32)
         slots.extend([v_row, v_col])
     else:
         v = np.zeros_like(params)
         slots.append(v)
     if self._do_momentum:
         m = np.zeros_like(params)
         slots.append(m)
     return slots
Beispiel #2
0
def DotProductAttention(query, key, value, mask, dropout, mode, rng):
    """Core dot product self-attention.

  Args:
    query: array of representations
    key: array of representations
    value: array of representations
    mask: attention-mask, gates attention
    dropout: float: dropout rate
    mode: 'eval' or 'train': whether to use dropout
    rng: JAX PRNGKey: subkey for disposable use

  Returns:
    Self attention for q, k, v arrays.
  """
    depth = np.shape(query)[-1]
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if math.backend_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = np.where(mask, dots, np.full_like(dots, -1e9))
    # Softmax.
    dots = np.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = np.where(keep, dots / (1.0 - dropout), np.zeros_like(dots))
    out = np.matmul(dots, value)
    return out
Beispiel #3
0
 def forward_with_state(self, x, weights, state, rng):
     batch_size, length = x.shape[0], x.shape[1]
     max_pos = min(self._bases)**self._n_digits
     rng1, rng2, rng3 = math.random.split(rng, 3)
     assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length,
                                                               max_pos)
     positions = jnp.arange(0, length)[None, :]
     if self._mode == 'train':
         # In 1% of training cases still start from 0 to be exactly as in eval.
         start_from_nonzero = jax.random.randint(
             rng1, (batch_size, ), 0, self._start_from_zero_one_in)
         start_from_nonzero = jnp.minimum(1, start_from_nonzero)
         random_start = jax.random.randint(rng2, (batch_size, ), 0,
                                           max_pos - length)
         random_start *= start_from_nonzero
         positions += random_start[:, None]
     res = []
     for bn, base in enumerate(self._bases):
         pos_embeddings = []
         cur_positions = positions
         for i in range(self._n_digits):
             cur_indices = jnp.mod(cur_positions, base)
             cur_positions = cur_positions // base
             s = weights[bn][i]
             pos_embeddings.append(
                 cur_indices.astype(jnp.float32)[:, :, None] * s)
         embeddings = jnp.concatenate(pos_embeddings, axis=-1)
         if self._mode == 'train':
             base_dropout = jax.random.randint(rng3, (batch_size, ), 0,
                                               self._base_dropout_one_in)
             base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32)
             embeddings *= base_dropout[:, None, None]
         res.append(embeddings)
     res = sum(res) + jnp.zeros_like(x)
     return jnp.concatenate([x, res], axis=-1), state
Beispiel #4
0
 def f(preds, values, returns, actions, mask):
   advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1)
   logps = self._policy_dist.log_prob(preds, actions)
   awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)(
       (logps, advantages, jnp.zeros_like(logps), mask))
   l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff
   return awr_loss + l2_value_loss
Beispiel #5
0
def PPOObjective(dist_inputs, values, returns, dones, rewards, actions,
                 old_log_probs, log_prob_fun, epsilon, normalize_advantages):
    """PPO Objective."""
    # dist_inputs of the shape float32[128,1,18]
    # values of the shape float32[128,1,1]
    # returns of the shape float32[128,1,1]
    # dones of the shape float32[128,1,1]
    # rewards of the shape int32[128,1,1]
    # actions of the shape int32[128,1]
    # and old_log_probs of the shape float32[128,1]
    returns = returns.squeeze(axis=2)
    values = values.squeeze(axis=2)
    dones = dones.squeeze(axis=2)
    rewards = rewards.squeeze(axis=2)
    assert rewards.shape == dones.shape, (
        f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}')
    assert dones.shape == values.shape, (
        f'dones.shape was {dones.shape} and values.shape was {values.shape}')
    assert returns.shape == values.shape, (
        f'returns.shape was {returns.shape} and values.shape was {values.shape}'
    )
    assert returns.shape == old_log_probs.shape, (
        f'returns.shape was {returns.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')

    probs_ratio = ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun)
    assert probs_ratio.shape == old_log_probs.shape, (
        f'probs_ratio.shape was {probs_ratio.shape} and'
        f'old_log_probs.shape was {old_log_probs.shape}')

    # jaxified versions of
    # returns[dones] = rewards[dones]
    # values[dones] = 0
    returns = jnp.where(dones, rewards, returns)
    values = jnp.where(dones, jnp.zeros_like(values), values)
    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    assert old_log_probs.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and advantages.shape was '
        f'{advantages.shape}')

    unclipped_objective = UnclippedObjective(probs_ratio, advantages)
    assert unclipped_objective.shape == advantages.shape, (
        f'old_log_probs.shape was {old_log_probs.shape} and'
        f'unclipped_objective.shape was {unclipped_objective.shape}')

    clipped_objective = ClippedObjective(probs_ratio, advantages, epsilon)
    assert clipped_objective.shape == advantages.shape, (
        f'clipped_objective.shape was {clipped_objective.shape} and'
        f'advantages.shape was {advantages.shape}')

    ppo_objective = jnp.minimum(unclipped_objective, clipped_objective)
    assert ppo_objective.shape == advantages.shape, (
        f'ppo_objective.shape was {ppo_objective.shape} and'
        f'advantages.shape was {advantages.shape}')

    return ppo_objective
Beispiel #6
0
 def AWRJointLoss(x, **unused_kwargs):  # pylint: disable=invalid-name
   preds, values, returns, actions, mask = x
   advantages = jnp.squeeze(returns - values, axis=-1)
   logps = self._policy_dist.log_prob(preds, actions)
   awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)(
       (logps, advantages, jnp.zeros_like(logps), mask))
   l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff
   return awr_loss + l2_value_loss
Beispiel #7
0
 def _UpdateRow(x):
     # row_e - (L1, H), row_d - (L2, H), row_mask_e - (L1,)
     row_e, row_d, row_mask_e = x
     # final_row - (L1+L2, H)
     final_row = jnp.concatenate([row_e, jnp.zeros_like(row_d)], axis=0)
     # Find the last real token/vector of the encoder.
     e_idx = jnp.sum(row_mask_e, dtype=jnp.int32)
     # Starting after that index, update with the decoder row.
     return jax.lax.dynamic_update_slice(final_row, row_d, (e_idx, 0))
Beispiel #8
0
def Relu():
  r"""Returns a layer that computes the Rectified Linear Unit (ReLU) function.

  .. math::
      f(x) = \left\{ \begin{array}{cl}
          0 & \text{if}\ x \leq 0, \\
          x & \text{otherwise}.
      \end{array} \right.
  """
  return Fn('Relu', lambda x: np.where(x <= 0, np.zeros_like(x), x))
Beispiel #9
0
 def _update_diagonal(self, grads, weights, m, v, opt_params):
   learning_rate = opt_params['learning_rate']
   momentum = opt_params['momentum']
   v[0] += grads * grads
   preconditioner = np.where(v[0] > 0, 1.0 / np.sqrt(v[0]),
                             np.zeros_like(v[0]))
   preconditioned_grads = preconditioner * grads
   m = (1 - momentum) * preconditioned_grads + momentum * m
   weights = weights - (learning_rate * m).astype(weights.dtype)
   return weights, (m, v)
Beispiel #10
0
 def forward(self, x, weights):
     """Execute dropout."""
     if self._mode != 'train':
         return x
     state, rng = self.state, self.rng
     rate = self._initial_rate
     if isinstance(state, dict) and self._name in state:
         rate = state[self._name]
     keep = math.random.bernoulli(rng, 1.0 - rate, x.shape)
     return jnp.where(keep, x / (1.0 - rate), jnp.zeros_like(x))
Beispiel #11
0
 def batches_stream(self):
   """Use the RLTask self._task to create inputs to the value model."""
   for np_trajectory in self._task.trajectory_batch_stream(
       self._batch_size, max_slice_length=self._max_slice_length, epochs=[-1]):
     # Insert an extra depth dimension, so the target shape is consistent with
     # the network output shape.
     yield (np_trajectory.observations,         # Inputs to the value model.
            np_trajectory.returns[:, :, None],
            np_trajectory.actions,
            jnp.zeros_like(np_trajectory.mask),
            np_trajectory.mask)
Beispiel #12
0
def A2CObjective(dist_inputs, values, returns, dones, rewards, actions, mask,
                 log_prob_fun, normalize_advantages):
    """Definition of the Advantage Actor Critic (A2C) loss."""
    # dist_inputs of the shape float32[128,1,18]
    # values of the shape float32[128,1,1]
    # returns of the shape float32[128,1,1]
    # dones of the shape int32[128,1,1]
    # actions of the shape int32[128,1]
    # and mask of the shape float32[128,1]
    # We have to squeeze values and returns, because we
    # are planning to compute (return - values) * new_log_probs * mask
    # and all of them should be of the same dimension
    values = values.squeeze(axis=2)
    returns = returns.squeeze(axis=2)
    dones = dones.squeeze(axis=2)
    rewards = rewards.squeeze(axis=2)
    assert rewards.shape == dones.shape, (
        f'rewards.shape was {rewards.shape} and dones.shape was {dones.shape}')
    assert dones.shape == values.shape, (
        f'dones.shape was {dones.shape} and values.shape was {values.shape}')
    assert returns.shape == values.shape, (
        f'returns.shape was {returns.shape} and values.shape was {values.shape}'
    )
    assert values.shape == mask.shape, (
        f'values.shape was {values.shape} and mask.shape was {mask.shape}')
    assert returns.shape[0] == dist_inputs.shape[0], (
        f'returns.shape[0] was {returns.shape[0]} and dist_inputs.shape[0] was '
        f'{dist_inputs.shape[0]}')

    new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun)
    assert new_log_probs.shape == mask.shape, (
        f'new_log_probs.shape was {new_log_probs.shape} and mask.shape was '
        f'{mask.shape}')

    # jaxified versions of
    # returns[dones] = rewards[dones]
    # values[dones] = 0
    returns = jnp.where(dones, rewards, returns)
    values = jnp.where(dones, jnp.zeros_like(values), values)
    advantages = returns - values
    if normalize_advantages:
        advantages = advantages - jnp.mean(advantages)
        advantages /= jnp.std(advantages) + 1e-8
    assert new_log_probs.shape == advantages.shape, (
        f'new_log_probs.shape was {new_log_probs.shape} and advantages.shape was '
        f'{advantages.shape}')

    # One of the motivation to the squeezes and assertions is to
    # avoid [128,1] * [128,1,1] * [128] multiplications in the definition
    # of the a2c objective - we insist on the same shapes
    a2c_objective = -jnp.sum(new_log_probs * advantages * mask) / jnp.sum(mask)
    return a2c_objective
Beispiel #13
0
def ParametricRelu(a=1.):
  r"""Returns a layer that computes a ReLU function with the given slope.

  .. math::
      f(x) = \left\{ \begin{array}{cl}
          0  & \text{if}\ x \leq 0, \\
          ax & \text{otherwise}.
      \end{array} \right.

  Args:
    a: Slope of line for positive inputs.
  """
  return Fn('ParametricRelu', lambda x: np.maximum(a * x, np.zeros_like(x)))
Beispiel #14
0
 def forward_with_state(self, x, weights, state, rng):
   """Execute dropout."""
   if self._mode != 'train':
     return x, state
   rate = self._initial_rate
   if isinstance(state, dict) and self._name in state:
     rate = state[self._name]
   if rng is None:
     msg = ('Dropout layer requires apply_fn to be called with a rng keyword '
            'argument. That is, instead of `Dropout(weights, inputs)`, call '
            'it like `Dropout(weights, inputs, rng=key)`.')
     raise ValueError(msg)
   keep = math.random.bernoulli(rng, 1.0 - rate, x.shape)
   return jnp.where(keep, x / (1.0 - rate), jnp.zeros_like(x)), state
Beispiel #15
0
def mask_self_attention(
    dots, q_info, kv_info, causal=True, exclude_self=True, masked=False):
  """Performs masking for self-attention."""
  if causal:
    mask = jax.lax.convert_element_type(jax.lax.lt(q_info, kv_info), np.float32)
    dots = dots - 1e9 * mask
  if exclude_self:
    mask = jax.lax.convert_element_type(jax.lax.eq(q_info, kv_info), np.float32)
    dots = dots - 1e5 * mask
  if masked:
    zeros_like_kv_info = jax.lax.tie_in(kv_info, np.zeros_like(kv_info))
    mask = jax.lax.convert_element_type(
        jax.lax.lt(kv_info, zeros_like_kv_info), np.float32)
    dots = dots - 1e9 * mask
  return dots
Beispiel #16
0
  def forward(self, x):
    """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of activations.

    Returns:
      Tensor of same shape and dtype as the input.
    """
    if self._mode != 'train':
      return x
    state, rng = self.state, self.rng
    rate = self._initial_rate
    if isinstance(state, dict) and self._name in state:
      rate = state[self._name]
    keep = math.random.bernoulli(rng, 1.0 - rate, x.shape)
    return jnp.where(keep, x / (1.0 - rate), jnp.zeros_like(x))
Beispiel #17
0
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng):
    """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head `(queries, keys)`,
    - applies `mask` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head `values` vectors.

  Args:
    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for dropout applied to attention activations
        (based on query-key pairs) before dotting them with values.
    mode: Either 'train' or eval'. Dropout applies only in 'train' mode.
    rng: Single-use random number generator (JAX PRNG key).

  Returns:
    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
  """
    d_feature = queries.shape[-1]
    dots = jnp.matmul(queries, jnp.swapaxes(keys, -1,
                                            -2)) / jnp.sqrt(d_feature)
    if mask is not None:
        # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
        # We must ensure that both mask and the -1e9 constant have a data dependency
        # on the input. Broadcasted copies of these use a lot of memory, so they
        # should be computed at runtime (rather than being global constants).
        if math.backend_name() == 'jax':
            mask = jax.lax.tie_in(dots, mask)
        # JAX's `full_like` already ties in -1e9 to dots.
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    # Softmax.
    dots = jnp.exp(dots - math.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = math.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
    out = jnp.matmul(dots, values)
    return out
Beispiel #18
0
 def _update_sketched(self, grads, weights, m, v, opt_params):
   """Update for higher-rank parameters."""
   learning_rate = opt_params['learning_rate']
   momentum = opt_params['momentum']
   shape = weights.shape
   rank = len(shape)
   reshaped_accumulators = [np.reshape(v[i], self._expanded_shape(shape, i))
                            for i in range(rank)]
   current_accumulator = self._minimum(reshaped_accumulators)
   current_accumulator += grads * grads
   accumulator_inv_sqrt = np.where(current_accumulator > 0.0,
                                   1.0 / np.sqrt(current_accumulator),
                                   np.zeros_like(current_accumulator))
   preconditioned_gradient = grads * accumulator_inv_sqrt
   m = (1.0 - momentum) * preconditioned_gradient + momentum * m
   weights = weights - (learning_rate * m).astype(weights.dtype)
   for i in range(len(v)):
     axes = list(range(int(i))) + list(range(int(i) + 1, rank))
     dim_accumulator = np.amax(current_accumulator, axis=axes)
     v[i] = dim_accumulator
   return weights, (m, v)
Beispiel #19
0
  def F(vec_e, vec_d, mask_e, mask_d):
    # pylint: disable=invalid-name
    L1 = mask_e.shape[1]
    L2 = mask_d.shape[1]
    # pylint: enable=invalid-name

    # [-(L1+L2), -L2) but with padding 0-ed out - (B, L1).
    mask_e_key = jnp.arange(-(L1 + L2), -L2) * mask_e
    # [-L2,0) but with padding 0-ed out - (B, L2).
    mask_d_key = jnp.arange(-L2, 0) * mask_d

    # Shape (B, L1+L2, H)
    enc_dec_concat = jnp.concatenate([vec_e, vec_d], axis=1)
    # Shape (B, L1+L2)
    mask_concat = jnp.concatenate([mask_e_key, mask_d_key], axis=1)
    # Make `mask_concat` the same shape as `enc_dec_concat`
    mask_concat = (
        mask_concat[..., jnp.newaxis] +
        jnp.zeros_like(enc_dec_concat, dtype=jnp.int32))
    # Sort on `mask_concat` so padding with key=0 goes to the right end, axis=1.
    _, enc_dec_pad = math.sort_key_val(mask_concat, enc_dec_concat, 1)

    return enc_dec_pad
Beispiel #20
0
def threefry_2x32_prange(key, lo: int = 0, hi: int = 2):
    """Splits a key into a stream of random keys.

  This uses the little-endian counter mode.

  Args:
    key: uint32[2] the key to split
    lo: the range to start extracting from
    hi: the range to stop extracting from

  Returns:
    keys: uint32[hi - lo, 2] the split keys
  """
    if not (key.shape == (2, ) and key.dtype == np.uint32):
        raise ValueError('key must be uint32[2]')
    if not hi < 2**32:
        # You shouldn't really be using more than half the key size anyways.
        raise NotImplementedError('only 32-bit sizes are supported')
    # Create a 64-bit counter:
    i_lo = np.arange(lo, hi, dtype=np.uint32)
    i_hi = np.zeros_like(i_lo)
    i = np.stack([i_lo, i_hi], axis=-1)
    return threefry_2x32_prf(key, i)
Beispiel #21
0
def Relu(x, **unused_kwargs):
    return np.maximum(x, np.zeros_like(x))
Beispiel #22
0
 def backward(self, inputs, output, grad, weights, state, new_state,
              rng):
     return (np.zeros_like(grad), ())
Beispiel #23
0
 def init(self, weights):
   m = np.zeros_like(weights)
   v = np.zeros_like(weights)
   return m, v
Beispiel #24
0
def ParametricRelu(x, a=1., **unused_kwargs):
    return np.maximum(a * x, np.zeros_like(x))
Beispiel #25
0
def Relu():
    return Fn('Relu', lambda x: np.maximum(x, np.zeros_like(x)))
Beispiel #26
0
 def init(self, params):
   vs = [np.zeros(sz, dtype=params.dtype) for sz in params.shape]
   return (np.zeros_like(params), vs)
Beispiel #27
0
def ParametricRelu(a=1.):
    return Fn('ParametricRelu', lambda x: np.maximum(a * x, np.zeros_like(x)))
Beispiel #28
0
 def init(self, weights):
     return np.zeros_like(weights)
Beispiel #29
0
 def backward(self, inputs, output, ct, weights, state, new_state,
              **kwargs):
     return (np.zeros_like(ct), ())
Beispiel #30
0
 def init(self, params):
     return np.zeros_like(params)