Exemplo n.º 1
0
    def assert_output_type(self, out_cls):
        def assert_type_recursive(s):
            self.assertEqual(type(s), out_cls)

        for in_cls in (dict, data_structures.FlatMap):
            with self.subTest(str(in_cls)):
                structure_a = in_cls({"m1": in_cls({"w": None})})
                structure_b = in_cls({"m2": in_cls({"w": None})})
                structure_c = in_cls(
                    {f"{i}": in_cls({"w": None})
                     for i in range(5)})
                assert_type_recursive(
                    filtering.filter(lambda m, n, v: True, structure_a))
                assert_type_recursive(
                    filtering.map(lambda m, n, v: v, structure_a))
                assert_type_recursive(filtering.merge(structure_a,
                                                      structure_b))
                parts = filtering.partition(lambda m, n, v: int(m) > 1,
                                            structure_c)
                for part in parts:
                    assert_type_recursive(part)
                parts = filtering.partition_n(lambda m, n, v: int(m),
                                              structure_c, 5)
                for part in parts:
                    assert_type_recursive(part)
Exemplo n.º 2
0
    def test_map(self):
        init_fn, _ = transform.transform(get_net)
        params = init_fn(jax.random.PRNGKey(428), jnp.ones((1, 1)))

        # parse by layer
        def map_fn(module_name, name, v):
            del name
            if "first_layer" in module_name:
                return v
            else:
                return 2. * v

        new_params = filtering.map(map_fn, params)
        self.assertLen(jax.tree_leaves(new_params), 4)

        first_layer_params, second_layer_params = filtering.partition(
            lambda module_name, *_: module_name == "first_layer", params)
        for mn in first_layer_params:
            for n in first_layer_params[mn]:
                self.assertEqual(params[mn][n], new_params[mn][n])

        for mn in second_layer_params:
            for n in second_layer_params[mn]:
                self.assertEqual(2. * params[mn][n], new_params[mn][n])