Exemple #1
0
    def test_filter(self):
        init_fn, _ = transform.transform(get_net)
        params = init_fn(jax.random.PRNGKey(428), jnp.ones((1, 1)))

        second_layer_params = filtering.filter(
            lambda module_name, *_: module_name == "second_layer", params)
        self.assertEqual(get_names(second_layer_params),
                         set(["second_layer/w", "second_layer/b"]))

        biases = filtering.filter(lambda module_name, name, _: name == "b",
                                  params)  # pytype: disable=wrong-arg-types
        self.assertEqual(get_names(biases),
                         set(["first_layer/b", "second_layer/b"]))
Exemple #2
0
    def assert_output_type(self, out_cls):
        def assert_type_recursive(s):
            self.assertEqual(type(s), out_cls)

        for in_cls in (dict, data_structures.FlatMap):
            with self.subTest(str(in_cls)):
                structure_a = in_cls({"m1": in_cls({"w": None})})
                structure_b = in_cls({"m2": in_cls({"w": None})})
                structure_c = in_cls(
                    {f"{i}": in_cls({"w": None})
                     for i in range(5)})
                assert_type_recursive(
                    filtering.filter(lambda m, n, v: True, structure_a))
                assert_type_recursive(
                    filtering.map(lambda m, n, v: v, structure_a))
                assert_type_recursive(filtering.merge(structure_a,
                                                      structure_b))
                parts = filtering.partition(lambda m, n, v: int(m) > 1,
                                            structure_c)
                for part in parts:
                    assert_type_recursive(part)
                parts = filtering.partition_n(lambda m, n, v: int(m),
                                              structure_c, 5)
                for part in parts:
                    assert_type_recursive(part)