def wrapper(*args): vmapped_forward_function = nn.vmap(forward, in_axes=in_axes, variable_axes={'params': None}, split_rngs={'params': False}) losses, aux = vmapped_forward_function(*args) return losses.mean(axis=0), aux
def vmap(cls): return nn.vmap(cls, in_axes=(0, ), variable_axes={ 'params': None, 'batch_stats': None }, split_rngs={'params': False}, axis_name='batch')
def test_module_transform_with_setup(self): class Foo(nn.Module): def setup(self): self.test = self.param('test', nn.initializers.ones, ()) def __call__(self, x): return x * self.test FooVmap = nn.vmap(Foo, in_axes=0, out_axes=0, variable_axes={'params': 0}, split_rngs={'params': True}) variables = FooVmap().init(random.PRNGKey(0), jnp.ones((4,))) self.assertEqual(variables['params']['test'].shape, (4,))
def concise_vmap(module, in_axes, out_axes, axis_size=None, **var_specs): variable_axes = { k: v[0] for k, v in var_specs.items() if isinstance(v, Sequence) } splits = {k: v[1] for k, v in var_specs.items() if isinstance(v, Sequence)} return vmap(module, in_axes=in_axes, out_axes=out_axes, variable_axes=variable_axes, split_rngs=splits, axis_size=axis_size)
def test_partially_applied_module_constructor_transform(self): k = random.PRNGKey(0) x = jnp.ones((3, 4, 4)) dense = partial(nn.Dense, use_bias=False) vmap_dense = nn.vmap(dense, variable_axes={'params': 0}, split_rngs={'params': True})(4) init_vars = vmap_dense.init(k, x) init_vars_shapes = jax.tree_map(jnp.shape, init_vars) ref_var_shapes = freeze({ 'params': { 'kernel': (3, 4, 4), }, }) self.assertTrue(tree_equals(init_vars_shapes, ref_var_shapes))
def vmap_module(module, in_axes=0, out_axes=0, num_batch_dims=1): """Vectorize a module. Args: module: the module to vectorize. in_axes: the `in_axes` argument passed to vmap. See `jax.vmap`. out_axes: the `out_axes` argument passed to vmap. See `jax.vmap`. num_batch_dims: the number of batch dimensions (how many times to apply vmap to the module). Returns: A vectorized module. """ for _ in range(num_batch_dims): module = nn.vmap(module, variable_axes={'params': None}, split_rngs={'params': False}, in_axes=in_axes, out_axes=out_axes) return module
def vmap(fn): return nn.vmap(fn, in_axes=(0, ), variable_axes={'params': None}, split_rngs={'params': False})
def test_toplevel_submodule_adoption_transform(self): class A(nn.Module): @nn.compact def __call__(self, x): return nn.Dense(3)(x) class B(nn.Module): A: nn.Module @nn.compact def __call__(self, x): return self.A(x) class C(nn.Module): A: nn.Module B: nn.Module @partial( nn.vmap, variable_axes={'params': 0}, split_rngs={'params': True}) @nn.compact def __call__(self, x): return self.B(x) + self.A(x) class Csimple(nn.Module): A: nn.Module B: nn.Module @nn.compact def __call__(self, x): return self.B(x) + self.A(x) class D(nn.Module): @nn.compact def __call__(self, x): a1 = A() a2 = A() b = B(a1) c = C(a2, b) return c(x) key = random.PRNGKey(0) x = jnp.ones((10, 10)) p1 = D().init(key, x) y1 = D().apply(p1, x) a1 = A() a2 = A() b = B(a1) p2 = freeze({'params': { 'A': p1['params']['A_0'], 'B': { 'A': p1['params']['A_1'], } }}) print(jax.tree_map(jnp.shape, p1)) # Test method wrapper transform. y2 = C(a2, b).apply(p2, x) np.testing.assert_allclose(y1, y2, atol=1e-7) # Test class transform. Ctrafo = nn.vmap(Csimple, variable_axes={'params': 0}, split_rngs={'params': True}) y3 = Ctrafo(a2, b).apply(p2, x) np.testing.assert_allclose(y1, y3, atol=1e-7)
def vmap(cls): return nn.vmap(cls, in_axes=(0, ), variable_in_axes={'param': None}, variable_out_axes={'param': None}, split_rngs={'param': False})