def test_partitioning(self): init_fn, _ = transform.transform(get_net) params = init_fn(jax.random.PRNGKey(428), jnp.ones((1, 1))) # parse by layer first_layer_params, second_layer_params = filtering.partition( lambda module_name, *_: module_name == "first_layer", params) self.assertEqual( get_names(first_layer_params), set(["first_layer/w", "first_layer/b"])) self.assertEqual( get_names(second_layer_params), set(["second_layer/w", "second_layer/b"])) # parse by variable type weights, biases = filtering.partition( lambda module_name, name, _: name == "w", params) # pytype: disable=wrong-arg-types self.assertEqual( get_names(weights), set(["first_layer/w", "second_layer/w"])) self.assertEqual( get_names(biases), set(["first_layer/b", "second_layer/b"])) # Compose regexes regex = compile_regex(["first_layer.*", ".*w"]) matching, not_matching = filtering.partition( lambda module_name, name, _: regex.match(f"{module_name}/{name}"), params) self.assertEqual( get_names(matching), set(["first_layer/w", "first_layer/b", "second_layer/w"])) self.assertEqual( get_names(not_matching), set(["second_layer/b"])) matching, not_matching = filtering.partition( lambda mod_name, name, _: mod_name == "first_layer" and name != "w", params) self.assertEqual( get_names(matching), set(["first_layer/b"])) self.assertEqual( get_names(not_matching), set(["first_layer/w", "second_layer/w", "second_layer/b"]))
def assert_output_type(self, out_cls): def assert_type_recursive(s): self.assertEqual(type(s), out_cls) for in_cls in (dict, data_structures.FlatMap): with self.subTest(str(in_cls)): structure_a = in_cls({"m1": in_cls({"w": None})}) structure_b = in_cls({"m2": in_cls({"w": None})}) structure_c = in_cls( {f"{i}": in_cls({"w": None}) for i in range(5)}) assert_type_recursive( filtering.filter(lambda m, n, v: True, structure_a)) assert_type_recursive( filtering.map(lambda m, n, v: v, structure_a)) assert_type_recursive(filtering.merge(structure_a, structure_b)) parts = filtering.partition(lambda m, n, v: int(m) > 1, structure_c) for part in parts: assert_type_recursive(part) parts = filtering.partition_n(lambda m, n, v: int(m), structure_c, 5) for part in parts: assert_type_recursive(part)
def test_map(self): init_fn, _ = transform.transform(get_net) params = init_fn(jax.random.PRNGKey(428), jnp.ones((1, 1))) # parse by layer def map_fn(module_name, name, v): del name if "first_layer" in module_name: return v else: return 2. * v new_params = filtering.map(map_fn, params) self.assertLen(jax.tree_leaves(new_params), 4) first_layer_params, second_layer_params = filtering.partition( lambda module_name, *_: module_name == "first_layer", params) for mn in first_layer_params: for n in first_layer_params[mn]: self.assertEqual(params[mn][n], new_params[mn][n]) for mn in second_layer_params: for n in second_layer_params[mn]: self.assertEqual(2. * params[mn][n], new_params[mn][n])
def fn_with_filter(p, *args, **kwargs): p1, p2 = filtering.partition(predicate, p) return jaxed_fn(p1, p2, *args, **kwargs)