Ejemplo n.º 1
0
    def test_shape_utils(self):
        self.assertEqual(utils.merge_shapes(None, None), None)
        self.assertEqual(utils.merge_shapes([], None), [])
        self.assertEqual(utils.merge_shapes(None, [1, 2, 3]), [1, 2, 3])
        self.assertEqual(utils.merge_shapes([1, 3], [None, 3]), [1, 3])
        self.assertEqual(utils.merge_shapes([1, None, 3], (-1, 2, "unk")), [1, 2, 3])

        self.assertTrue(utils.are_shapes_compatible(None, []))
        self.assertTrue(utils.are_shapes_compatible([1, None, 3], (-1, 2, "unk")))
        self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (2, 3)))
        self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (4, 5, 6)))

        self.assertTrue(utils.are_shapes_equal(None, None))
        self.assertFalse(utils.are_shapes_equal(None, []))
        self.assertTrue(utils.are_shapes_equal([1, 2, 3], (1, 2, 3)))
Ejemplo n.º 2
0
    def update_node_shape_dtype(self, node, override=False):
        """Try the best to infer shapes and dtypes for outputs of the node,
        by default, we respect TF shapes and dtypes.
        """
        if node.is_const() or node.is_graph_input():
            return
        # NOTE: only support onnx node for now
        if not utils.is_onnx_domain(node.domain):
            return

        logger.debug("Infer shape and dtype for [%s]", node.name)
        # NOTE: shape inference for some ops need the input values of the op, e.g., Reshape
        # op needs the "Shape" value to infer output shape.
        initializers = []
        for i, inp in enumerate(node.inputs):
            if not inp:
                if logger.isEnabledFor(logging.VERBOSE):
                    logger.warning(
                        "[%s] infer a inexistent node: [%s], please check the code",
                        node.name, node.input[i]
                    )
                continue
            if inp.is_const():
                t = inp.get_attr("value")
                tensor = helper.get_attribute_value(t)
                tensor.name = inp.output[0]
                initializers.append(tensor)

        input_shapes = [self.get_shape(i) for i in node.input]
        input_dtypes = [self.get_dtype(i) for i in node.input]

        shapes, dtypes = infer_onnx_shape_dtype(node, self._opset, input_shapes, input_dtypes, initializers)
        if not shapes or not dtypes:
            return

        for output, shape, dtype in zip(node.output, shapes, dtypes):
            if dtype == TensorProto.UNDEFINED:
                logger.debug("Inferred dtype for [%s, type: %s] is UNDEFINED, SKIP", node.name, node.type)
            else:
                existing_dtype = self.get_dtype(output)
                if existing_dtype is not None and existing_dtype != dtype:
                    if override:
                        logger.warning("Override dtype of %s from %s to %s", output, existing_dtype, dtype)
                    else:
                        dtype = existing_dtype
                self.set_dtype(output, dtype)
                logger.debug("Set dtype of [%s] to %s", output, dtype)

            if shape is None:
                logger.debug("Inferred shape for [%s, type: %s] is None, SKIP", node.name, node.type)
            else:
                existing_shape = self.get_shape(output)
                if existing_shape is not None and not utils.are_shapes_equal(existing_shape, shape):
                    if override:
                        logger.warning("Override shape of %s from %s to %s", output, existing_shape, shape)
                    else:
                        shape = existing_shape
                self.set_shape(output, shape)
                logger.debug("Set shape of [%s] to %s", output, shape)