示例#1
0
 def test_different_metadata(self):
   e, = prefix_errors({1: 2}, {3: 4})
   expected = ("pytree structure error: different pytree metadata "
               "at key path\n"
               "    in_axes tree root")
   with self.assertRaisesRegex(ValueError, expected):
     raise e('in_axes')
示例#2
0
 def test_different_metadata_nested(self):
   e, = prefix_errors([{1: 2}], [{3: 4}])
   expected = ("pytree structure error: different pytree metadata "
               "at key path\n"
               r"    in_axes\[0\]")
   with self.assertRaisesRegex(ValueError, expected):
     raise e('in_axes')
示例#3
0
 def test_different_num_children_nested(self):
   e, = prefix_errors([[1]], [[2, 3]])
   expected = ("pytree structure error: different numbers of pytree children "
               "at key path\n"
               r"    in_axes\[0\]")
   with self.assertRaisesRegex(ValueError, expected):
     raise e('in_axes')
示例#4
0
 def test_different_num_children(self):
   e, = prefix_errors((1,), (2, 3))
   expected = ("pytree structure error: different numbers of pytree children "
               "at key path\n"
               "    in_axes tree root")
   with self.assertRaisesRegex(ValueError, expected):
     raise e('in_axes')
示例#5
0
 def test_different_types_multiple(self):
     e1, e2 = prefix_errors(((1, ), (2, )), ([3], [4]))
     expected = r"pytree structure error: different types at in_axes\[0\]"
     with self.assertRaisesRegex(ValueError, expected):
         raise e1('in_axes')
     expected = r"pytree structure error: different types at in_axes\[1\]"
     with self.assertRaisesRegex(ValueError, expected):
         raise e2('in_axes')
示例#6
0
 def test_different_metadata_multiple(self):
     e1, e2 = prefix_errors([{1: 2}, {3: 4}], [{3: 4}, {5: 6}])
     expected = ("pytree structure error: different pytree metadata "
                 r"at in_axes\[0\]")
     with self.assertRaisesRegex(ValueError, expected):
         raise e1('in_axes')
     expected = ("pytree structure error: different pytree metadata "
                 r"at in_axes\[1\]")
     with self.assertRaisesRegex(ValueError, expected):
         raise e2('in_axes')
示例#7
0
 def test_different_num_children_multiple(self):
   e1, e2 = prefix_errors([[1], [2]], [[3, 4], [5, 6]])
   expected = ("pytree structure error: different numbers of pytree children "
               "at key path\n"
               r"    in_axes\[0\]")
   with self.assertRaisesRegex(ValueError, expected):
     raise e1('in_axes')
   expected = ("pytree structure error: different numbers of pytree children "
               "at key path\n"
               r"    in_axes\[1\]")
   with self.assertRaisesRegex(ValueError, expected):
     raise e2('in_axes')
示例#8
0
 def test_no_errors(self):
   () = prefix_errors((1, 2), ((11, 12, 13), 2))
示例#9
0
 def test_fallback_keypath(self):
   e, = prefix_errors(Special(1, [2]), Special(3, 4))
   expected = ("pytree structure error: different types at key path\n"
               r"    in_axes\[<flat index 1>\]")
   with self.assertRaisesRegex(ValueError, expected):
     raise e('in_axes')
示例#10
0
 def test_different_types_nested(self):
   e, = prefix_errors(((1,), (2,)), ([3], (4,)))
   expected = ("pytree structure error: different types at key path\n"
               r"    in_axes\[0\]")
   with self.assertRaisesRegex(ValueError, expected):
     raise e('in_axes')
示例#11
0
 def test_different_types(self):
   e, = prefix_errors((1, 2), [1, 2])
   expected = ("pytree structure error: different types at key path\n"
               "    in_axes tree root")
   with self.assertRaisesRegex(ValueError, expected):
     raise e('in_axes')