Пример #1
0
    def test_to_module_error_docs(self):
        def documented_fn():
            """Really great docs."""

        def undocumented_fn():
            pass

        cls = basic.to_module(documented_fn)
        documented = cls()
        self.assertEqual(documented.__doc__, "Really great docs.")
        self.assertEqual(documented.__call__.__doc__, "Really great docs.")

        cls = basic.to_module(undocumented_fn)
        undocumented = cls()
        self.assertEqual(undocumented.__doc__,
                         "Module produced by `hk.to_module`.")
        self.assertIsNone(undocumented.__call__.__doc__)
Пример #2
0
    def test_to_module(self):
        def bias_fn(x):
            b = base.get_parameter("b", [], init=jnp.ones)
            return x + b

        Bias = basic.to_module(bias_fn)  # pylint: disable=invalid-name
        mod = Bias()
        self.assertEqual(mod(jnp.ones([])), 2.)
Пример #3
0
    def test_includes_no_param_modules(self):
        dropout_cls = basic.to_module(
            lambda x: basic.dropout(base.next_rng_key(), 0.5, x))

        x = jnp.ones([4])
        f = lambda: dropout_cls(name="dropout")(x)
        rows = tabulate_to_list(f, columns=("module", ))
        expected = [["dropout (ToModuleWrapper)"]]
        self.assertEqual(rows, expected)
Пример #4
0
    def test_to_module_error_invalid_name(self):
        def bias_fn(x):
            b = base.get_parameter("b", [], init=jnp.ones)
            return x + b

        cls = basic.to_module(bias_fn)
        garbage = object()
        with self.assertRaisesRegex(
                TypeError, f"Expected a string name .* got: {garbage}"):
            cls(garbage)  # pytype: disable=wrong-arg-types
Пример #5
0
    def test_params_or_state(self, params, num_elems):
        def cls():
            for i in range(num_elems):
                g = base.get_parameter if params else base.get_state
                g(f"x{i}", [], init=jnp.zeros)

        f = lambda: basic.to_module(cls)(name="foo")()
        invocations = get_summary(f)
        invocation, = invocations
        details = invocation.module_details
        d = details.params if params else details.state
        self.assertEqual(list(d), [f"foo/x{i}" for i in range(num_elems)])
Пример #6
0
    def test_owned_params_sorted_by_size_then_name(self):
        def f():
            base.get_parameter("a", [1], init=jnp.zeros)
            base.get_parameter("b", [2], init=jnp.zeros)
            base.get_parameter("c", [2], init=jnp.zeros)
            base.get_parameter("d", [3], init=jnp.zeros)
            return 0

        f = lambda f=f: basic.to_module(f)()()
        rows = tabulate_to_list(f, columns=("owned_params", ))
        expected = [["d: f32[3]\n" "b: f32[2]\n" "c: f32[2]\n" "a: f32[1]"]]
        self.assertEqual(rows, expected)