def testVarianceScaling(self, map_in, map_out, fan, distr): shape = (80, 50, 7) fan_in, fan_out = jax._src.nn.initializers._compute_fans( NamedShape(*shape), 0, 1) key = jax.random.PRNGKey(0) base_scaling = partial(jax.nn.initializers.variance_scaling, 100, fan, distr) ref_sampler = lambda: base_scaling(in_axis=0, out_axis=1)(key, shape) if map_in and map_out: out_axes = ['i', 'o', ...] named_shape = NamedShape(shape[2], i=shape[0], o=shape[1]) xmap_sampler = lambda: base_scaling(in_axis='i', out_axis='o')( key, named_shape) elif map_in: out_axes = ['i', ...] named_shape = NamedShape(shape[1], shape[2], i=shape[0]) xmap_sampler = lambda: base_scaling(in_axis='i', out_axis=0)( key, named_shape) elif map_out: out_axes = [None, 'o', ...] named_shape = NamedShape(shape[0], shape[2], o=shape[1]) xmap_sampler = lambda: base_scaling(in_axis=0, out_axis='o')( key, named_shape) mapped_sampler = xmap(xmap_sampler, in_axes=(), out_axes=out_axes, axis_sizes={ 'i': shape[0], 'o': shape[1] }) self.assertAllClose(jnp.var(mapped_sampler()), jnp.var(ref_sampler()), atol=1e-4, rtol=2e-2)
def testSamplerSharding(self, distr_sample): def sample(shape, map_size): return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=shape), in_axes=(), out_axes=[None, 'i', ...], axis_sizes={'i': map_size})() replicated = sample((3,), 4) self.assertTrue((replicated[:,[0]] == replicated).all()) sharded = sample(NamedShape(3, i=4), 4) self.assertFalse((sharded[:,[0]] == sharded[:,1:]).all(1).any()) error = "The shape of axis i was specified as 4, but it really is 5" with self.assertRaisesRegex(ValueError, error): sample(NamedShape(3, i=4), 5)
def sample(axis_resources): return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=NamedShape(3, i=4, j=6)), in_axes=(), out_axes=['i', 'j', ...], axis_sizes={ 'i': 4, 'j': 6 }, axis_resources=axis_resources)()