def test_orthogonal_invalid_shape(self): init = initializers.Orthogonal() shape = (20, ) with self.assertRaisesRegex( ValueError, "Orthogonal initializer requires at least a 2D shape."): init(shape, jnp.float32)
def test_initializers(self): # This just makes sure we can call the initializers in accordance to the # API and get the right shapes and dtypes out. inits = [ initializers.Constant(42.0), initializers.RandomNormal(), initializers.RandomNormal(2.0), initializers.RandomUniform(), initializers.RandomUniform(3.0), initializers.VarianceScaling(), initializers.VarianceScaling(2.0), initializers.VarianceScaling(2.0, mode="fan_in"), initializers.VarianceScaling(2.0, mode="fan_out"), initializers.VarianceScaling(2.0, mode="fan_avg"), initializers.VarianceScaling(2.0, distribution="truncated_normal"), initializers.VarianceScaling(2.0, distribution="normal"), initializers.VarianceScaling(2.0, distribution="uniform"), initializers.UniformScaling(), initializers.UniformScaling(2.0), initializers.TruncatedNormal(), initializers.Orthogonal(), # Users are supposed to be able to use these. jnp.zeros, jnp.ones, ] # TODO(ibab): Test other shapes as well. shape = (20, 42) dtype = jnp.float32 for init in inits: generated = init(shape, dtype) self.assertEqual(generated.shape, shape) self.assertEqual(generated.dtype, dtype)
def test_orthogonal_orthogonal(self): init = initializers.Orthogonal() shape = (42, 20) generated = init(shape, jnp.float32) self.assertEqual(generated.shape, shape) self.assertEqual(generated.dtype, jnp.float32)