Esempio n. 1
0
    def test_check_duplicates(self):
        err = "Duplicate array found"
        a = {"a": {"b": jnp.array([0.0, 0.1], dtype=jnp.float32)}}
        b = {"a": {"b": jnp.array([0.0, 0.1], dtype=jnp.bfloat16)}}
        c = {"a": {"b": jnp.array([0.0, 0.1, 0.2], dtype=jnp.float32)}}
        d = {"a": {"b": "foo"}}

        with self.subTest("dtype_mismatch"):
            with self.assertRaisesRegex(ValueError,
                                        fr"{err}.*f32\[2\] vs bf16\[2\]"):
                filtering.merge(a, b, check_duplicates=True)

        with self.subTest("shape_mismatch"):
            with self.assertRaisesRegex(ValueError,
                                        fr"{err}.*f32\[2\] vs f32\[3\]"):
                filtering.merge(a, c, check_duplicates=True)

        with self.subTest("multiple_mismatch"):
            with self.assertRaisesRegex(ValueError,
                                        fr"{err}.*f32\[2\] vs bf16\[2\]"):
                filtering.merge(a, b, c, check_duplicates=True)

        with self.subTest("object_mismatch"):
            with self.assertRaisesRegex(ValueError,
                                        fr"{err}.*f32\[2\] vs 'foo'"):
                filtering.merge(a, d, check_duplicates=True)
Esempio n. 2
0
    def assert_output_type(self, out_cls):
        def assert_type_recursive(s):
            self.assertEqual(type(s), out_cls)

        for in_cls in (dict, data_structures.FlatMap):
            with self.subTest(str(in_cls)):
                structure_a = in_cls({"m1": in_cls({"w": None})})
                structure_b = in_cls({"m2": in_cls({"w": None})})
                structure_c = in_cls(
                    {f"{i}": in_cls({"w": None})
                     for i in range(5)})
                assert_type_recursive(
                    filtering.filter(lambda m, n, v: True, structure_a))
                assert_type_recursive(
                    filtering.map(lambda m, n, v: v, structure_a))
                assert_type_recursive(filtering.merge(structure_a,
                                                      structure_b))
                parts = filtering.partition(lambda m, n, v: int(m) > 1,
                                            structure_c)
                for part in parts:
                    assert_type_recursive(part)
                parts = filtering.partition_n(lambda m, n, v: int(m),
                                              structure_c, 5)
                for part in parts:
                    assert_type_recursive(part)
Esempio n. 3
0
 def test_merge_different_mappings(self):
     a = collections.defaultdict(dict)
     a["foo"]["bar"] = 1
     b = {"foo": {"baz": 2}}
     c = types.MappingProxyType({"foo": {"bat": 3}})
     d = filtering.merge(a, b, c)
     self.assertEqual(d, {"foo": {"bar": 1, "baz": 2, "bat": 3}})
Esempio n. 4
0
 def test_partition_n_merge_isomorphism(self, n):
     cnt = itertools.count()
     fn = lambda m, n, v: next(cnt)
     input_structure = {f"layer_{i}": {"w": None} for i in range(n)}
     structures = filtering.partition_n(fn, input_structure, n)
     merged_structure = filtering.merge(*structures)
     self.assertEqual(merged_structure, input_structure)
Esempio n. 5
0
 def test_partition_n(self, n):
     cnt = itertools.count()
     fn = lambda m, n, v: next(cnt)
     structure = {f"layer_{i}": {"w": None} for i in range(n)}
     structures = filtering.partition_n(fn, structure, n)
     self.assertLen(structures, n)
     self.assertEqual(filtering.merge(*structures), structure)
     for i, substructure in enumerate(structures):
         expected = {f"layer_{i}": {"w": None}}
         self.assertEqual(substructure, expected)
Esempio n. 6
0
 def test_merge_nested(self):
     a = {"layer": {"a": [1, 2, 3]}}
     b = {"layer": {"b": set([object()])}}
     c = {"layer": {"c": {"a": "b"}}}
     actual = filtering.merge(a, b, c)
     expected = {
         "layer": {
             "a": a["layer"]["a"],
             "b": b["layer"]["b"],
             "c": c["layer"]["c"]
         }
     }
     self.assertEqual(expected, actual)
Esempio n. 7
0
 def wrapper(p1, p2, *args, **kwargs):
     return f(filtering.merge(p1, p2), *args, **kwargs)