Exemple #1
0
def PaddingMask(pad=0):
  def f(x):
    return jnp.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))
  return Fn(f'PaddingMask({pad})', f)
Exemple #2
0
def _WeightedMean():
    """Returns a layer that computes a weighted mean of the given values."""
    def f(values, weights):  # pylint: disable=invalid-name
        return jnp.sum(values * weights) / jnp.sum(weights)

    return Fn('_WeightedMean', f)
Exemple #3
0
def ToFloat():
  """Returns a layer that changes the dtype of a tensor to `float32`."""
  return Fn('ToFloat', lambda x: x.astype(np.float32))
Exemple #4
0
def Exp():
  """Returns a layer that computes the element-wise exponential of a tensor."""
  return Fn('Exp', lambda x: jnp.exp(x))  # pylint: disable=unnecessary-lambda
Exemple #5
0
def FlattenList():
  """Flatten lists."""
  # TODO(jonni): Consider renaming layer to DeepFlatten.
  return Fn('FlattenList', lambda x: tuple(_deep_flatten(x)))
Exemple #6
0
def SubtractTop():
  """Subtracts the first tensor from the second."""
  return Fn('SubtractTop', lambda x0, x1: x1 - x0)
Exemple #7
0
def Sum(axis=-1, keepdims=False):
    return Fn('Sum', lambda x: jnp.sum(x, axis=axis, keepdims=keepdims))
Exemple #8
0
def Dup():
  """Duplicates (copies) the top element on the data stack."""
  return Fn('Dup', lambda x: (x, x), n_out=2)
Exemple #9
0
def ToFloat():
    return Fn('ToFloat', lambda x: x.astype(np.float32))
Exemple #10
0
def Mean(axis=-1, keepdims=False):
    return Fn('Mean', lambda x: jnp.mean(x, axis=axis, keepdims=keepdims))
Exemple #11
0
def Softmax(axis=-1):
    """Layer that applies softmax: exponentiate and normalize along given axis."""
    return Fn('Softmax',
              lambda x: jnp.exp(x - math.logsumexp(x, axis, keepdims=True)))
Exemple #12
0
def LogSoftmax(axis=-1):
    """Layer that applies log softmax: log-normalize along the given axis."""
    return Fn('LogSoftmax',
              lambda x: x - math.logsumexp(x, axis, keepdims=True))
Exemple #13
0
def Exp():
    return Fn('Exp', lambda x: jnp.exp(x))  # pylint: disable=unnecessary-lambda
Exemple #14
0
def WeightedSum():
    """Returns a layer that computes a weighted sum of the given values."""
    def f(values, weights):  # pylint: disable=invalid-name
        return jnp.sum(values * weights)

    return Fn('WeightedSum', f)
Exemple #15
0
def Negate():
    return Fn('Negate', lambda x: -x)
Exemple #16
0
def Drop():
  """Drops the top stack element."""
  return Fn('Drop', lambda x: (), n_out=0)
Exemple #17
0
def _WeightedMean():
    """Returns a layer to compute weighted mean over all values in the input."""
    def f(values, weights):  # pylint: disable=invalid-name
        return np.sum(values * weights) / np.sum(weights)

    return Fn('_WeightedMean', f)
Exemple #18
0
def Swap():
  """Swaps the top two stack elements."""
  return Fn('Swap', lambda x0, x1: (x1, x0), n_out=2)
Exemple #19
0
def WeightedSum():
    """Returns a layer to compute weighted sum over all values in the input."""
    def f(values, weights):  # pylint: disable=invalid-name
        return np.sum(values * weights)

    return Fn('WeightedSum', f)
Exemple #20
0
def Add():
  """Adds two tensors."""
  return Fn('Add', lambda x0, x1: x0 + x1)
Exemple #21
0
def ThresholdToBinary(threshold=.5):
  """Returns a layer that thresholds inputs to yield outputs in {0, 1}."""
  def f(model_output):  # pylint: disable=invalid-name
    return (model_output > threshold).astype(jnp.int32)
  return Fn('ThresholdToBinary', f)
Exemple #22
0
def Multiply():
  """Multiplies two tensors."""
  return Fn('Multiply', lambda x0, x1: x0 * x1)
Exemple #23
0
def ArgMax(axis=-1):
  """Returns a layer that calculates argmax along the given axis."""
  def f(model_output):  # pylint: disable=invalid-name
    return jnp.argmax(model_output, axis=axis)
  return Fn('ArgMax', f)
Exemple #24
0
def Negate():
  """Returns a layer that computes the element-wise negation of a tensor."""
  return Fn('Negate', lambda x: -x)
Exemple #25
0
def StopGradient():
    """Returns an identity layer with a stop gradient."""
    return Fn('StopGradient', lambda x: fastmath.stop_gradient(x))  # pylint: disable=unnecessary-lambda
Exemple #26
0
def Log():
  """Returns a layer that computes the element-wise logarithm of a tensor."""
  return Fn('Log', lambda x: jnp.log(x))  # pylint: disable=unnecessary-lambda
Exemple #27
0
 def test_fn_layer_fails_wrong_f(self):
     with self.assertRaisesRegex(ValueError, 'default arg'):
         Fn('', lambda x, sth=None: x)
     with self.assertRaisesRegex(ValueError, 'keyword arg'):
         Fn('', lambda x, **kwargs: x)