def assert_output_type(self, out_cls): def assert_type_recursive(s): self.assertEqual(type(s), out_cls) for in_cls in (dict, data_structures.FlatMap): with self.subTest(str(in_cls)): structure_a = in_cls({"m1": in_cls({"w": None})}) structure_b = in_cls({"m2": in_cls({"w": None})}) structure_c = in_cls( {f"{i}": in_cls({"w": None}) for i in range(5)}) assert_type_recursive( filtering.filter(lambda m, n, v: True, structure_a)) assert_type_recursive( filtering.map(lambda m, n, v: v, structure_a)) assert_type_recursive(filtering.merge(structure_a, structure_b)) parts = filtering.partition(lambda m, n, v: int(m) > 1, structure_c) for part in parts: assert_type_recursive(part) parts = filtering.partition_n(lambda m, n, v: int(m), structure_c, 5) for part in parts: assert_type_recursive(part)
def test_map(self): init_fn, _ = transform.transform(get_net) params = init_fn(jax.random.PRNGKey(428), jnp.ones((1, 1))) # parse by layer def map_fn(module_name, name, v): del name if "first_layer" in module_name: return v else: return 2. * v new_params = filtering.map(map_fn, params) self.assertLen(jax.tree_leaves(new_params), 4) first_layer_params, second_layer_params = filtering.partition( lambda module_name, *_: module_name == "first_layer", params) for mn in first_layer_params: for n in first_layer_params[mn]: self.assertEqual(params[mn][n], new_params[mn][n]) for mn in second_layer_params: for n in second_layer_params[mn]: self.assertEqual(2. * params[mn][n], new_params[mn][n])