def test_shape_utils(self): self.assertEqual(utils.merge_shapes(None, None), None) self.assertEqual(utils.merge_shapes([], None), []) self.assertEqual(utils.merge_shapes(None, [1, 2, 3]), [1, 2, 3]) self.assertEqual(utils.merge_shapes([1, 3], [None, 3]), [1, 3]) self.assertEqual(utils.merge_shapes([1, None, 3], (-1, 2, "unk")), [1, 2, 3]) self.assertTrue(utils.are_shapes_compatible(None, [])) self.assertTrue(utils.are_shapes_compatible([1, None, 3], (-1, 2, "unk"))) self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (2, 3))) self.assertFalse(utils.are_shapes_compatible([1, 2, 3], (4, 5, 6))) self.assertTrue(utils.are_shapes_equal(None, None)) self.assertFalse(utils.are_shapes_equal(None, [])) self.assertTrue(utils.are_shapes_equal([1, 2, 3], (1, 2, 3)))
def update_node_shape_dtype(self, node, override=False): """Try the best to infer shapes and dtypes for outputs of the node, by default, we respect TF shapes and dtypes. """ if node.is_const() or node.is_graph_input(): return # NOTE: only support onnx node for now if not utils.is_onnx_domain(node.domain): return logger.debug("Infer shape and dtype for [%s]", node.name) # NOTE: shape inference for some ops need the input values of the op, e.g., Reshape # op needs the "Shape" value to infer output shape. initializers = [] for i, inp in enumerate(node.inputs): if not inp: if logger.isEnabledFor(logging.VERBOSE): logger.warning( "[%s] infer a inexistent node: [%s], please check the code", node.name, node.input[i] ) continue if inp.is_const(): t = inp.get_attr("value") tensor = helper.get_attribute_value(t) tensor.name = inp.output[0] initializers.append(tensor) input_shapes = [self.get_shape(i) for i in node.input] input_dtypes = [self.get_dtype(i) for i in node.input] shapes, dtypes = infer_onnx_shape_dtype(node, self._opset, input_shapes, input_dtypes, initializers) if not shapes or not dtypes: return for output, shape, dtype in zip(node.output, shapes, dtypes): if dtype == TensorProto.UNDEFINED: logger.debug("Inferred dtype for [%s, type: %s] is UNDEFINED, SKIP", node.name, node.type) else: existing_dtype = self.get_dtype(output) if existing_dtype is not None and existing_dtype != dtype: if override: logger.warning("Override dtype of %s from %s to %s", output, existing_dtype, dtype) else: dtype = existing_dtype self.set_dtype(output, dtype) logger.debug("Set dtype of [%s] to %s", output, dtype) if shape is None: logger.debug("Inferred shape for [%s, type: %s] is None, SKIP", node.name, node.type) else: existing_shape = self.get_shape(output) if existing_shape is not None and not utils.are_shapes_equal(existing_shape, shape): if override: logger.warning("Override shape of %s from %s to %s", output, existing_shape, shape) else: shape = existing_shape self.set_shape(output, shape) logger.debug("Set shape of [%s] to %s", output, shape)