コード例 #1
0
    def test_shape_utils(self):
        self.assertEqual(utils.merge_shapes(None, None), None)
        self.assertEqual(utils.merge_shapes([], None), [])
        self.assertEqual(utils.merge_shapes(None, [1, 2, 3]), [1, 2, 3])
        self.assertEqual(utils.merge_shapes([1, 3], [None, 3]), [1, 3])
        self.assertEqual(utils.merge_shapes([1, None, 3], (-1, 2, "unk")), [1, 2, 3])

        self.assertTrue(utils.are_shapes_compatible(None, []))
        self.assertTrue(utils.are_shapes_compatible([1, None, 3], (-1, 2, "unk")))
        self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (2, 3)))
        self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (4, 5, 6)))

        self.assertTrue(utils.are_shapes_equal(None, None))
        self.assertFalse(utils.are_shapes_equal(None, []))
        self.assertTrue(utils.are_shapes_equal([1, 2, 3], (1, 2, 3)))
コード例 #2
0
 def _get_output_shape_dtype(self, cond_context):
     output_shapes = []
     output_dtypes = []
     for i, _ in enumerate(cond_context.true_branch_context.output):
         true_output = cond_context.true_branch_context.output[i]
         false_output = cond_context.false_branch_context.output[i]
         true_shape = self.g.get_shape(true_output)
         true_dtype = self.g.get_dtype(true_output)
         false_shape = self.g.get_shape(false_output)
         false_dtype = self.g.get_dtype(false_output)
         if not utils.are_shapes_compatible(true_shape, false_shape):
             raise RuntimeError(
                 "the shape of outputs {} and {} mismatch: {}, {}".format(
                     true_output,
                     false_output,
                     true_shape,
                     false_shape
                 )
             )
         if true_dtype != false_dtype:
             raise RuntimeError(
                 "the dtype of outputs {} and {} mismatch: {}, {}".format(
                     true_output,
                     false_output,
                     true_dtype,
                     false_dtype
                 )
             )
         # in tf, the shape of different branched can be different,
         # for example output shape of branch A can be [-1] while branch B can be [1].
         # Under this case, we should set output shape to be [-1]
         output_shapes.append(utils.create_vague_shape_like(utils.merge_shapes(true_shape, false_shape)))
         output_dtypes.append(true_dtype)
     return output_shapes, output_dtypes
コード例 #3
0
 def _compare_shape_for_op(self, op1, op2):
     """Align outputs of op2 to op1."""
     for out1, out2 in zip(op1.outputs, op2.outputs):
         expected_shape = get_tf_tensor_shape(out1)
         if out1 is not None:
             actual_shape = get_tf_tensor_shape(out2)
             self.assertTrue(
                 utils.are_shapes_compatible(expected_shape, actual_shape))
コード例 #4
0
    def _run_test_case(self, graph, feed_dict):
        """Run model with onnxruntime and compare results' shape with internal shape inference."""
        outputs = graph.outputs
        results = self.run_backend(graph, outputs, feed_dict)

        for actual, inferred in zip(results, outputs):
            actual_shape = actual.shape
            inferred_shape = tuple(graph.get_shape(inferred))
            self.assertTrue(utils.are_shapes_compatible(actual_shape, inferred_shape))

            actual_dtype = actual.dtype
            inferred_dtype = utils.ONNX_TO_NUMPY_DTYPE[graph.get_dtype(inferred)]
            self.assertEqual(actual_dtype, inferred_dtype)
コード例 #5
0
 def _get_output_shape_dtype(self, cond_context):
     output_shapes = []
     output_dtypes = []
     for i, _ in enumerate(cond_context.true_branch_context.output):
         true_output = cond_context.true_branch_context.output[i]
         false_output = cond_context.false_branch_context.output[i]
         true_shape = self.g.get_shape(true_output)
         true_dtype = self.g.get_dtype(true_output)
         false_shape = self.g.get_shape(false_output)
         false_dtype = self.g.get_dtype(false_output)
         if not utils.are_shapes_compatible(true_shape, false_shape):
             raise RuntimeError(
                 "the shape of outputs {} and {} mismatch: {}, {}".format(
                     true_output, false_output, true_shape, false_shape))
         if true_dtype != false_dtype:
             raise RuntimeError(
                 "the dtype of outputs {} and {} mismatch: {}, {}".format(
                     true_output, false_output, true_dtype, false_dtype))
         output_shapes.append(utils.merge_shapes(true_shape, false_shape))
         output_dtypes.append(true_dtype)
     return output_shapes, output_dtypes