Ejemplo n.º 1
0
def haiku_public_symbols():
    names = set()
    for module_name, module in test_utils.find_internal_python_modules(hk):
        for name in module.__all__:
            symbol_name = f"{module_name}.{name}"
            if symbol_name not in HIDDEN_SYMBOLS:
                names.add(symbol_name)
    return names
Ejemplo n.º 2
0
class DoctestTest(parameterized.TestCase):
    def setUp(self):
        super().setUp()
        os.environ["HAIKU_FLATMAPPING"] = "0"

    def tearDown(self):
        super().tearDown()
        del os.environ["HAIKU_FLATMAPPING"]

    @parameterized.named_parameters(test_utils.find_internal_python_modules(hk)
                                    )
    def test_doctest(self, module):
        def run_test():
            num_failed, num_attempted = doctest.testmod(
                module,
                optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
                extraglobs={
                    "itertools": itertools,
                    "chex": chex,
                    "collections": collections,
                    "contextlib": contextlib,
                    "unittest": unittest,
                    "hk": hk,
                    "jnp": jnp,
                    "jax": jax,
                    "jmp": jmp,
                })
            tests_symbols = ", ".join(module.__test__.keys())
            if num_attempted == 0:
                logging.info("No doctests in %s", tests_symbols)
            self.assertEqual(num_failed, 0,
                             "{} doctests failed".format(num_failed))
            logging.info("%s tests passed in %s", num_attempted, tests_symbols)

        # `hk` et al import all dependencies from `src`, however doctest does not
        # test imported deps so we must manually set `__test__` such that imported
        # symbols are tested.
        # See: docs.python.org/3/library/doctest.html#which-docstrings-are-examined
        if not hasattr(module, "__test__") or not module.__test__:
            module.__test__ = {}

        # Many tests expect to be run as part of an `hk.transform`. We loop over all
        # exported symbols and run them in their own `hk.transform` so parameter and
        # module names don't clash.
        for name in module.__all__:
            test_names = []

            value = getattr(module, name)
            if inspect.ismodule(value):
                continue

            # Skip type annotations in Python 3.7.
            if hasattr(value, "__origin__"):
                continue

            logging.info("Testing name: %r value: %r", name, value)
            if inspect.isclass(value):
                # Find unbound methods on classes, doctest doesn't seem to find them.
                test_names.append(name)
                module.__test__[name] = value

                for attr_name in dir(value):
                    attr_value = getattr(value, attr_name)
                    if inspect.isfunction(attr_value):
                        test_name = name + "_" + attr_name
                        test_names.append(test_name)
                        module.__test__[test_name] = attr_value
            else:
                test_names.append(name)
                module.__test__[name] = value

            init_fn, _ = hk.transform_with_state(run_test)
            rng = jax.random.PRNGKey(42)
            init_fn(rng)

            for test_name in test_names:
                del module.__test__[test_name]
Ejemplo n.º 3
0
def haiku_public_symbols():
    names = set()
    for module_name, module in test_utils.find_internal_python_modules(hk):
        for name in module.__all__:
            names.add(module_name + "." + name)
    return names