Exemplo n.º 1
0
 def log_prob(self, inputs, point):
     point = point.reshape(inputs.shape[:-1] + (-1, ))
     return (
         # L2 term.
         -jnp.sum((point - inputs)**2, axis=-1) / (2 * self._std**2) -
         # Normalizing constant.
         ((jnp.log(self._std) + jnp.log(jnp.sqrt(2 * jnp.pi))) *
          jnp.prod(self._shape)))
Exemplo n.º 2
0
def _GetFans(shape, out_dim=-1, in_dim=-2):
    """Get the fan-in and fan-out sizes for the given shape and dims."""
    # Temporary fix until numpy.delete supports negative indices.
    if out_dim < 0:
        out_dim += len(shape)
    if in_dim < 0:
        in_dim += len(shape)

    receptive_field = jnp.prod(np.delete(shape, [in_dim, out_dim]))
    if len(shape) >= 2:
        fan_in, fan_out = shape[in_dim], shape[out_dim]
    elif len(shape) == 1:
        fan_in = shape[0]
        fan_out = shape[0]
    else:
        fan_in = 1.
        fan_out = 1.
        fan_in *= receptive_field
        fan_out *= receptive_field
    return fan_in, fan_out
Exemplo n.º 3
0
 def test_batch_norm(self):
     input_shape = (2, 3, 4)
     input_dtype = np.float32
     input_signature = ShapeDtype(input_shape, input_dtype)
     eps = 1e-5
     inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype),
                       input_shape)
     m1 = 11.5  # Mean of this random input.
     v1 = 47.9167  # Variance of this random input.
     layer = normalization.BatchNorm(axis=(0, 1, 2))
     _, _ = layer.init(input_signature)
     state = layer.state
     onp.testing.assert_allclose(state[0], 0)
     onp.testing.assert_allclose(state[1], 1)
     self.assertEqual(state[2], 0)
     out = layer(inp1)
     state = layer.state
     onp.testing.assert_allclose(state[0], m1 * 0.001)
     onp.testing.assert_allclose(state[1], 0.999 + v1 * 0.001, rtol=1e-6)
     self.assertEqual(state[2], 1)
     onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps),
                                 rtol=1e-6)
Exemplo n.º 4
0
 def n_inputs(self):
     return jnp.prod(self._shape, dtype=jnp.int32) * self._n_categories
Exemplo n.º 5
0
 def n_inputs(self):
     return jnp.prod(self._shape, dtype=jnp.int32)