Exemplo n.º 1
0
 def test_coverage(self):
     all_modules = frozenset(test_utils.find_subclasses(hk, hk.Module))
     tested_modules = {
         type(descriptors.unwrap(d.create()))
         for d in ALL_MODULES
     }
     self.assertEmpty(all_modules - (tested_modules | IGNORED_MODULES))
Exemplo n.º 2
0
def module_type(module_fn: ModuleFn) -> Type[hk.Module]:
    f = hk.transform(lambda: type(descriptors.unwrap(module_fn())), state=True)
    return f.apply(*f.init(jax.random.PRNGKey(42)))[0]
Exemplo n.º 3
0
 def test_protocols(self, module_fn: ModuleFn, shape, dtype):
     del shape, dtype
     module = descriptors.unwrap(module_fn())
     self.assertIsInstance(module, hk.ModuleProtocol)
     # NOTE: All current Haiku builtin modules are callable.
     self.assertIsInstance(module, hk.SupportsCall)
Exemplo n.º 4
0
def module_name(d: descriptors.ModuleDescriptor):
    name = hk.testing.transform_and_run(
        lambda: str(descriptors.unwrap(d.create())))()
    return name.split("\n")
Exemplo n.º 5
0
def get_module_cls(module_fn: ModuleFn) -> Type[hk.Module]:
    get_cls = lambda: type(descriptors.unwrap(module_fn()))
    return hk.testing.transform_and_run(get_cls)()