Beispiel #1
0
 def __init__(self):
     self.platform = sys.platform
     self.tf_version = utils.get_tf_version()
     self.opset = int(os.environ.get("TF2ONNX_TEST_OPSET", constants.PREFERRED_OPSET))
     self.target = os.environ.get("TF2ONNX_TEST_TARGET", ",".join(constants.DEFAULT_TARGET)).split(',')
     self.backend = os.environ.get("TF2ONNX_TEST_BACKEND", "onnxruntime")
     self.backend_version = self._get_backend_version()
     self.log_level = logging.WARNING
     self.temp_dir = utils.get_temp_directory()
def reload_tf_graph(tf_graph):
    """Invoke tensorflow cpp shape inference by reloading graph_def."""
    # invoke c api if tf version is below 1.8
    if utils.get_tf_version() < LooseVersion("1.8"):
        logger.debug(
            "On TF < 1.8, graph is constructed by python API, " \
            "which doesn't invoke shape inference, please set " \
            "TF_C_API_GRAPH_CONSTRUCTION=1 to enable it"
        )

    graph_def = tf_graph.as_graph_def(add_shapes=True)
    with tf.Graph().as_default() as inferred_graph:
        tf.import_graph_def(graph_def, name="")
    return inferred_graph
def infer_shape(tf_graph, shape_override):
    """Infer shape for TF graph with shape_override set first."""
    if shape_override:
        logger.info("Apply shape override:")
        for name, shape in shape_override.items():
            logger.info("\tSet %s shape to %s", name, shape)
            tf_graph.get_tensor_by_name(name).set_shape(shape)
        tf_graph = reload_tf_graph(tf_graph)

    tf_graph = infer_shape_for_graph(tf_graph)

    op_outputs_with_none_shape = check_shape_for_tf_graph(tf_graph)
    if op_outputs_with_none_shape:
        if utils.get_tf_version() > LooseVersion("1.5.0"):
            for op, outs in op_outputs_with_none_shape.items():
                logger.warning("Cannot infer shape for %s: %s", op,
                               ",".join(outs))
        tf_graph = infer_shape_for_graph_legacy(tf_graph)

    return tf_graph