Esempio n. 1
0
 def test_init_outside_setup_without_compact(self):
   rngkey = jax.random.PRNGKey(0)
   class DummyModule(nn.Module):
     def __call__(self, x):
       bias = self.param('bias', initializers.ones, x.shape)
       return x + bias
   x = jnp.array([1.])
   scope = Scope({}, {'params': rngkey})
   with self.assertRaisesRegex(ValueError, 'must be initialized.*setup'):
     y = DummyModule(parent=scope)(x)
Esempio n. 2
0
 def test_init_module(self):
   rngkey = jax.random.PRNGKey(0)
   x = jnp.array([1.])
   scope = Scope({}, {'params': rngkey}, mutable=['params'])
   y = DummyModule(parent=scope)(x)
   params = scope.variables()['params']
   y2 = DummyModule(parent=scope.rewound())(x)
   onp.testing.assert_allclose(y, y2)
   onp.testing.assert_allclose(y, jnp.array([2.]))
   self.assertEqual(params, {'bias': jnp.array([1.])})
Esempio n. 3
0
 def test_attr_submodule_name_collision(self):
   rngkey = jax.random.PRNGKey(0)
   class Dummy(nn.Module):
     bias: bool
     def setup(self):
       self.bias = DummyModule(name='bias')
     def __call__(self, x):
       return self.bias(x)
   x = jnp.array([1.])
   scope = Scope({}, {'params': rngkey}, mutable=['params'])
   with self.assertRaisesRegex(ValueError, 'bias exists already'):
     y = Dummy(x.shape, parent=scope)(x)
Esempio n. 4
0
 def test_setattr_name_var_disagreement(self):
   rngkey = jax.random.PRNGKey(0)
   class Dummy(nn.Module):
     xshape: Tuple[int]
     def setup(self):
       self.bias = self.param('notbias', initializers.ones, self.xshape)
     def __call__(self, x):
       return x + self.bias
   x = jnp.array([1.])
   scope = Scope({}, {'params': rngkey}, mutable=['params'])
   with self.assertRaisesRegex(ValueError, 'notbias.*must equal.*bias'):
     y = Dummy(x.shape, parent=scope)(x)
Esempio n. 5
0
 def test_attr_param_name_collision(self):
   rngkey = jax.random.PRNGKey(0)
   class Dummy(nn.Module):
     bias: bool
     def setup(self):
       self.bias = self.param('bias', initializers.ones, (3, 3))
     def __call__(self, x):
       return x + self.bias
   x = jnp.array([1.])
   scope = Scope({}, {'params': rngkey}, mutable=['params'])
   with self.assertRaisesRegex(ValueError, 'Name bias already in use'):
     y = Dummy(x.shape, parent=scope)(x)
Esempio n. 6
0
 def test_attr_submodule_name_collision(self):
   rngkey = jax.random.PRNGKey(0)
   class Dummy(nn.Module):
     bias: bool
     def setup(self):
       self.bias = DummyModule(name='bias')
     def __call__(self, x):
       return self.bias(x)
   x = jnp.array([1.])
   scope = Scope({}, {'params': rngkey}, mutable=['params'])
   msg = 'Could not create submodule "bias" in Module Dummy: Name in use'
   with self.assertRaisesRegex(errors.NameInUseError, msg):
     y = Dummy(x.shape, parent=scope)(x)
Esempio n. 7
0
 def test_setattr_name_submodule_redundant(self):
   rngkey = jax.random.PRNGKey(0)
   class Dummy(nn.Module):
     xshape: Tuple[int]
     def setup(self):
       self.bias = DummyModule(name='bias')
     def __call__(self, x):
       return x + self.bias
   x = jnp.array([1.])
   scope = Scope({}, {'params': rngkey}, mutable=['params'])
   with self.assertRaisesRegex(ValueError, 'In setup, assign names of Modules '
       'via self.<name> and not using keyword argument name="<name>"'):
     y = Dummy(x.shape, parent=scope)(x)
Esempio n. 8
0
 def test_call_var_collision(self):
   rngkey = jax.random.PRNGKey(0)
   class Dummy(nn.Module):
     xshape: Tuple[int]
     @compact
     def __call__(self, x):
       bias = self.param('bias', initializers.ones, self.xshape)
       bias = self.param('bias', initializers.ones, self.xshape)
       return x + bias
   x = jnp.array([1.])
   scope = Scope({}, {'params': rngkey}, mutable=['params'])
   with self.assertRaisesRegex(ValueError, 'bias already in use'):
     y = Dummy(x.shape, parent=scope)(x)
Esempio n. 9
0
 def test_attr_param_name_collision(self):
   rngkey = jax.random.PRNGKey(0)
   class Dummy(nn.Module):
     bias: bool
     def setup(self):
       self.bias = self.param('bias', initializers.ones, (3, 3))
     def __call__(self, x):
       return x + self.bias
   x = jnp.array([1.])
   scope = Scope({}, {'params': rngkey}, mutable=['params'])
   msg = 'Could not create param "bias" in Module Dummy: Name in use'
   with self.assertRaisesRegex(errors.NameInUseError, msg):
     y = Dummy(x.shape, parent=scope)(x)
Esempio n. 10
0
 def test_submodule_var_collision(self):
   rngkey = jax.random.PRNGKey(0)
   class Dummy(nn.Module):
     xshape: Tuple[int]
     def setup(self):
       self.bias = self.param('bias', initializers.ones, self.xshape)
       self.bias = DummyModule()
     def __call__(self, x):
       return x + self.bias
   x = jnp.array([1.])
   scope = Scope({}, {'params': rngkey}, mutable=['params'])
   msg = r'Duplicate use of scope name: "bias"'
   with self.assertRaisesRegex(errors.ScopeNameInUseError, msg):
     y = Dummy(x.shape, parent=scope)(x)
   class Dummy(nn.Module):
     xshape: Tuple[int]
     def setup(self):
       self.bias = self.param('bias', initializers.ones, self.xshape)
     @compact
     def __call__(self, x):
       bias = DummyModule(name='bias')
       return x + self.bias
   x = jnp.array([1.])
   scope = Scope({}, {'params': rngkey}, mutable=['params'])
   with self.assertRaisesRegex(ValueError, 'name bias exists already'):
     y = Dummy(x.shape, parent=scope)(x)
   class Dummy(nn.Module):
     xshape: Tuple[int]
     def setup(self):
       self.bias = DummyModule()
     @compact
     def __call__(self, x):
       bias = self.param('bias', initializers.ones, self.xshape)
       return x + self.bias
   x = jnp.array([1.])
   scope = Scope({}, {'params': rngkey}, mutable=['params'])
   with self.assertRaisesRegex(ValueError, 'bias already'):
     y = Dummy(x.shape, parent=scope)(x)
Esempio n. 11
0
 def test_call_var_collision(self):
   rngkey = jax.random.PRNGKey(0)
   class Dummy(nn.Module):
     xshape: Tuple[int]
     @compact
     def __call__(self, x):
       bias = self.param('bias', initializers.ones, self.xshape)
       bias = self.param('bias', initializers.ones, self.xshape)
       return x + bias
   x = jnp.array([1.])
   scope = Scope({}, {'params': rngkey}, mutable=['params'])
   msg = 'Could not create param "bias" in Module Dummy: Name in use'
   with self.assertRaisesRegex(errors.NameInUseError, msg):
     y = Dummy(x.shape, parent=scope)(x)
Esempio n. 12
0
  def test_setattr_name_var_disagreement_allowed_in_lists(self):
    rngkey = jax.random.PRNGKey(0)
    class Dummy(nn.Module):
      xshape: Tuple[int]
      def setup(self):
        self.biases = [
          self.param(f'bias_{i}', initializers.ones, self.xshape)
          for i in range(4)]
      def __call__(self, x):
        return x + self.biases[0]

    x = jnp.array([1.])
    scope = Scope({}, {'params': rngkey}, mutable=['params'])
    y = Dummy(x.shape, parent=scope)(x)
    self.assertEqual(y, jnp.array([2.]))
Esempio n. 13
0
    def test_setattr_name_submodule_redundant(self):
        rngkey = jax.random.PRNGKey(0)

        class Dummy(nn.Module):
            xshape: Tuple[int]

            def setup(self):
                self.bias = DummyModule(name='bias')

            def __call__(self, x):
                return x + self.bias

        x = jnp.array([1.])
        scope = Scope({}, {'param': rngkey})
        with self.assertRaisesRegex(ValueError, 'assign names via self'):
            y = Dummy(x.shape, parent=scope)(x)
Esempio n. 14
0
 def test_param_in_setup(self):
   rngkey = jax.random.PRNGKey(0)
   class DummyModule(nn.Module):
     xshape: Tuple[int]
     def setup(self):
       self.bias = self.param('bias', initializers.ones, self.xshape)
     def __call__(self, x):
       return x + self.bias
   x = jnp.array([1.])
   scope = Scope({}, {'params': rngkey}, mutable=['params'])
   y = DummyModule(x.shape, parent=scope)(x)
   params = scope.variables()['params']
   y2 = DummyModule(x.shape, parent=scope.rewound())(x)
   np.testing.assert_allclose(y, y2)
   np.testing.assert_allclose(y, jnp.array([2.]))
   self.assertEqual(params, {'bias': jnp.array([1.])})
Esempio n. 15
0
  def test_only_one_compact_method_subclass(self):
    class Dummy(nn.Module):
      @nn.compact
      def __call__(self):
        pass
    class SubDummy(Dummy):
      @nn.compact
      def __call__(self):
        super().__call__()

    scope = Scope(variables={})

    subdummy = SubDummy(parent=scope)
    # Make sure the @compact annotation is valid on both base class and subclass, as long
    # as its on the same method.
    subdummy()
Esempio n. 16
0
    def test_setup_var_collision(self):
        rngkey = jax.random.PRNGKey(0)

        class Dummy(nn.Module):
            xshape: Tuple[int]

            def setup(self):
                self.bias = self.param('bias', initializers.ones, self.xshape)
                self.bias = self.param('bias', initializers.ones, self.xshape)

            def __call__(self, x):
                return x + self.bias

        x = jnp.array([1.])
        scope = Scope({}, {'param': rngkey})
        with self.assertRaisesRegex(ValueError, 'bias already in use'):
            y = Dummy(x.shape, parent=scope)(x)
Esempio n. 17
0
  def test_setattr_name_var_disagreement_allowed_in_dicts(self):
    rngkey = jax.random.PRNGKey(0)
    class Dummy(nn.Module):
      xshape: Tuple[int]
      def setup(self):
        self.biases = {
          # NOTE that keys still must be strings. This is to make a possible
          # future transition to automatically derived parameter names when assigned
          # as a dict easier (like we currently have with submodules).
          # See a bit of discussion here: https://github.com/google/flax/issues/705#issuecomment-738761853 
          str(i): self.param(f'bias_{i}', initializers.ones, self.xshape)
          for i in range(4)}
      def __call__(self, x):
        return x + self.biases['0']

    x = jnp.array([1.])
    scope = Scope({}, {'params': rngkey}, mutable=['params'])
    y = Dummy(x.shape, parent=scope)(x)
    self.assertEqual(y, jnp.array([2.]))
Esempio n. 18
0
  def test_submodule_var_collision_with_scope(self):
    rngkey = jax.random.PRNGKey(0)

    class Dummy(nn.Module):
      xshape: Tuple[int]

      def setup(self):
        self.bias = self.param('bias', initializers.ones, self.xshape)
        self.bias = DummyModule()

      def __call__(self, x):
        return x + self.bias

    x = jnp.array([1.])
    scope = Scope({}, {'params': rngkey}, mutable=['params'])

    msg = 'Duplicate use of scope name: "bias"'
    with self.assertRaisesWithLiteralMatch(ValueError, msg):
      y = Dummy(x.shape, parent=scope)(x)
Esempio n. 19
0
 def test_util_fun(self):
   rngkey = jax.random.PRNGKey(0)
   class MLP(nn.Module):
     @compact
     def __call__(self, x):
       x = self._mydense(x)
       x = self._mydense(x)
       return x
     def _mydense(self, x):
       return Dense(3)(x)
   x = jnp.ones((10,))
   scope = Scope({}, {'params': rngkey}, mutable=['params'])
   y = MLP(parent=scope)(x)
   params = scope.variables()['params']
   y2 = MLP(parent=scope.rewound())(x)
   np.testing.assert_allclose(y, y2)
   param_shape = jax.tree_map(jnp.shape, params)
   self.assertEqual(param_shape,
     {'Dense_0': {'kernel': (10, 3)},
      'Dense_1': {'kernel': (3, 3)}})
Esempio n. 20
0
 def test_setup_dict_assignment(self):
   rngkey = jax.random.PRNGKey(0)
   class MLP(nn.Module):
     def setup(self):
       self.lyrs1 = {'a': Dense(3), 'b': Dense(3),}
       self.lyrs2 = [Dense(3), Dense(3)]
     def __call__(self, x):
       y = self.lyrs1['a'](x)
       z = self.lyrs1['b'](y)
       #w = self.lyrs2[0](x)
       return z
   x = jnp.ones((10,))
   scope = Scope({}, {'params': rngkey}, mutable=['params'])
   y = MLP(parent=scope)(x)
   params = scope.variables()['params']
   y2 = MLP(parent=scope.rewound())(x)
   np.testing.assert_allclose(y, y2)
   param_shape = jax.tree_map(jnp.shape, params)
   self.assertEqual(param_shape,
     {'lyrs1_a': {'kernel': (10, 3)},
     'lyrs1_b': {'kernel': (3, 3)}})
Esempio n. 21
0
    def test_nested_module_reuse(self):
        rngkey = jax.random.PRNGKey(0)

        class MLP(nn.Module):
            @compact
            def __call__(self, x):
                x = self._mydense(x)
                x = self._mydense(x)
                return x

            def _mydense(self, x):
                return Dense(3)(x)

        class Top(nn.Module):
            @compact
            def __call__(self, x):
                mlp = MLP()
                y = mlp(x)
                z = mlp(x)
                return y + z

        x = jnp.ones((10, ))
        scope = Scope({}, {'param': rngkey})
        y = Top(parent=scope)(x)
        params = scope.variables()['param']
        y2 = Top(parent=scope.rewound())(x)
        onp.testing.assert_allclose(y, y2)
        param_shape = jax.tree_map(jnp.shape, params)
        self.assertEqual(
            param_shape, {
                'MLP_0': {
                    'Dense_0': {
                        'kernel': (10, 3)
                    },
                    'Dense_1': {
                        'kernel': (3, 3)
                    }
                }
            })
Esempio n. 22
0
 def test_module_with_scope_is_not_hashable(self):
   module_a = nn.Dense(10, parent=Scope({}))
   with self.assertRaisesWithLiteralMatch(ValueError, 'Can\'t call __hash__ on modules that hold variables.'):
     hash(module_a)
Esempio n. 23
0
 def test_setup_cloning(self):
   class MLP(nn.Module):
     def setup(self):
       self.dense = Dense(3)
   scope = Scope({})
   MLPclone = MLP(parent=scope).clone()