Exemplo n.º 1
0
 def wrapper(*args):
     vmapped_forward_function = nn.vmap(forward,
                                        in_axes=in_axes,
                                        variable_axes={'params': None},
                                        split_rngs={'params': False})
     losses, aux = vmapped_forward_function(*args)
     return losses.mean(axis=0), aux
Exemplo n.º 2
0
 def vmap(cls):
     return nn.vmap(cls,
                    in_axes=(0, ),
                    variable_axes={
                        'params': None,
                        'batch_stats': None
                    },
                    split_rngs={'params': False},
                    axis_name='batch')
Exemplo n.º 3
0
  def test_module_transform_with_setup(self):
    class Foo(nn.Module):
      def setup(self):
        self.test = self.param('test', nn.initializers.ones, ())

      def __call__(self, x):
        return x * self.test

    FooVmap = nn.vmap(Foo, in_axes=0, out_axes=0,
                      variable_axes={'params': 0}, split_rngs={'params': True})
    variables = FooVmap().init(random.PRNGKey(0), jnp.ones((4,)))
    self.assertEqual(variables['params']['test'].shape, (4,))
Exemplo n.º 4
0
def concise_vmap(module, in_axes, out_axes, axis_size=None, **var_specs):
    variable_axes = {
        k: v[0]
        for k, v in var_specs.items() if isinstance(v, Sequence)
    }
    splits = {k: v[1] for k, v in var_specs.items() if isinstance(v, Sequence)}
    return vmap(module,
                in_axes=in_axes,
                out_axes=out_axes,
                variable_axes=variable_axes,
                split_rngs=splits,
                axis_size=axis_size)
Exemplo n.º 5
0
 def test_partially_applied_module_constructor_transform(self):
     k = random.PRNGKey(0)
     x = jnp.ones((3, 4, 4))
     dense = partial(nn.Dense, use_bias=False)
     vmap_dense = nn.vmap(dense,
                          variable_axes={'params': 0},
                          split_rngs={'params': True})(4)
     init_vars = vmap_dense.init(k, x)
     init_vars_shapes = jax.tree_map(jnp.shape, init_vars)
     ref_var_shapes = freeze({
         'params': {
             'kernel': (3, 4, 4),
         },
     })
     self.assertTrue(tree_equals(init_vars_shapes, ref_var_shapes))
Exemplo n.º 6
0
def vmap_module(module, in_axes=0, out_axes=0, num_batch_dims=1):
    """Vectorize a module.

  Args:
    module: the module to vectorize.
    in_axes: the `in_axes` argument passed to vmap. See `jax.vmap`.
    out_axes: the `out_axes` argument passed to vmap. See `jax.vmap`.
    num_batch_dims: the number of batch dimensions (how many times to apply vmap
      to the module).

  Returns:
    A vectorized module.
  """
    for _ in range(num_batch_dims):
        module = nn.vmap(module,
                         variable_axes={'params': None},
                         split_rngs={'params': False},
                         in_axes=in_axes,
                         out_axes=out_axes)

    return module
Exemplo n.º 7
0
 def vmap(fn):
     return nn.vmap(fn,
                    in_axes=(0, ),
                    variable_axes={'params': None},
                    split_rngs={'params': False})
Exemplo n.º 8
0
  def test_toplevel_submodule_adoption_transform(self):
    class A(nn.Module):
      @nn.compact
      def __call__(self, x):
        return nn.Dense(3)(x)
    class B(nn.Module):
      A: nn.Module
      @nn.compact
      def __call__(self, x):
        return self.A(x)
    class C(nn.Module):
      A: nn.Module
      B: nn.Module
      @partial(
          nn.vmap,
          variable_axes={'params': 0},
          split_rngs={'params': True})
      @nn.compact
      def __call__(self, x):
        return self.B(x) + self.A(x)
    class Csimple(nn.Module):
      A: nn.Module
      B: nn.Module
      @nn.compact
      def __call__(self, x):
        return self.B(x) + self.A(x)
    class D(nn.Module):
      @nn.compact
      def __call__(self, x):
        a1 = A()
        a2 = A()
        b = B(a1)
        c = C(a2, b)
        return c(x)

    key = random.PRNGKey(0)
    x = jnp.ones((10, 10))
    p1 = D().init(key, x)
    y1 = D().apply(p1, x)

    a1 = A()
    a2 = A()
    b = B(a1)
    p2 = freeze({'params': {
        'A': p1['params']['A_0'],
        'B': {
            'A': p1['params']['A_1'],
        }
    }})

    print(jax.tree_map(jnp.shape, p1))
    # Test method wrapper transform.
    y2 = C(a2, b).apply(p2, x)
    np.testing.assert_allclose(y1, y2, atol=1e-7)
    # Test class transform.
    Ctrafo = nn.vmap(Csimple,
                     variable_axes={'params': 0},
                     split_rngs={'params': True})

    y3 = Ctrafo(a2, b).apply(p2, x)
    np.testing.assert_allclose(y1, y3, atol=1e-7)
Exemplo n.º 9
0
 def vmap(cls):
     return nn.vmap(cls,
                    in_axes=(0, ),
                    variable_in_axes={'param': None},
                    variable_out_axes={'param': None},
                    split_rngs={'param': False})