Exemplo n.º 1
0
 def test_orthogonal_invalid_shape(self):
     init = initializers.Orthogonal()
     shape = (20, )
     with self.assertRaisesRegex(
             ValueError,
             "Orthogonal initializer requires at least a 2D shape."):
         init(shape, jnp.float32)
Exemplo n.º 2
0
    def test_initializers(self):
        # This just makes sure we can call the initializers in accordance to the
        # API and get the right shapes and dtypes out.
        inits = [
            initializers.Constant(42.0),
            initializers.RandomNormal(),
            initializers.RandomNormal(2.0),
            initializers.RandomUniform(),
            initializers.RandomUniform(3.0),
            initializers.VarianceScaling(),
            initializers.VarianceScaling(2.0),
            initializers.VarianceScaling(2.0, mode="fan_in"),
            initializers.VarianceScaling(2.0, mode="fan_out"),
            initializers.VarianceScaling(2.0, mode="fan_avg"),
            initializers.VarianceScaling(2.0, distribution="truncated_normal"),
            initializers.VarianceScaling(2.0, distribution="normal"),
            initializers.VarianceScaling(2.0, distribution="uniform"),
            initializers.UniformScaling(),
            initializers.UniformScaling(2.0),
            initializers.TruncatedNormal(),
            initializers.Orthogonal(),

            # Users are supposed to be able to use these.
            jnp.zeros,
            jnp.ones,
        ]

        # TODO(ibab): Test other shapes as well.
        shape = (20, 42)

        dtype = jnp.float32
        for init in inits:
            generated = init(shape, dtype)
            self.assertEqual(generated.shape, shape)
            self.assertEqual(generated.dtype, dtype)
Exemplo n.º 3
0
 def test_orthogonal_orthogonal(self):
     init = initializers.Orthogonal()
     shape = (42, 20)
     generated = init(shape, jnp.float32)
     self.assertEqual(generated.shape, shape)
     self.assertEqual(generated.dtype, jnp.float32)