def test_init_outside_setup_without_compact(self): rngkey = jax.random.PRNGKey(0) class DummyModule(nn.Module): def __call__(self, x): bias = self.param('bias', initializers.ones, x.shape) return x + bias x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}) with self.assertRaisesRegex(ValueError, 'must be initialized.*setup'): y = DummyModule(parent=scope)(x)
def test_init_module(self): rngkey = jax.random.PRNGKey(0) x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = DummyModule(parent=scope)(x) params = scope.variables()['params'] y2 = DummyModule(parent=scope.rewound())(x) onp.testing.assert_allclose(y, y2) onp.testing.assert_allclose(y, jnp.array([2.])) self.assertEqual(params, {'bias': jnp.array([1.])})
def test_attr_submodule_name_collision(self): rngkey = jax.random.PRNGKey(0) class Dummy(nn.Module): bias: bool def setup(self): self.bias = DummyModule(name='bias') def __call__(self, x): return self.bias(x) x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) with self.assertRaisesRegex(ValueError, 'bias exists already'): y = Dummy(x.shape, parent=scope)(x)
def test_setattr_name_var_disagreement(self): rngkey = jax.random.PRNGKey(0) class Dummy(nn.Module): xshape: Tuple[int] def setup(self): self.bias = self.param('notbias', initializers.ones, self.xshape) def __call__(self, x): return x + self.bias x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) with self.assertRaisesRegex(ValueError, 'notbias.*must equal.*bias'): y = Dummy(x.shape, parent=scope)(x)
def test_attr_param_name_collision(self): rngkey = jax.random.PRNGKey(0) class Dummy(nn.Module): bias: bool def setup(self): self.bias = self.param('bias', initializers.ones, (3, 3)) def __call__(self, x): return x + self.bias x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) with self.assertRaisesRegex(ValueError, 'Name bias already in use'): y = Dummy(x.shape, parent=scope)(x)
def test_attr_submodule_name_collision(self): rngkey = jax.random.PRNGKey(0) class Dummy(nn.Module): bias: bool def setup(self): self.bias = DummyModule(name='bias') def __call__(self, x): return self.bias(x) x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) msg = 'Could not create submodule "bias" in Module Dummy: Name in use' with self.assertRaisesRegex(errors.NameInUseError, msg): y = Dummy(x.shape, parent=scope)(x)
def test_setattr_name_submodule_redundant(self): rngkey = jax.random.PRNGKey(0) class Dummy(nn.Module): xshape: Tuple[int] def setup(self): self.bias = DummyModule(name='bias') def __call__(self, x): return x + self.bias x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) with self.assertRaisesRegex(ValueError, 'In setup, assign names of Modules ' 'via self.<name> and not using keyword argument name="<name>"'): y = Dummy(x.shape, parent=scope)(x)
def test_call_var_collision(self): rngkey = jax.random.PRNGKey(0) class Dummy(nn.Module): xshape: Tuple[int] @compact def __call__(self, x): bias = self.param('bias', initializers.ones, self.xshape) bias = self.param('bias', initializers.ones, self.xshape) return x + bias x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) with self.assertRaisesRegex(ValueError, 'bias already in use'): y = Dummy(x.shape, parent=scope)(x)
def test_attr_param_name_collision(self): rngkey = jax.random.PRNGKey(0) class Dummy(nn.Module): bias: bool def setup(self): self.bias = self.param('bias', initializers.ones, (3, 3)) def __call__(self, x): return x + self.bias x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) msg = 'Could not create param "bias" in Module Dummy: Name in use' with self.assertRaisesRegex(errors.NameInUseError, msg): y = Dummy(x.shape, parent=scope)(x)
def test_submodule_var_collision(self): rngkey = jax.random.PRNGKey(0) class Dummy(nn.Module): xshape: Tuple[int] def setup(self): self.bias = self.param('bias', initializers.ones, self.xshape) self.bias = DummyModule() def __call__(self, x): return x + self.bias x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) msg = r'Duplicate use of scope name: "bias"' with self.assertRaisesRegex(errors.ScopeNameInUseError, msg): y = Dummy(x.shape, parent=scope)(x) class Dummy(nn.Module): xshape: Tuple[int] def setup(self): self.bias = self.param('bias', initializers.ones, self.xshape) @compact def __call__(self, x): bias = DummyModule(name='bias') return x + self.bias x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) with self.assertRaisesRegex(ValueError, 'name bias exists already'): y = Dummy(x.shape, parent=scope)(x) class Dummy(nn.Module): xshape: Tuple[int] def setup(self): self.bias = DummyModule() @compact def __call__(self, x): bias = self.param('bias', initializers.ones, self.xshape) return x + self.bias x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) with self.assertRaisesRegex(ValueError, 'bias already'): y = Dummy(x.shape, parent=scope)(x)
def test_call_var_collision(self): rngkey = jax.random.PRNGKey(0) class Dummy(nn.Module): xshape: Tuple[int] @compact def __call__(self, x): bias = self.param('bias', initializers.ones, self.xshape) bias = self.param('bias', initializers.ones, self.xshape) return x + bias x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) msg = 'Could not create param "bias" in Module Dummy: Name in use' with self.assertRaisesRegex(errors.NameInUseError, msg): y = Dummy(x.shape, parent=scope)(x)
def test_setattr_name_var_disagreement_allowed_in_lists(self): rngkey = jax.random.PRNGKey(0) class Dummy(nn.Module): xshape: Tuple[int] def setup(self): self.biases = [ self.param(f'bias_{i}', initializers.ones, self.xshape) for i in range(4)] def __call__(self, x): return x + self.biases[0] x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = Dummy(x.shape, parent=scope)(x) self.assertEqual(y, jnp.array([2.]))
def test_setattr_name_submodule_redundant(self): rngkey = jax.random.PRNGKey(0) class Dummy(nn.Module): xshape: Tuple[int] def setup(self): self.bias = DummyModule(name='bias') def __call__(self, x): return x + self.bias x = jnp.array([1.]) scope = Scope({}, {'param': rngkey}) with self.assertRaisesRegex(ValueError, 'assign names via self'): y = Dummy(x.shape, parent=scope)(x)
def test_param_in_setup(self): rngkey = jax.random.PRNGKey(0) class DummyModule(nn.Module): xshape: Tuple[int] def setup(self): self.bias = self.param('bias', initializers.ones, self.xshape) def __call__(self, x): return x + self.bias x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = DummyModule(x.shape, parent=scope)(x) params = scope.variables()['params'] y2 = DummyModule(x.shape, parent=scope.rewound())(x) np.testing.assert_allclose(y, y2) np.testing.assert_allclose(y, jnp.array([2.])) self.assertEqual(params, {'bias': jnp.array([1.])})
def test_only_one_compact_method_subclass(self): class Dummy(nn.Module): @nn.compact def __call__(self): pass class SubDummy(Dummy): @nn.compact def __call__(self): super().__call__() scope = Scope(variables={}) subdummy = SubDummy(parent=scope) # Make sure the @compact annotation is valid on both base class and subclass, as long # as its on the same method. subdummy()
def test_setup_var_collision(self): rngkey = jax.random.PRNGKey(0) class Dummy(nn.Module): xshape: Tuple[int] def setup(self): self.bias = self.param('bias', initializers.ones, self.xshape) self.bias = self.param('bias', initializers.ones, self.xshape) def __call__(self, x): return x + self.bias x = jnp.array([1.]) scope = Scope({}, {'param': rngkey}) with self.assertRaisesRegex(ValueError, 'bias already in use'): y = Dummy(x.shape, parent=scope)(x)
def test_setattr_name_var_disagreement_allowed_in_dicts(self): rngkey = jax.random.PRNGKey(0) class Dummy(nn.Module): xshape: Tuple[int] def setup(self): self.biases = { # NOTE that keys still must be strings. This is to make a possible # future transition to automatically derived parameter names when assigned # as a dict easier (like we currently have with submodules). # See a bit of discussion here: https://github.com/google/flax/issues/705#issuecomment-738761853 str(i): self.param(f'bias_{i}', initializers.ones, self.xshape) for i in range(4)} def __call__(self, x): return x + self.biases['0'] x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = Dummy(x.shape, parent=scope)(x) self.assertEqual(y, jnp.array([2.]))
def test_submodule_var_collision_with_scope(self): rngkey = jax.random.PRNGKey(0) class Dummy(nn.Module): xshape: Tuple[int] def setup(self): self.bias = self.param('bias', initializers.ones, self.xshape) self.bias = DummyModule() def __call__(self, x): return x + self.bias x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) msg = 'Duplicate use of scope name: "bias"' with self.assertRaisesWithLiteralMatch(ValueError, msg): y = Dummy(x.shape, parent=scope)(x)
def test_util_fun(self): rngkey = jax.random.PRNGKey(0) class MLP(nn.Module): @compact def __call__(self, x): x = self._mydense(x) x = self._mydense(x) return x def _mydense(self, x): return Dense(3)(x) x = jnp.ones((10,)) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = MLP(parent=scope)(x) params = scope.variables()['params'] y2 = MLP(parent=scope.rewound())(x) np.testing.assert_allclose(y, y2) param_shape = jax.tree_map(jnp.shape, params) self.assertEqual(param_shape, {'Dense_0': {'kernel': (10, 3)}, 'Dense_1': {'kernel': (3, 3)}})
def test_setup_dict_assignment(self): rngkey = jax.random.PRNGKey(0) class MLP(nn.Module): def setup(self): self.lyrs1 = {'a': Dense(3), 'b': Dense(3),} self.lyrs2 = [Dense(3), Dense(3)] def __call__(self, x): y = self.lyrs1['a'](x) z = self.lyrs1['b'](y) #w = self.lyrs2[0](x) return z x = jnp.ones((10,)) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = MLP(parent=scope)(x) params = scope.variables()['params'] y2 = MLP(parent=scope.rewound())(x) np.testing.assert_allclose(y, y2) param_shape = jax.tree_map(jnp.shape, params) self.assertEqual(param_shape, {'lyrs1_a': {'kernel': (10, 3)}, 'lyrs1_b': {'kernel': (3, 3)}})
def test_nested_module_reuse(self): rngkey = jax.random.PRNGKey(0) class MLP(nn.Module): @compact def __call__(self, x): x = self._mydense(x) x = self._mydense(x) return x def _mydense(self, x): return Dense(3)(x) class Top(nn.Module): @compact def __call__(self, x): mlp = MLP() y = mlp(x) z = mlp(x) return y + z x = jnp.ones((10, )) scope = Scope({}, {'param': rngkey}) y = Top(parent=scope)(x) params = scope.variables()['param'] y2 = Top(parent=scope.rewound())(x) onp.testing.assert_allclose(y, y2) param_shape = jax.tree_map(jnp.shape, params) self.assertEqual( param_shape, { 'MLP_0': { 'Dense_0': { 'kernel': (10, 3) }, 'Dense_1': { 'kernel': (3, 3) } } })
def test_module_with_scope_is_not_hashable(self): module_a = nn.Dense(10, parent=Scope({})) with self.assertRaisesWithLiteralMatch(ValueError, 'Can\'t call __hash__ on modules that hold variables.'): hash(module_a)
def test_setup_cloning(self): class MLP(nn.Module): def setup(self): self.dense = Dense(3) scope = Scope({}) MLPclone = MLP(parent=scope).clone()