def test_initializers(self): as_np_f64 = lambda t: np.array(t, dtype=np.float64) # This just makes sure we can call the initializers in accordance to the # API and get the right shapes and dtypes out. inits = [ initializers.Constant(42.0), initializers.Constant(as_np_f64(42.0)), initializers.RandomNormal(), initializers.RandomNormal(2.0), initializers.RandomNormal(as_np_f64(2.0)), initializers.RandomUniform(), initializers.RandomUniform(3.0), initializers.RandomUniform(as_np_f64(3.0)), initializers.VarianceScaling(), initializers.VarianceScaling(2.0), initializers.VarianceScaling(as_np_f64(2.0)), initializers.VarianceScaling(2.0, mode="fan_in"), initializers.VarianceScaling(as_np_f64(2.0), mode="fan_in", fan_in_axes=[0]), initializers.VarianceScaling(2.0, mode="fan_in", fan_in_axes=[0]), initializers.VarianceScaling(as_np_f64(2.0), mode="fan_in"), initializers.VarianceScaling(2.0, mode="fan_out"), initializers.VarianceScaling(as_np_f64(2.0), mode="fan_out"), initializers.VarianceScaling(2.0, mode="fan_avg"), initializers.VarianceScaling(as_np_f64(2.0), mode="fan_avg"), initializers.VarianceScaling(2.0, distribution="truncated_normal"), initializers.VarianceScaling( as_np_f64(2.0), distribution="truncated_normal"), initializers.VarianceScaling(2.0, distribution="normal"), initializers.VarianceScaling(as_np_f64(2.0), distribution="normal"), initializers.VarianceScaling(2.0, distribution="uniform"), initializers.VarianceScaling(as_np_f64(2.0), distribution="uniform"), initializers.UniformScaling(), initializers.UniformScaling(2.0), initializers.UniformScaling(as_np_f64(2.0)), initializers.TruncatedNormal(), initializers.Orthogonal(), initializers.Identity(), initializers.Identity(as_np_f64(2.0)), # Users are supposed to be able to use these. jnp.zeros, jnp.ones, ] # TODO(ibab): Test other shapes as well. shape = (20, 42) dtype = jnp.float32 for init in inits: generated = init(shape, dtype) self.assertEqual(generated.shape, shape) self.assertEqual(generated.dtype, dtype)
def testRange(self, shape, gain, dtype): init = initializers.Identity(gain) value = init(shape, dtype) self.assertEqual(value.shape, shape) np.testing.assert_almost_equal(value.mean(), gain / shape[-1], decimal=4) np.testing.assert_almost_equal(value.max(), gain, decimal=4)
def test_identity_identity(self): init = initializers.Identity() shape = (42, 20) generated = init(shape, jnp.float32) self.assertEqual(generated.shape, shape) self.assertEqual(generated.dtype, jnp.float32) key = jax.random.PRNGKey(42) some_matrix = jax.random.normal(key, (62, 42), jnp.float32) np.testing.assert_allclose(some_matrix @ generated, some_matrix[:, :20], rtol=1e-2)
def test_identity_invalid_shape(self): init = initializers.Identity() shape = (20, ) with self.assertRaisesRegex(ValueError, "requires at least a 2D shape."): init(shape, jnp.float32)