def test_check_duplicates(self): err = "Duplicate array found" a = {"a": {"b": jnp.array([0.0, 0.1], dtype=jnp.float32)}} b = {"a": {"b": jnp.array([0.0, 0.1], dtype=jnp.bfloat16)}} c = {"a": {"b": jnp.array([0.0, 0.1, 0.2], dtype=jnp.float32)}} d = {"a": {"b": "foo"}} with self.subTest("dtype_mismatch"): with self.assertRaisesRegex(ValueError, fr"{err}.*f32\[2\] vs bf16\[2\]"): filtering.merge(a, b, check_duplicates=True) with self.subTest("shape_mismatch"): with self.assertRaisesRegex(ValueError, fr"{err}.*f32\[2\] vs f32\[3\]"): filtering.merge(a, c, check_duplicates=True) with self.subTest("multiple_mismatch"): with self.assertRaisesRegex(ValueError, fr"{err}.*f32\[2\] vs bf16\[2\]"): filtering.merge(a, b, c, check_duplicates=True) with self.subTest("object_mismatch"): with self.assertRaisesRegex(ValueError, fr"{err}.*f32\[2\] vs 'foo'"): filtering.merge(a, d, check_duplicates=True)
def assert_output_type(self, out_cls): def assert_type_recursive(s): self.assertEqual(type(s), out_cls) for in_cls in (dict, data_structures.FlatMap): with self.subTest(str(in_cls)): structure_a = in_cls({"m1": in_cls({"w": None})}) structure_b = in_cls({"m2": in_cls({"w": None})}) structure_c = in_cls( {f"{i}": in_cls({"w": None}) for i in range(5)}) assert_type_recursive( filtering.filter(lambda m, n, v: True, structure_a)) assert_type_recursive( filtering.map(lambda m, n, v: v, structure_a)) assert_type_recursive(filtering.merge(structure_a, structure_b)) parts = filtering.partition(lambda m, n, v: int(m) > 1, structure_c) for part in parts: assert_type_recursive(part) parts = filtering.partition_n(lambda m, n, v: int(m), structure_c, 5) for part in parts: assert_type_recursive(part)
def test_merge_different_mappings(self): a = collections.defaultdict(dict) a["foo"]["bar"] = 1 b = {"foo": {"baz": 2}} c = types.MappingProxyType({"foo": {"bat": 3}}) d = filtering.merge(a, b, c) self.assertEqual(d, {"foo": {"bar": 1, "baz": 2, "bat": 3}})
def test_partition_n_merge_isomorphism(self, n): cnt = itertools.count() fn = lambda m, n, v: next(cnt) input_structure = {f"layer_{i}": {"w": None} for i in range(n)} structures = filtering.partition_n(fn, input_structure, n) merged_structure = filtering.merge(*structures) self.assertEqual(merged_structure, input_structure)
def test_partition_n(self, n): cnt = itertools.count() fn = lambda m, n, v: next(cnt) structure = {f"layer_{i}": {"w": None} for i in range(n)} structures = filtering.partition_n(fn, structure, n) self.assertLen(structures, n) self.assertEqual(filtering.merge(*structures), structure) for i, substructure in enumerate(structures): expected = {f"layer_{i}": {"w": None}} self.assertEqual(substructure, expected)
def test_merge_nested(self): a = {"layer": {"a": [1, 2, 3]}} b = {"layer": {"b": set([object()])}} c = {"layer": {"c": {"a": "b"}}} actual = filtering.merge(a, b, c) expected = { "layer": { "a": a["layer"]["a"], "b": b["layer"]["b"], "c": c["layer"]["c"] } } self.assertEqual(expected, actual)
def wrapper(p1, p2, *args, **kwargs): return f(filtering.merge(p1, p2), *args, **kwargs)