Esempio n. 1
0
  def test_partitioning(self):

    init_fn, _ = transform.transform(get_net)
    params = init_fn(jax.random.PRNGKey(428), jnp.ones((1, 1)))

    # parse by layer
    first_layer_params, second_layer_params = filtering.partition(
        lambda module_name, *_: module_name == "first_layer",
        params)
    self.assertEqual(
        get_names(first_layer_params),
        set(["first_layer/w", "first_layer/b"]))
    self.assertEqual(
        get_names(second_layer_params),
        set(["second_layer/w", "second_layer/b"]))

    # parse by variable type
    weights, biases = filtering.partition(
        lambda module_name, name, _: name == "w",
        params)  # pytype: disable=wrong-arg-types
    self.assertEqual(
        get_names(weights),
        set(["first_layer/w", "second_layer/w"]))
    self.assertEqual(
        get_names(biases),
        set(["first_layer/b", "second_layer/b"]))

    # Compose regexes
    regex = compile_regex(["first_layer.*", ".*w"])
    matching, not_matching = filtering.partition(
        lambda module_name, name, _: regex.match(f"{module_name}/{name}"),
        params)
    self.assertEqual(
        get_names(matching),
        set(["first_layer/w", "first_layer/b", "second_layer/w"]))
    self.assertEqual(
        get_names(not_matching),
        set(["second_layer/b"]))

    matching, not_matching = filtering.partition(
        lambda mod_name, name, _: mod_name == "first_layer" and name != "w",
        params)
    self.assertEqual(
        get_names(matching),
        set(["first_layer/b"]))
    self.assertEqual(
        get_names(not_matching),
        set(["first_layer/w", "second_layer/w", "second_layer/b"]))
Esempio n. 2
0
    def assert_output_type(self, out_cls):
        def assert_type_recursive(s):
            self.assertEqual(type(s), out_cls)

        for in_cls in (dict, data_structures.FlatMap):
            with self.subTest(str(in_cls)):
                structure_a = in_cls({"m1": in_cls({"w": None})})
                structure_b = in_cls({"m2": in_cls({"w": None})})
                structure_c = in_cls(
                    {f"{i}": in_cls({"w": None})
                     for i in range(5)})
                assert_type_recursive(
                    filtering.filter(lambda m, n, v: True, structure_a))
                assert_type_recursive(
                    filtering.map(lambda m, n, v: v, structure_a))
                assert_type_recursive(filtering.merge(structure_a,
                                                      structure_b))
                parts = filtering.partition(lambda m, n, v: int(m) > 1,
                                            structure_c)
                for part in parts:
                    assert_type_recursive(part)
                parts = filtering.partition_n(lambda m, n, v: int(m),
                                              structure_c, 5)
                for part in parts:
                    assert_type_recursive(part)
Esempio n. 3
0
    def test_map(self):
        init_fn, _ = transform.transform(get_net)
        params = init_fn(jax.random.PRNGKey(428), jnp.ones((1, 1)))

        # parse by layer
        def map_fn(module_name, name, v):
            del name
            if "first_layer" in module_name:
                return v
            else:
                return 2. * v

        new_params = filtering.map(map_fn, params)
        self.assertLen(jax.tree_leaves(new_params), 4)

        first_layer_params, second_layer_params = filtering.partition(
            lambda module_name, *_: module_name == "first_layer", params)
        for mn in first_layer_params:
            for n in first_layer_params[mn]:
                self.assertEqual(params[mn][n], new_params[mn][n])

        for mn in second_layer_params:
            for n in second_layer_params[mn]:
                self.assertEqual(2. * params[mn][n], new_params[mn][n])
Esempio n. 4
0
 def fn_with_filter(p, *args, **kwargs):
     p1, p2 = filtering.partition(predicate, p)
     return jaxed_fn(p1, p2, *args, **kwargs)