示例#1
0
    def test_shape_utils(self):
        self.assertEqual(utils.merge_shapes(None, None), None)
        self.assertEqual(utils.merge_shapes([], None), [])
        self.assertEqual(utils.merge_shapes(None, [1, 2, 3]), [1, 2, 3])
        self.assertEqual(utils.merge_shapes([1, 3], [None, 3]), [1, 3])
        self.assertEqual(utils.merge_shapes([1, None, 3], (-1, 2, "unk")), [1, 2, 3])

        self.assertTrue(utils.are_shapes_compatible(None, []))
        self.assertTrue(utils.are_shapes_compatible([1, None, 3], (-1, 2, "unk")))
        self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (2, 3)))
        self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (4, 5, 6)))

        self.assertTrue(utils.are_shapes_equal(None, None))
        self.assertFalse(utils.are_shapes_equal(None, []))
        self.assertTrue(utils.are_shapes_equal([1, 2, 3], (1, 2, 3)))
示例#2
0
 def _get_output_shape_dtype(self, cond_context):
     output_shapes = []
     output_dtypes = []
     for i, _ in enumerate(cond_context.true_branch_context.output):
         true_output = cond_context.true_branch_context.output[i]
         false_output = cond_context.false_branch_context.output[i]
         true_shape = self.g.get_shape(true_output)
         true_dtype = self.g.get_dtype(true_output)
         false_shape = self.g.get_shape(false_output)
         false_dtype = self.g.get_dtype(false_output)
         if not utils.are_shapes_compatible(true_shape, false_shape):
             raise RuntimeError(
                 "the shape of outputs {} and {} mismatch: {}, {}".format(
                     true_output,
                     false_output,
                     true_shape,
                     false_shape
                 )
             )
         if true_dtype != false_dtype:
             raise RuntimeError(
                 "the dtype of outputs {} and {} mismatch: {}, {}".format(
                     true_output,
                     false_output,
                     true_dtype,
                     false_dtype
                 )
             )
         # in tf, the shape of different branched can be different,
         # for example output shape of branch A can be [-1] while branch B can be [1].
         # Under this case, we should set output shape to be [-1]
         output_shapes.append(utils.create_vague_shape_like(utils.merge_shapes(true_shape, false_shape)))
         output_dtypes.append(true_dtype)
     return output_shapes, output_dtypes
 def _get_output_shape_dtype(self, cond_context):
     output_shapes = []
     output_dtypes = []
     for i, _ in enumerate(cond_context.true_branch_context.output):
         true_output = cond_context.true_branch_context.output[i]
         false_output = cond_context.false_branch_context.output[i]
         true_shape = self.g.get_shape(true_output)
         true_dtype = self.g.get_dtype(true_output)
         false_shape = self.g.get_shape(false_output)
         false_dtype = self.g.get_dtype(false_output)
         if not utils.are_shapes_compatible(true_shape, false_shape):
             raise RuntimeError(
                 "the shape of outputs {} and {} mismatch: {}, {}".format(
                     true_output, false_output, true_shape, false_shape))
         if true_dtype != false_dtype:
             raise RuntimeError(
                 "the dtype of outputs {} and {} mismatch: {}, {}".format(
                     true_output, false_output, true_dtype, false_dtype))
         output_shapes.append(utils.merge_shapes(true_shape, false_shape))
         output_dtypes.append(true_dtype)
     return output_shapes, output_dtypes