def PaddingMask(pad=0): def f(x): return jnp.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1])) return Fn(f'PaddingMask({pad})', f)
def _WeightedMean(): """Returns a layer that computes a weighted mean of the given values.""" def f(values, weights): # pylint: disable=invalid-name return jnp.sum(values * weights) / jnp.sum(weights) return Fn('_WeightedMean', f)
def ToFloat(): """Returns a layer that changes the dtype of a tensor to `float32`.""" return Fn('ToFloat', lambda x: x.astype(np.float32))
def Exp(): """Returns a layer that computes the element-wise exponential of a tensor.""" return Fn('Exp', lambda x: jnp.exp(x)) # pylint: disable=unnecessary-lambda
def FlattenList(): """Flatten lists.""" # TODO(jonni): Consider renaming layer to DeepFlatten. return Fn('FlattenList', lambda x: tuple(_deep_flatten(x)))
def SubtractTop(): """Subtracts the first tensor from the second.""" return Fn('SubtractTop', lambda x0, x1: x1 - x0)
def Sum(axis=-1, keepdims=False): return Fn('Sum', lambda x: jnp.sum(x, axis=axis, keepdims=keepdims))
def Dup(): """Duplicates (copies) the top element on the data stack.""" return Fn('Dup', lambda x: (x, x), n_out=2)
def ToFloat(): return Fn('ToFloat', lambda x: x.astype(np.float32))
def Mean(axis=-1, keepdims=False): return Fn('Mean', lambda x: jnp.mean(x, axis=axis, keepdims=keepdims))
def Softmax(axis=-1): """Layer that applies softmax: exponentiate and normalize along given axis.""" return Fn('Softmax', lambda x: jnp.exp(x - math.logsumexp(x, axis, keepdims=True)))
def LogSoftmax(axis=-1): """Layer that applies log softmax: log-normalize along the given axis.""" return Fn('LogSoftmax', lambda x: x - math.logsumexp(x, axis, keepdims=True))
def Exp(): return Fn('Exp', lambda x: jnp.exp(x)) # pylint: disable=unnecessary-lambda
def WeightedSum(): """Returns a layer that computes a weighted sum of the given values.""" def f(values, weights): # pylint: disable=invalid-name return jnp.sum(values * weights) return Fn('WeightedSum', f)
def Negate(): return Fn('Negate', lambda x: -x)
def Drop(): """Drops the top stack element.""" return Fn('Drop', lambda x: (), n_out=0)
def _WeightedMean(): """Returns a layer to compute weighted mean over all values in the input.""" def f(values, weights): # pylint: disable=invalid-name return np.sum(values * weights) / np.sum(weights) return Fn('_WeightedMean', f)
def Swap(): """Swaps the top two stack elements.""" return Fn('Swap', lambda x0, x1: (x1, x0), n_out=2)
def WeightedSum(): """Returns a layer to compute weighted sum over all values in the input.""" def f(values, weights): # pylint: disable=invalid-name return np.sum(values * weights) return Fn('WeightedSum', f)
def Add(): """Adds two tensors.""" return Fn('Add', lambda x0, x1: x0 + x1)
def ThresholdToBinary(threshold=.5): """Returns a layer that thresholds inputs to yield outputs in {0, 1}.""" def f(model_output): # pylint: disable=invalid-name return (model_output > threshold).astype(jnp.int32) return Fn('ThresholdToBinary', f)
def Multiply(): """Multiplies two tensors.""" return Fn('Multiply', lambda x0, x1: x0 * x1)
def ArgMax(axis=-1): """Returns a layer that calculates argmax along the given axis.""" def f(model_output): # pylint: disable=invalid-name return jnp.argmax(model_output, axis=axis) return Fn('ArgMax', f)
def Negate(): """Returns a layer that computes the element-wise negation of a tensor.""" return Fn('Negate', lambda x: -x)
def StopGradient(): """Returns an identity layer with a stop gradient.""" return Fn('StopGradient', lambda x: fastmath.stop_gradient(x)) # pylint: disable=unnecessary-lambda
def Log(): """Returns a layer that computes the element-wise logarithm of a tensor.""" return Fn('Log', lambda x: jnp.log(x)) # pylint: disable=unnecessary-lambda
def test_fn_layer_fails_wrong_f(self): with self.assertRaisesRegex(ValueError, 'default arg'): Fn('', lambda x, sth=None: x) with self.assertRaisesRegex(ValueError, 'keyword arg'): Fn('', lambda x, **kwargs: x)