Example #1
0
def PaddingMask(pad=0):
  def f(x):
    return jnp.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))
  return Fn(f'PaddingMask({pad})', f)
Example #2
0
def _WeightedMean():
    """Returns a layer that computes a weighted mean of the given values."""
    def f(values, weights):  # pylint: disable=invalid-name
        return jnp.sum(values * weights) / jnp.sum(weights)

    return Fn('_WeightedMean', f)
Example #3
0
def ToFloat():
  """Returns a layer that changes the dtype of a tensor to `float32`."""
  return Fn('ToFloat', lambda x: x.astype(np.float32))
Example #4
0
def Exp():
  """Returns a layer that computes the element-wise exponential of a tensor."""
  return Fn('Exp', lambda x: jnp.exp(x))  # pylint: disable=unnecessary-lambda
Example #5
0
def FlattenList():
  """Flatten lists."""
  # TODO(jonni): Consider renaming layer to DeepFlatten.
  return Fn('FlattenList', lambda x: tuple(_deep_flatten(x)))
Example #6
0
def SubtractTop():
  """Subtracts the first tensor from the second."""
  return Fn('SubtractTop', lambda x0, x1: x1 - x0)
Example #7
0
def Sum(axis=-1, keepdims=False):
    return Fn('Sum', lambda x: jnp.sum(x, axis=axis, keepdims=keepdims))
Example #8
0
def Dup():
  """Duplicates (copies) the top element on the data stack."""
  return Fn('Dup', lambda x: (x, x), n_out=2)
Example #9
0
def ToFloat():
    return Fn('ToFloat', lambda x: x.astype(np.float32))
Example #10
0
def Mean(axis=-1, keepdims=False):
    return Fn('Mean', lambda x: jnp.mean(x, axis=axis, keepdims=keepdims))
Example #11
0
def Softmax(axis=-1):
    """Layer that applies softmax: exponentiate and normalize along given axis."""
    return Fn('Softmax',
              lambda x: jnp.exp(x - math.logsumexp(x, axis, keepdims=True)))
Example #12
0
def LogSoftmax(axis=-1):
    """Layer that applies log softmax: log-normalize along the given axis."""
    return Fn('LogSoftmax',
              lambda x: x - math.logsumexp(x, axis, keepdims=True))
Example #13
0
def Exp():
    return Fn('Exp', lambda x: jnp.exp(x))  # pylint: disable=unnecessary-lambda
Example #14
0
def WeightedSum():
    """Returns a layer that computes a weighted sum of the given values."""
    def f(values, weights):  # pylint: disable=invalid-name
        return jnp.sum(values * weights)

    return Fn('WeightedSum', f)
Example #15
0
def Negate():
    return Fn('Negate', lambda x: -x)
Example #16
0
def Drop():
  """Drops the top stack element."""
  return Fn('Drop', lambda x: (), n_out=0)
Example #17
0
def _WeightedMean():
    """Returns a layer to compute weighted mean over all values in the input."""
    def f(values, weights):  # pylint: disable=invalid-name
        return np.sum(values * weights) / np.sum(weights)

    return Fn('_WeightedMean', f)
Example #18
0
def Swap():
  """Swaps the top two stack elements."""
  return Fn('Swap', lambda x0, x1: (x1, x0), n_out=2)
Example #19
0
def WeightedSum():
    """Returns a layer to compute weighted sum over all values in the input."""
    def f(values, weights):  # pylint: disable=invalid-name
        return np.sum(values * weights)

    return Fn('WeightedSum', f)
Example #20
0
def Add():
  """Adds two tensors."""
  return Fn('Add', lambda x0, x1: x0 + x1)
Example #21
0
def ThresholdToBinary(threshold=.5):
  """Returns a layer that thresholds inputs to yield outputs in {0, 1}."""
  def f(model_output):  # pylint: disable=invalid-name
    return (model_output > threshold).astype(jnp.int32)
  return Fn('ThresholdToBinary', f)
Example #22
0
def Multiply():
  """Multiplies two tensors."""
  return Fn('Multiply', lambda x0, x1: x0 * x1)
Example #23
0
def ArgMax(axis=-1):
  """Returns a layer that calculates argmax along the given axis."""
  def f(model_output):  # pylint: disable=invalid-name
    return jnp.argmax(model_output, axis=axis)
  return Fn('ArgMax', f)
Example #24
0
def Negate():
  """Returns a layer that computes the element-wise negation of a tensor."""
  return Fn('Negate', lambda x: -x)
Example #25
0
def StopGradient():
    """Returns an identity layer with a stop gradient."""
    return Fn('StopGradient', lambda x: fastmath.stop_gradient(x))  # pylint: disable=unnecessary-lambda
Example #26
0
def Log():
  """Returns a layer that computes the element-wise logarithm of a tensor."""
  return Fn('Log', lambda x: jnp.log(x))  # pylint: disable=unnecessary-lambda
Example #27
0
 def test_fn_layer_fails_wrong_f(self):
     with self.assertRaisesRegex(ValueError, 'default arg'):
         Fn('', lambda x, sth=None: x)
     with self.assertRaisesRegex(ValueError, 'keyword arg'):
         Fn('', lambda x, **kwargs: x)