Beispiel #1
0
  def test_initializers(self):
    as_np_f64 = lambda t: np.array(t, dtype=np.float64)
    # This just makes sure we can call the initializers in accordance to the
    # API and get the right shapes and dtypes out.
    inits = [
        initializers.Constant(42.0),
        initializers.Constant(as_np_f64(42.0)),
        initializers.RandomNormal(),
        initializers.RandomNormal(2.0),
        initializers.RandomNormal(as_np_f64(2.0)),
        initializers.RandomUniform(),
        initializers.RandomUniform(3.0),
        initializers.RandomUniform(as_np_f64(3.0)),
        initializers.VarianceScaling(),
        initializers.VarianceScaling(2.0),
        initializers.VarianceScaling(as_np_f64(2.0)),
        initializers.VarianceScaling(2.0, mode="fan_in"),
        initializers.VarianceScaling(as_np_f64(2.0), mode="fan_in",
                                     fan_in_axes=[0]),
        initializers.VarianceScaling(2.0, mode="fan_in", fan_in_axes=[0]),
        initializers.VarianceScaling(as_np_f64(2.0), mode="fan_in"),
        initializers.VarianceScaling(2.0, mode="fan_out"),
        initializers.VarianceScaling(as_np_f64(2.0), mode="fan_out"),
        initializers.VarianceScaling(2.0, mode="fan_avg"),
        initializers.VarianceScaling(as_np_f64(2.0), mode="fan_avg"),
        initializers.VarianceScaling(2.0, distribution="truncated_normal"),
        initializers.VarianceScaling(
            as_np_f64(2.0), distribution="truncated_normal"),
        initializers.VarianceScaling(2.0, distribution="normal"),
        initializers.VarianceScaling(as_np_f64(2.0), distribution="normal"),
        initializers.VarianceScaling(2.0, distribution="uniform"),
        initializers.VarianceScaling(as_np_f64(2.0), distribution="uniform"),
        initializers.UniformScaling(),
        initializers.UniformScaling(2.0),
        initializers.UniformScaling(as_np_f64(2.0)),
        initializers.TruncatedNormal(),
        initializers.Orthogonal(),
        initializers.Identity(),
        initializers.Identity(as_np_f64(2.0)),

        # Users are supposed to be able to use these.
        jnp.zeros,
        jnp.ones,
    ]

    # TODO(ibab): Test other shapes as well.
    shape = (20, 42)

    dtype = jnp.float32
    for init in inits:
      generated = init(shape, dtype)
      self.assertEqual(generated.shape, shape)
      self.assertEqual(generated.dtype, dtype)
Beispiel #2
0
 def testRange(self, shape, gain, dtype):
     init = initializers.Identity(gain)
     value = init(shape, dtype)
     self.assertEqual(value.shape, shape)
     np.testing.assert_almost_equal(value.mean(),
                                    gain / shape[-1],
                                    decimal=4)
     np.testing.assert_almost_equal(value.max(), gain, decimal=4)
Beispiel #3
0
  def test_identity_identity(self):
    init = initializers.Identity()
    shape = (42, 20)
    generated = init(shape, jnp.float32)
    self.assertEqual(generated.shape, shape)
    self.assertEqual(generated.dtype, jnp.float32)

    key = jax.random.PRNGKey(42)
    some_matrix = jax.random.normal(key, (62, 42), jnp.float32)
    np.testing.assert_allclose(some_matrix @ generated, some_matrix[:, :20],
                               rtol=1e-2)
Beispiel #4
0
 def test_identity_invalid_shape(self):
     init = initializers.Identity()
     shape = (20, )
     with self.assertRaisesRegex(ValueError,
                                 "requires at least a 2D shape."):
         init(shape, jnp.float32)