def Mean(axis=-1, keepdims=False): """Returns a layer that computes mean values using one tensor axis. `Mean` uses one tensor axis to form groups of values and replaces each group with the mean value of that group. The resulting values can either remain in their own size 1 axis (`keepdims=True`), or that axis can be removed from the overall tensor (default `keepdims=False`), lowering the rank of the tensor by one. Args: axis: Axis along which values are grouped for computing a mean. keepdims: If `True`, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis. """ return Fn('Mean', lambda x: jnp.mean(x, axis=axis, keepdims=keepdims))
def _WeightedSequenceMean(): """Returns a layer that computes a weighted sequence accuracy mean.""" def f(values, weights): # pylint: disable=invalid-name # This function assumes weights are 0 or 1. # Then compute 1: not-correct, 0: correct or masked not_correct = (1.0 - values) * weights axis_to_sum = list(range(1, len(not_correct.shape))) # Summing not-correct on all axes but batch. We're summing 0s and 1s, # so the sum is 0 if it's all 0 and >=1 in all other cases. not_correct_seq = jnp.sum(not_correct, axis=axis_to_sum) # Sequence is correct if not_correct_seq is 0, reverting here. correct_seq = 1.0 - jnp.minimum(1.0, not_correct_seq) return jnp.mean(correct_seq) # Mean over batch. return Fn('_WeightedSequenceMean', f)
def Sum(axis=None, keepdims=False): """Returns a layer that computes sums using one tensor axis. `Sum` uses one tensor axis to form groups of values and replaces each group with the sum of that group. The resulting sum values can either remain in their own size 1 axis (`keepdims=True`), or that axis can be removed from the overall tensor (default `keepdims=False`), lowering the rank of the tensor by one. Args: axis: Axis along which values are grouped for computing a sum; if None, compute sum over all elements in tensor. keepdims: If `True`, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis. """ return Fn('Sum', lambda x: jnp.sum(x, axis=axis, keepdims=keepdims))
def ShiftRight(n_positions=1, mode='train'): """Returns a layer that can insert padding to shift the input sequence. Args: n_positions: Number of positions to shift the input sequence rightward; initial positions freed by the shift get padded with zeros. Applies only if layer is created in a non-``'eval'`` mode. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. """ # TODO(jonni): Include pad arg, like PaddingMask, to allow non-default pads? def f(x): if mode == 'predict': return x padded = _zero_pad(x, (n_positions, 0), 1) return padded[:, :-n_positions] return Fn(f'ShiftRight({n_positions})', f)
def Selu(alpha=1.6732632423543772848170429916717, lmbda=1.0507009873554804934193349852946): r"""Returns an `Elu`-like layer with an additional scaling/slope parameter. .. math:: f(x) = \left\{ \begin{array}{cl} \lambda \cdot \alpha \cdot (e^x - 1) & \text{if}\ x \leq 0, \\ \lambda \cdot x & \text{otherwise}. \end{array} \right. Args: alpha: Coefficient multiplying the exponential, for negative inputs. lmbda: Coefficient scaling the whole function. """ return Fn('Selu', lambda x: lmbda * jnp.where(x > 0, x, alpha * jnp.expm1(x)))
def LogSoftmax(axis=-1): """Returns a layer that applies log softmax along one tensor axis. Note that the implementation actually computes x - LogSumExp(x), which is mathematically equal to LogSoftmax(x). `LogSoftmax` acts on a group of values and normalizes them to look like a set of log probability values. (Probability values must be non-negative, and as a set must sum to 1. A group of log probability values can be seen as the natural logarithm function applied to a set of probability values.) Args: axis: Axis along which values are grouped for computing log softmax. """ return Fn('LogSoftmax', lambda x: x - fastmath.logsumexp(x, axis, keepdims=True))
def MergeHeads(n_heads, merged_batch_and_head=True): """Returns a layer that undoes splitting, after multi-head computation.""" def f(x): if merged_batch_and_head: batchheads, seq_len, d_head = x.shape assert batchheads % n_heads == 0 batch_size = batchheads // n_heads x = x.reshape((batch_size, n_heads, seq_len, d_head)) else: batch_size, _, seq_len, d_head = x.shape # (b_size, n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) x = x.transpose((0, 2, 1, 3)) x = x.reshape((batch_size, seq_len, n_heads * d_head)) return x return Fn('MergeHeads', f)
def ShiftRight(n_positions=1, mode='train'): """Returns a layer that can insert padding to shift the input sequence. Args: n_positions: Number of positions to shift the input sequence rightward; initial positions freed by the shift get padded with zeros. mode: If `'train'`, perform the specified shift; if `'predict'`, do nothing. """ # TODO(jonni): Include pad arg, like PaddingMask, to allow non-default pads? def f(x): if mode == 'predict': return x padded = _zero_pad(x, (n_positions, 0), 1) return padded[:, :-n_positions] return Fn(f'ShiftRight({n_positions})', f)
def L2Loss(): """Returns a layer that computes total L2 loss for one batch.""" def f(model_output, targets, weights): # pylint: disable=invalid-name """Returns elementwise-weighted L2 norm of `model_output - targets`. Args: model_output: Output from one batch, treated as an unanalyzed tensor. targets: Tensor of same shape as `model_output` containing element-wise target values. weights: Tensor of same shape as `model_output` and `targets`. """ shapes.assert_same_shape(model_output, targets) shapes.assert_same_shape(targets, weights) l2 = weights * (model_output - targets)**2 return jnp.sum(l2) / jnp.sum(weights) return Fn('L2Loss', f)
def test_output_signature(self): input_signature = (ShapeDtype((2, 3, 5)), ShapeDtype((2, 3, 5))) layer = Fn('2in1out', lambda x, y: x + y) output_signature = layer.output_signature(input_signature) self.assertEqual(output_signature, ShapeDtype((2, 3, 5))) input_signature = ShapeDtype((5, 7)) layer = Fn('1in3out', lambda x: (x, 2 * x, 3 * x), n_out=3) output_signature = layer.output_signature(input_signature) self.assertEqual(output_signature, (ShapeDtype((5, 7)), ) * 3) self.assertNotEqual(output_signature, (ShapeDtype((4, 7)), ) * 3) self.assertNotEqual(output_signature, (ShapeDtype((5, 7)), ) * 2)
def SplitIntoHeads(n_heads, merged_batch_and_head=True): """Returns a layer that reshapes an array for multi-head computation.""" def f(x): batch_size, seq_len, d_feature = x.shape if d_feature % n_heads != 0: raise ValueError( f'Feature embedding dimensionality ({d_feature}) is not a multiple' f' of the requested number of attention heads ({n_heads}).') d_head = d_feature // n_heads # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) x = x.reshape((batch_size, seq_len, n_heads, d_head)) x = x.transpose((0, 2, 1, 3)) if merged_batch_and_head: x = x.reshape((batch_size * n_heads, seq_len, d_head)) return x return Fn('SplitIntoHeads', f)
def L2Loss(): """Returns a layer that computes an L2-like loss for one batch.""" def f(model_output, targets, weights): # pylint: disable=invalid-name """Returns weighted sum-of-squared-errors for `model_output` vs. `targets`. Args: model_output: Output from one batch, typically a 2- or 3-d array of float-valued elements. targets: Tensor of same shape as `model_output` containing element-wise target values. weights: Tensor of same shape as `model_output` and `targets`, containing element-wise weight values. """ shapes.assert_same_shape(model_output, targets) shapes.assert_same_shape(targets, weights) weighted_sse = weights * (model_output - targets)**2 return jnp.sum(weighted_sse) / jnp.sum(weights) return Fn('L2Loss', f)
def Flatten(n_axes_to_keep=1): """Returns a layer that combines one or more trailing axes of a tensor. Flattening keeps all the values of the input tensor, but reshapes it by collapsing one or more trailing axes into a single axis. For example, a `Flatten(n_axes_to_keep=2)` layer would map a tensor with shape `(2, 3, 5, 7, 11)` to the same values with shape `(2, 3, 385)`. Args: n_axes_to_keep: Number of leading axes to leave unchanged when reshaping; collapse only the axes after these. """ layer_name = f'Flatten_keep{n_axes_to_keep}' def f(x): # pylint: disable=invalid-name in_rank = len(x.shape) if in_rank <= n_axes_to_keep: raise ValueError(f'Input rank ({in_rank}) must exceed the number of ' f'axes to keep ({n_axes_to_keep}) after flattening.') return jnp.reshape(x, (x.shape[:n_axes_to_keep] + (-1,))) return Fn(layer_name, f)
def MergeHeads(n_heads, merged_batch_and_head=True): """Returns a layer that rejoins heads, after multi-head computation.""" def f(x): if merged_batch_and_head: dim_0, seq_len, d_head = x.shape if dim_0 % n_heads != 0: raise ValueError( f"Array's leading dimension ({dim_0}) is not a multiple of the" f" number of attention heads ({n_heads}).") batch_size = dim_0 // n_heads x = x.reshape((batch_size, n_heads, seq_len, d_head)) else: batch_size, _, seq_len, d_head = x.shape # (b_size, n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) x = x.transpose((0, 2, 1, 3)) x = x.reshape((batch_size, seq_len, n_heads * d_head)) return x return Fn('MergeHeads', f)
def PaddingMask(pad=0): """Returns a layer that maps integer sequences to padding masks. The layer expects as input a batch of integer sequences. The layer output is a tensor that marks for each sequence position whether the integer (e.g., a token ID) in that position represents padding -- value `pad` -- versus text/content -- all other values. The padding mask shape is (batch_size, 1, 1, encoder_sequence_length), such that axis 1 will broadcast to cover any number of attention heads and axis 2 will broadcast to cover decoder sequence positions. Args: pad: Integer that represents padding rather than a token/content ID. """ def f(x): if len(x.shape) != 2: raise ValueError( f'Input to PaddingMask must be a rank 2 tensor with shape ' f'(batch_size, sequence_length); instead got shape {x.shape}.') batch_size = x.shape[0] sequence_length = x.shape[1] content_positions = (x != pad) return content_positions.reshape((batch_size, 1, 1, sequence_length)) return Fn(f'PaddingMask({pad})', f)
def Multiply(): """Multiplies two tensors.""" return Fn('Multiply', lambda x0, x1: x0 * x1)
def SubtractTop(): """Subtracts the first tensor from the second.""" return Fn('SubtractTop', lambda x0, x1: x1 - x0)
def Add(): """Adds two tensors.""" return Fn('Add', lambda x0, x1: x0 + x1)
def FlattenList(): """Flatten lists.""" # TODO(jonni): Consider renaming layer to DeepFlatten. return Fn('FlattenList', lambda x: tuple(_deep_flatten(x)))
def Swap(): """Swaps the top two stack elements.""" return Fn('Swap', lambda x0, x1: (x1, x0), n_out=2)
def Dup(): """Duplicates (copies) the top element on the data stack.""" return Fn('Dup', lambda x: (x, x), n_out=2)
def Negate(): return Fn('Negate', lambda x: -x)
def ThresholdToBinary(threshold=.5): """Returns a layer that thresholds inputs to yield outputs in {0, 1}.""" def f(model_output): # pylint: disable=invalid-name return (model_output > threshold).astype(jnp.int32) return Fn('ThresholdToBinary', f)
def ToFloat(): """Returns a layer that changes the dtype of a tensor to `float32`.""" return Fn('ToFloat', lambda x: x.astype(np.float32))
def Negate(): """Returns a layer that computes the element-wise negation of a tensor.""" return Fn('Negate', lambda x: -x)
def Drop(): """Drops the top stack element.""" return Fn('Drop', lambda x: (), n_out=0)
def ArgMax(axis=-1): """Returns a layer that calculates argmax along the given axis.""" def f(model_output): # pylint: disable=invalid-name return jnp.argmax(model_output, axis=axis) return Fn('ArgMax', f)
def Mean(axis=-1, keepdims=False): return Fn('Mean', lambda x: jnp.mean(x, axis=axis, keepdims=keepdims))
def StopGradient(): """Returns an identity layer with a stop gradient.""" return Fn('StopGradient', lambda x: fastmath.stop_gradient(x)) # pylint: disable=unnecessary-lambda
def Sum(axis=-1, keepdims=False): return Fn('Sum', lambda x: jnp.sum(x, axis=axis, keepdims=keepdims))