Beispiel #1
0
def Mean(axis=-1, keepdims=False):
    """Returns a layer that computes mean values using one tensor axis.

  `Mean` uses one tensor axis to form groups of values and replaces each group
  with the mean value of that group. The resulting values can either remain
  in their own size 1 axis (`keepdims=True`), or that axis can be removed from
  the overall tensor (default `keepdims=False`), lowering the rank of the
  tensor by one.

  Args:
    axis: Axis along which values are grouped for computing a mean.
    keepdims: If `True`, keep the resulting size 1 axis as a separate tensor
        axis; else, remove that axis.
  """
    return Fn('Mean', lambda x: jnp.mean(x, axis=axis, keepdims=keepdims))
Beispiel #2
0
def _WeightedSequenceMean():
    """Returns a layer that computes a weighted sequence accuracy mean."""
    def f(values, weights):  # pylint: disable=invalid-name
        # This function assumes weights are 0 or 1.
        # Then compute 1: not-correct, 0: correct or masked
        not_correct = (1.0 - values) * weights
        axis_to_sum = list(range(1, len(not_correct.shape)))
        # Summing not-correct on all axes but batch. We're summing 0s and 1s,
        # so the sum is 0 if it's all 0 and >=1 in all other cases.
        not_correct_seq = jnp.sum(not_correct, axis=axis_to_sum)
        # Sequence is correct if not_correct_seq is 0, reverting here.
        correct_seq = 1.0 - jnp.minimum(1.0, not_correct_seq)
        return jnp.mean(correct_seq)  # Mean over batch.

    return Fn('_WeightedSequenceMean', f)
Beispiel #3
0
def Sum(axis=None, keepdims=False):
    """Returns a layer that computes sums using one tensor axis.

  `Sum` uses one tensor axis to form groups of values and replaces each group
  with the sum of that group. The resulting sum values can either remain in
  their own size 1 axis (`keepdims=True`), or that axis can be removed from the
  overall tensor (default `keepdims=False`), lowering the rank of the tensor by
  one.

  Args:
    axis: Axis along which values are grouped for computing a sum; if None,
        compute sum over all elements in tensor.
    keepdims: If `True`, keep the resulting size 1 axis as a separate tensor
        axis; else, remove that axis.
  """
    return Fn('Sum', lambda x: jnp.sum(x, axis=axis, keepdims=keepdims))
Beispiel #4
0
def ShiftRight(n_positions=1, mode='train'):
  """Returns a layer that can insert padding to shift the input sequence.

  Args:
    n_positions: Number of positions to shift the input sequence rightward;
        initial positions freed by the shift get padded with zeros. Applies
        only if layer is created in a non-``'eval'`` mode.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
  """
  # TODO(jonni): Include pad arg, like PaddingMask, to allow non-default pads?
  def f(x):
    if mode == 'predict':
      return x
    padded = _zero_pad(x, (n_positions, 0), 1)
    return padded[:, :-n_positions]
  return Fn(f'ShiftRight({n_positions})', f)
Beispiel #5
0
def Selu(alpha=1.6732632423543772848170429916717,
         lmbda=1.0507009873554804934193349852946):
    r"""Returns an `Elu`-like layer with an additional scaling/slope parameter.

  .. math::
      f(x) = \left\{ \begin{array}{cl}
          \lambda \cdot \alpha \cdot (e^x - 1) & \text{if}\ x \leq 0, \\
          \lambda \cdot x                      & \text{otherwise}.
      \end{array} \right.

  Args:
    alpha: Coefficient multiplying the exponential, for negative inputs.
    lmbda: Coefficient scaling the whole function.
  """
    return Fn('Selu',
              lambda x: lmbda * jnp.where(x > 0, x, alpha * jnp.expm1(x)))
Beispiel #6
0
def LogSoftmax(axis=-1):
    """Returns a layer that applies log softmax along one tensor axis.

  Note that the implementation actually computes x - LogSumExp(x),
  which is mathematically equal to LogSoftmax(x).

  `LogSoftmax` acts on a group of values and normalizes them to look like a set
  of log probability values. (Probability values must be non-negative, and as
  a set must sum to 1. A group of log probability values can be seen as the
  natural logarithm function applied to a set of probability values.)

  Args:
    axis: Axis along which values are grouped for computing log softmax.
  """
    return Fn('LogSoftmax',
              lambda x: x - fastmath.logsumexp(x, axis, keepdims=True))
Beispiel #7
0
def MergeHeads(n_heads, merged_batch_and_head=True):
    """Returns a layer that undoes splitting, after multi-head computation."""
    def f(x):
        if merged_batch_and_head:
            batchheads, seq_len, d_head = x.shape
            assert batchheads % n_heads == 0
            batch_size = batchheads // n_heads
            x = x.reshape((batch_size, n_heads, seq_len, d_head))
        else:
            batch_size, _, seq_len, d_head = x.shape

        # (b_size, n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature)
        x = x.transpose((0, 2, 1, 3))
        x = x.reshape((batch_size, seq_len, n_heads * d_head))
        return x

    return Fn('MergeHeads', f)
Beispiel #8
0
def ShiftRight(n_positions=1, mode='train'):
    """Returns a layer that can insert padding to shift the input sequence.

  Args:
    n_positions: Number of positions to shift the input sequence rightward;
        initial positions freed by the shift get padded with zeros.
    mode: If `'train'`, perform the specified shift; if `'predict'`, do nothing.
  """

    # TODO(jonni): Include pad arg, like PaddingMask, to allow non-default pads?
    def f(x):
        if mode == 'predict':
            return x
        padded = _zero_pad(x, (n_positions, 0), 1)
        return padded[:, :-n_positions]

    return Fn(f'ShiftRight({n_positions})', f)
Beispiel #9
0
def L2Loss():
    """Returns a layer that computes total L2 loss for one batch."""
    def f(model_output, targets, weights):  # pylint: disable=invalid-name
        """Returns elementwise-weighted L2 norm of `model_output - targets`.

    Args:
      model_output: Output from one batch, treated as an unanalyzed tensor.
      targets: Tensor of same shape as `model_output` containing element-wise
          target values.
      weights: Tensor of same shape as `model_output` and `targets`.
    """
        shapes.assert_same_shape(model_output, targets)
        shapes.assert_same_shape(targets, weights)
        l2 = weights * (model_output - targets)**2
        return jnp.sum(l2) / jnp.sum(weights)

    return Fn('L2Loss', f)
Beispiel #10
0
    def test_output_signature(self):
        input_signature = (ShapeDtype((2, 3, 5)), ShapeDtype((2, 3, 5)))
        layer = Fn('2in1out', lambda x, y: x + y)
        output_signature = layer.output_signature(input_signature)
        self.assertEqual(output_signature, ShapeDtype((2, 3, 5)))

        input_signature = ShapeDtype((5, 7))
        layer = Fn('1in3out', lambda x: (x, 2 * x, 3 * x), n_out=3)
        output_signature = layer.output_signature(input_signature)
        self.assertEqual(output_signature, (ShapeDtype((5, 7)), ) * 3)
        self.assertNotEqual(output_signature, (ShapeDtype((4, 7)), ) * 3)
        self.assertNotEqual(output_signature, (ShapeDtype((5, 7)), ) * 2)
Beispiel #11
0
def SplitIntoHeads(n_heads, merged_batch_and_head=True):
  """Returns a layer that reshapes an array for multi-head computation."""
  def f(x):
    batch_size, seq_len, d_feature = x.shape
    if d_feature % n_heads != 0:
      raise ValueError(
          f'Feature embedding dimensionality ({d_feature}) is not a multiple'
          f' of the requested number of attention heads ({n_heads}).')

    d_head = d_feature // n_heads

    # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head)
    x = x.reshape((batch_size, seq_len, n_heads, d_head))
    x = x.transpose((0, 2, 1, 3))
    if merged_batch_and_head:
      x = x.reshape((batch_size * n_heads, seq_len, d_head))
    return x
  return Fn('SplitIntoHeads', f)
Beispiel #12
0
def L2Loss():
    """Returns a layer that computes an L2-like loss for one batch."""
    def f(model_output, targets, weights):  # pylint: disable=invalid-name
        """Returns weighted sum-of-squared-errors for `model_output` vs. `targets`.

    Args:
      model_output: Output from one batch, typically a 2- or 3-d array of
          float-valued elements.
      targets: Tensor of same shape as `model_output` containing element-wise
          target values.
      weights: Tensor of same shape as `model_output` and `targets`, containing
          element-wise weight values.
    """
        shapes.assert_same_shape(model_output, targets)
        shapes.assert_same_shape(targets, weights)
        weighted_sse = weights * (model_output - targets)**2
        return jnp.sum(weighted_sse) / jnp.sum(weights)

    return Fn('L2Loss', f)
Beispiel #13
0
def Flatten(n_axes_to_keep=1):
  """Returns a layer that combines one or more trailing axes of a tensor.

  Flattening keeps all the values of the input tensor, but reshapes it by
  collapsing one or more trailing axes into a single axis. For example, a
  `Flatten(n_axes_to_keep=2)` layer would map a tensor with shape
  `(2, 3, 5, 7, 11)` to the same values with shape `(2, 3, 385)`.

  Args:
    n_axes_to_keep: Number of leading axes to leave unchanged when reshaping;
        collapse only the axes after these.
  """
  layer_name = f'Flatten_keep{n_axes_to_keep}'
  def f(x):  # pylint: disable=invalid-name
    in_rank = len(x.shape)
    if in_rank <= n_axes_to_keep:
      raise ValueError(f'Input rank ({in_rank}) must exceed the number of '
                       f'axes to keep ({n_axes_to_keep}) after flattening.')
    return jnp.reshape(x, (x.shape[:n_axes_to_keep] + (-1,)))
  return Fn(layer_name, f)
Beispiel #14
0
def MergeHeads(n_heads, merged_batch_and_head=True):
  """Returns a layer that rejoins heads, after multi-head computation."""
  def f(x):
    if merged_batch_and_head:
      dim_0, seq_len, d_head = x.shape
      if dim_0 % n_heads != 0:
        raise ValueError(
            f"Array's leading dimension ({dim_0}) is not a multiple of the"
            f" number of attention heads ({n_heads}).")

      batch_size = dim_0 // n_heads
      x = x.reshape((batch_size, n_heads, seq_len, d_head))
    else:
      batch_size, _, seq_len, d_head = x.shape

    # (b_size, n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature)
    x = x.transpose((0, 2, 1, 3))
    x = x.reshape((batch_size, seq_len, n_heads * d_head))
    return x
  return Fn('MergeHeads', f)
Beispiel #15
0
def PaddingMask(pad=0):
  """Returns a layer that maps integer sequences to padding masks.

  The layer expects as input a batch of integer sequences. The layer output is
  a tensor that marks for each sequence position whether the integer (e.g., a
  token ID) in that position represents padding -- value `pad` -- versus
  text/content -- all other values. The padding mask shape is
  (batch_size, 1, 1, encoder_sequence_length), such that axis 1 will broadcast
  to cover any number of attention heads and axis 2 will broadcast to cover
  decoder sequence positions.

  Args:
    pad: Integer that represents padding rather than a token/content ID.
  """
  def f(x):
    if len(x.shape) != 2:
      raise ValueError(
          f'Input to PaddingMask must be a rank 2 tensor with shape '
          f'(batch_size, sequence_length); instead got shape {x.shape}.')
    batch_size = x.shape[0]
    sequence_length = x.shape[1]
    content_positions = (x != pad)
    return content_positions.reshape((batch_size, 1, 1, sequence_length))
  return Fn(f'PaddingMask({pad})', f)
Beispiel #16
0
def Multiply():
    """Multiplies two tensors."""
    return Fn('Multiply', lambda x0, x1: x0 * x1)
Beispiel #17
0
def SubtractTop():
    """Subtracts the first tensor from the second."""
    return Fn('SubtractTop', lambda x0, x1: x1 - x0)
Beispiel #18
0
def Add():
    """Adds two tensors."""
    return Fn('Add', lambda x0, x1: x0 + x1)
Beispiel #19
0
def FlattenList():
    """Flatten lists."""
    # TODO(jonni): Consider renaming layer to DeepFlatten.
    return Fn('FlattenList', lambda x: tuple(_deep_flatten(x)))
Beispiel #20
0
def Swap():
    """Swaps the top two stack elements."""
    return Fn('Swap', lambda x0, x1: (x1, x0), n_out=2)
Beispiel #21
0
def Dup():
    """Duplicates (copies) the top element on the data stack."""
    return Fn('Dup', lambda x: (x, x), n_out=2)
Beispiel #22
0
def Negate():
    return Fn('Negate', lambda x: -x)
Beispiel #23
0
def ThresholdToBinary(threshold=.5):
    """Returns a layer that thresholds inputs to yield outputs in {0, 1}."""
    def f(model_output):  # pylint: disable=invalid-name
        return (model_output > threshold).astype(jnp.int32)

    return Fn('ThresholdToBinary', f)
Beispiel #24
0
def ToFloat():
    """Returns a layer that changes the dtype of a tensor to `float32`."""
    return Fn('ToFloat', lambda x: x.astype(np.float32))
Beispiel #25
0
def Negate():
    """Returns a layer that computes the element-wise negation of a tensor."""
    return Fn('Negate', lambda x: -x)
Beispiel #26
0
def Drop():
    """Drops the top stack element."""
    return Fn('Drop', lambda x: (), n_out=0)
Beispiel #27
0
def ArgMax(axis=-1):
    """Returns a layer that calculates argmax along the given axis."""
    def f(model_output):  # pylint: disable=invalid-name
        return jnp.argmax(model_output, axis=axis)

    return Fn('ArgMax', f)
Beispiel #28
0
def Mean(axis=-1, keepdims=False):
    return Fn('Mean', lambda x: jnp.mean(x, axis=axis, keepdims=keepdims))
Beispiel #29
0
def StopGradient():
    """Returns an identity layer with a stop gradient."""
    return Fn('StopGradient', lambda x: fastmath.stop_gradient(x))  # pylint: disable=unnecessary-lambda
Beispiel #30
0
def Sum(axis=-1, keepdims=False):
    return Fn('Sum', lambda x: jnp.sum(x, axis=axis, keepdims=keepdims))