def test_filter(self): init_fn, _ = transform.transform(get_net) params = init_fn(jax.random.PRNGKey(428), jnp.ones((1, 1))) second_layer_params = filtering.filter( lambda module_name, *_: module_name == "second_layer", params) self.assertEqual(get_names(second_layer_params), set(["second_layer/w", "second_layer/b"])) biases = filtering.filter(lambda module_name, name, _: name == "b", params) # pytype: disable=wrong-arg-types self.assertEqual(get_names(biases), set(["first_layer/b", "second_layer/b"]))
def assert_output_type(self, out_cls): def assert_type_recursive(s): self.assertEqual(type(s), out_cls) for in_cls in (dict, data_structures.FlatMap): with self.subTest(str(in_cls)): structure_a = in_cls({"m1": in_cls({"w": None})}) structure_b = in_cls({"m2": in_cls({"w": None})}) structure_c = in_cls( {f"{i}": in_cls({"w": None}) for i in range(5)}) assert_type_recursive( filtering.filter(lambda m, n, v: True, structure_a)) assert_type_recursive( filtering.map(lambda m, n, v: v, structure_a)) assert_type_recursive(filtering.merge(structure_a, structure_b)) parts = filtering.partition(lambda m, n, v: int(m) > 1, structure_c) for part in parts: assert_type_recursive(part) parts = filtering.partition_n(lambda m, n, v: int(m), structure_c, 5) for part in parts: assert_type_recursive(part)