def test_different_metadata(self): e, = prefix_errors({1: 2}, {3: 4}) expected = ("pytree structure error: different pytree metadata " "at key path\n" " in_axes tree root") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes')
def test_different_metadata_nested(self): e, = prefix_errors([{1: 2}], [{3: 4}]) expected = ("pytree structure error: different pytree metadata " "at key path\n" r" in_axes\[0\]") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes')
def test_different_num_children_nested(self): e, = prefix_errors([[1]], [[2, 3]]) expected = ("pytree structure error: different numbers of pytree children " "at key path\n" r" in_axes\[0\]") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes')
def test_different_num_children(self): e, = prefix_errors((1,), (2, 3)) expected = ("pytree structure error: different numbers of pytree children " "at key path\n" " in_axes tree root") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes')
def test_different_types_multiple(self): e1, e2 = prefix_errors(((1, ), (2, )), ([3], [4])) expected = r"pytree structure error: different types at in_axes\[0\]" with self.assertRaisesRegex(ValueError, expected): raise e1('in_axes') expected = r"pytree structure error: different types at in_axes\[1\]" with self.assertRaisesRegex(ValueError, expected): raise e2('in_axes')
def test_different_metadata_multiple(self): e1, e2 = prefix_errors([{1: 2}, {3: 4}], [{3: 4}, {5: 6}]) expected = ("pytree structure error: different pytree metadata " r"at in_axes\[0\]") with self.assertRaisesRegex(ValueError, expected): raise e1('in_axes') expected = ("pytree structure error: different pytree metadata " r"at in_axes\[1\]") with self.assertRaisesRegex(ValueError, expected): raise e2('in_axes')
def test_different_num_children_multiple(self): e1, e2 = prefix_errors([[1], [2]], [[3, 4], [5, 6]]) expected = ("pytree structure error: different numbers of pytree children " "at key path\n" r" in_axes\[0\]") with self.assertRaisesRegex(ValueError, expected): raise e1('in_axes') expected = ("pytree structure error: different numbers of pytree children " "at key path\n" r" in_axes\[1\]") with self.assertRaisesRegex(ValueError, expected): raise e2('in_axes')
def test_no_errors(self): () = prefix_errors((1, 2), ((11, 12, 13), 2))
def test_fallback_keypath(self): e, = prefix_errors(Special(1, [2]), Special(3, 4)) expected = ("pytree structure error: different types at key path\n" r" in_axes\[<flat index 1>\]") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes')
def test_different_types_nested(self): e, = prefix_errors(((1,), (2,)), ([3], (4,))) expected = ("pytree structure error: different types at key path\n" r" in_axes\[0\]") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes')
def test_different_types(self): e, = prefix_errors((1, 2), [1, 2]) expected = ("pytree structure error: different types at key path\n" " in_axes tree root") with self.assertRaisesRegex(ValueError, expected): raise e('in_axes')