def test_to_module_error_docs(self): def documented_fn(): """Really great docs.""" def undocumented_fn(): pass cls = basic.to_module(documented_fn) documented = cls() self.assertEqual(documented.__doc__, "Really great docs.") self.assertEqual(documented.__call__.__doc__, "Really great docs.") cls = basic.to_module(undocumented_fn) undocumented = cls() self.assertEqual(undocumented.__doc__, "Module produced by `hk.to_module`.") self.assertIsNone(undocumented.__call__.__doc__)
def test_to_module(self): def bias_fn(x): b = base.get_parameter("b", [], init=jnp.ones) return x + b Bias = basic.to_module(bias_fn) # pylint: disable=invalid-name mod = Bias() self.assertEqual(mod(jnp.ones([])), 2.)
def test_includes_no_param_modules(self): dropout_cls = basic.to_module( lambda x: basic.dropout(base.next_rng_key(), 0.5, x)) x = jnp.ones([4]) f = lambda: dropout_cls(name="dropout")(x) rows = tabulate_to_list(f, columns=("module", )) expected = [["dropout (ToModuleWrapper)"]] self.assertEqual(rows, expected)
def test_to_module_error_invalid_name(self): def bias_fn(x): b = base.get_parameter("b", [], init=jnp.ones) return x + b cls = basic.to_module(bias_fn) garbage = object() with self.assertRaisesRegex( TypeError, f"Expected a string name .* got: {garbage}"): cls(garbage) # pytype: disable=wrong-arg-types
def test_params_or_state(self, params, num_elems): def cls(): for i in range(num_elems): g = base.get_parameter if params else base.get_state g(f"x{i}", [], init=jnp.zeros) f = lambda: basic.to_module(cls)(name="foo")() invocations = get_summary(f) invocation, = invocations details = invocation.module_details d = details.params if params else details.state self.assertEqual(list(d), [f"foo/x{i}" for i in range(num_elems)])
def test_owned_params_sorted_by_size_then_name(self): def f(): base.get_parameter("a", [1], init=jnp.zeros) base.get_parameter("b", [2], init=jnp.zeros) base.get_parameter("c", [2], init=jnp.zeros) base.get_parameter("d", [3], init=jnp.zeros) return 0 f = lambda f=f: basic.to_module(f)()() rows = tabulate_to_list(f, columns=("owned_params", )) expected = [["d: f32[3]\n" "b: f32[2]\n" "c: f32[2]\n" "a: f32[1]"]] self.assertEqual(rows, expected)