Exemplo n.º 1
0
Arquivo: nn_test.py Projeto: gtr8/jax
  def testNormalizeWhereMask(self):
    x = jnp.array([5.5, 1.3, -4.2, 0.9])
    m = jnp.array([True, False, True, True])
    x_filtered = jnp.take(x, jnp.array([0, 2, 3]))

    out_masked = jnp.take(nn.normalize(x, where=m), jnp.array([0, 2, 3]))
    out_filtered = nn.normalize(x_filtered)

    self.assertAllClose(out_masked, out_filtered)
Exemplo n.º 2
0
 def apply_fun(params, x, **kwargs):
   beta, gamma = params
   # TODO(phawkins): np.expand_dims should accept an axis tuple.
   # (https://github.com/numpy/numpy/issues/12290)
   ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x)))
   z = normalize(x, axis, epsilon=epsilon)
   if center and scale: return gamma[ed] * z + beta[ed]
   if center: return z + beta[ed]
   if scale: return gamma[ed] * z
   return z