Ejemplo n.º 1
0
 def testVarianceScaling(self, map_in, map_out, fan, distr):
     shape = (80, 50, 7)
     fan_in, fan_out = jax._src.nn.initializers._compute_fans(
         NamedShape(*shape), 0, 1)
     key = jax.random.PRNGKey(0)
     base_scaling = partial(jax.nn.initializers.variance_scaling, 100, fan,
                            distr)
     ref_sampler = lambda: base_scaling(in_axis=0, out_axis=1)(key, shape)
     if map_in and map_out:
         out_axes = ['i', 'o', ...]
         named_shape = NamedShape(shape[2], i=shape[0], o=shape[1])
         xmap_sampler = lambda: base_scaling(in_axis='i', out_axis='o')(
             key, named_shape)
     elif map_in:
         out_axes = ['i', ...]
         named_shape = NamedShape(shape[1], shape[2], i=shape[0])
         xmap_sampler = lambda: base_scaling(in_axis='i', out_axis=0)(
             key, named_shape)
     elif map_out:
         out_axes = [None, 'o', ...]
         named_shape = NamedShape(shape[0], shape[2], o=shape[1])
         xmap_sampler = lambda: base_scaling(in_axis=0, out_axis='o')(
             key, named_shape)
     mapped_sampler = xmap(xmap_sampler,
                           in_axes=(),
                           out_axes=out_axes,
                           axis_sizes={
                               'i': shape[0],
                               'o': shape[1]
                           })
     self.assertAllClose(jnp.var(mapped_sampler()),
                         jnp.var(ref_sampler()),
                         atol=1e-4,
                         rtol=2e-2)
Ejemplo n.º 2
0
 def testSamplerSharding(self, distr_sample):
   def sample(shape, map_size):
     return xmap(lambda: distr_sample(jax.random.PRNGKey(0), shape=shape),
                 in_axes=(), out_axes=[None, 'i', ...], axis_sizes={'i': map_size})()
   replicated = sample((3,), 4)
   self.assertTrue((replicated[:,[0]] == replicated).all())
   sharded = sample(NamedShape(3, i=4), 4)
   self.assertFalse((sharded[:,[0]] == sharded[:,1:]).all(1).any())
   error = "The shape of axis i was specified as 4, but it really is 5"
   with self.assertRaisesRegex(ValueError, error):
     sample(NamedShape(3, i=4), 5)
Ejemplo n.º 3
0
 def sample(axis_resources):
     return xmap(lambda: distr_sample(jax.random.PRNGKey(0),
                                      shape=NamedShape(3, i=4, j=6)),
                 in_axes=(),
                 out_axes=['i', 'j', ...],
                 axis_sizes={
                     'i': 4,
                     'j': 6
                 },
                 axis_resources=axis_resources)()