def test_shape_utils(self): self.assertEqual(utils.merge_shapes(None, None), None) self.assertEqual(utils.merge_shapes([], None), []) self.assertEqual(utils.merge_shapes(None, [1, 2, 3]), [1, 2, 3]) self.assertEqual(utils.merge_shapes([1, 3], [None, 3]), [1, 3]) self.assertEqual(utils.merge_shapes([1, None, 3], (-1, 2, "unk")), [1, 2, 3]) self.assertTrue(utils.are_shapes_compatible(None, [])) self.assertTrue(utils.are_shapes_compatible([1, None, 3], (-1, 2, "unk"))) self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (2, 3))) self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (4, 5, 6))) self.assertTrue(utils.are_shapes_equal(None, None)) self.assertFalse(utils.are_shapes_equal(None, [])) self.assertTrue(utils.are_shapes_equal([1, 2, 3], (1, 2, 3)))
def _get_output_shape_dtype(self, cond_context): output_shapes = [] output_dtypes = [] for i, _ in enumerate(cond_context.true_branch_context.output): true_output = cond_context.true_branch_context.output[i] false_output = cond_context.false_branch_context.output[i] true_shape = self.g.get_shape(true_output) true_dtype = self.g.get_dtype(true_output) false_shape = self.g.get_shape(false_output) false_dtype = self.g.get_dtype(false_output) if not utils.are_shapes_compatible(true_shape, false_shape): raise RuntimeError( "the shape of outputs {} and {} mismatch: {}, {}".format( true_output, false_output, true_shape, false_shape ) ) if true_dtype != false_dtype: raise RuntimeError( "the dtype of outputs {} and {} mismatch: {}, {}".format( true_output, false_output, true_dtype, false_dtype ) ) # in tf, the shape of different branched can be different, # for example output shape of branch A can be [-1] while branch B can be [1]. # Under this case, we should set output shape to be [-1] output_shapes.append(utils.create_vague_shape_like(utils.merge_shapes(true_shape, false_shape))) output_dtypes.append(true_dtype) return output_shapes, output_dtypes
def _compare_shape_for_op(self, op1, op2): """Align outputs of op2 to op1.""" for out1, out2 in zip(op1.outputs, op2.outputs): expected_shape = get_tf_tensor_shape(out1) if out1 is not None: actual_shape = get_tf_tensor_shape(out2) self.assertTrue( utils.are_shapes_compatible(expected_shape, actual_shape))
def _run_test_case(self, graph, feed_dict): """Run model with onnxruntime and compare results' shape with internal shape inference.""" outputs = graph.outputs results = self.run_backend(graph, outputs, feed_dict) for actual, inferred in zip(results, outputs): actual_shape = actual.shape inferred_shape = tuple(graph.get_shape(inferred)) self.assertTrue(utils.are_shapes_compatible(actual_shape, inferred_shape)) actual_dtype = actual.dtype inferred_dtype = utils.ONNX_TO_NUMPY_DTYPE[graph.get_dtype(inferred)] self.assertEqual(actual_dtype, inferred_dtype)
def _get_output_shape_dtype(self, cond_context): output_shapes = [] output_dtypes = [] for i, _ in enumerate(cond_context.true_branch_context.output): true_output = cond_context.true_branch_context.output[i] false_output = cond_context.false_branch_context.output[i] true_shape = self.g.get_shape(true_output) true_dtype = self.g.get_dtype(true_output) false_shape = self.g.get_shape(false_output) false_dtype = self.g.get_dtype(false_output) if not utils.are_shapes_compatible(true_shape, false_shape): raise RuntimeError( "the shape of outputs {} and {} mismatch: {}, {}".format( true_output, false_output, true_shape, false_shape)) if true_dtype != false_dtype: raise RuntimeError( "the dtype of outputs {} and {} mismatch: {}, {}".format( true_output, false_output, true_dtype, false_dtype)) output_shapes.append(utils.merge_shapes(true_shape, false_shape)) output_dtypes.append(true_dtype) return output_shapes, output_dtypes