def __init__(self): self.platform = sys.platform self.tf_version = utils.get_tf_version() self.opset = int(os.environ.get("TF2ONNX_TEST_OPSET", constants.PREFERRED_OPSET)) self.target = os.environ.get("TF2ONNX_TEST_TARGET", ",".join(constants.DEFAULT_TARGET)).split(',') self.backend = os.environ.get("TF2ONNX_TEST_BACKEND", "onnxruntime") self.backend_version = self._get_backend_version() self.log_level = logging.WARNING self.temp_dir = utils.get_temp_directory()
def reload_tf_graph(tf_graph): """Invoke tensorflow cpp shape inference by reloading graph_def.""" # invoke c api if tf version is below 1.8 if utils.get_tf_version() < LooseVersion("1.8"): logger.debug( "On TF < 1.8, graph is constructed by python API, " \ "which doesn't invoke shape inference, please set " \ "TF_C_API_GRAPH_CONSTRUCTION=1 to enable it" ) graph_def = tf_graph.as_graph_def(add_shapes=True) with tf.Graph().as_default() as inferred_graph: tf.import_graph_def(graph_def, name="") return inferred_graph
def infer_shape(tf_graph, shape_override): """Infer shape for TF graph with shape_override set first.""" if shape_override: logger.info("Apply shape override:") for name, shape in shape_override.items(): logger.info("\tSet %s shape to %s", name, shape) tf_graph.get_tensor_by_name(name).set_shape(shape) tf_graph = reload_tf_graph(tf_graph) tf_graph = infer_shape_for_graph(tf_graph) op_outputs_with_none_shape = check_shape_for_tf_graph(tf_graph) if op_outputs_with_none_shape: if utils.get_tf_version() > LooseVersion("1.5.0"): for op, outs in op_outputs_with_none_shape.items(): logger.warning("Cannot infer shape for %s: %s", op, ",".join(outs)) tf_graph = infer_shape_for_graph_legacy(tf_graph) return tf_graph