Example #1
0
def Swish():
    r"""Returns a layer that computes the Swish function.

  .. math::
      f(x) = x \cdot \text{sigmoid}(x)
  """
    return Fn('Swish', lambda x: x * fastmath.expit(x))
Example #2
0
def Sigmoid():
  r"""Returns a layer that computes the sigmoid function.

  .. math::
      f(x) = \frac{1}{1 + e^{-x}}
  """
  return Fn('Sigmoid', lambda x: fastmath.expit(x))
Example #3
0
 def f(model_output, targets):  # pylint: disable=invalid-name
     probabilities = fastmath.expit(model_output)
     binary_entropies = -(targets * jnp.log(probabilities) + (1 - targets) *
                          (jnp.log(1 - probabilities)))
     return jnp.average(binary_entropies)
Example #4
0
File: rse.py Project: yliu45/trax
 def forward(self, x):
   s = self.weights
   return jnp.multiply(x, fastmath.expit(s))
Example #5
0
 def _f(x, axis=-1):  # pylint: disable=invalid-name
     size = x.shape[axis]
     assert size % 2 == 0, f'axis {axis} of size {size} is not be divisible by 2'
     a, b = jnp.split(x, 2, axis)
     return a * fastmath.expit(b)