예제 #1
0
def load_graph(path):
    """
    Loads a TensorFlow frozen model.

    Args:
        path (Union[str, tf.Graph, tf.GraphDef]):
                A path to the frozen model, or a frozen TensorFlow graph or graphdef.

    Returns:
        tf.Graph: The TensorFlow graph
    """
    if isinstance(path, tf.Graph):
        return path

    if isinstance(path, str):
        graphdef = tf.compat.v1.GraphDef()

        import google

        try:
            graphdef.ParseFromString(
                util.load_file(path, description="GraphDef"))
        except google.protobuf.message.DecodeError:
            G_LOGGER.backtrace()
            G_LOGGER.critical(
                "Could not import TensorFlow GraphDef from: {:}. Is this a valid TensorFlow model?"
                .format(path))
    elif isinstance(path, tf.compat.v1.GraphDef):
        graphdef = path

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graphdef, name="")
        return graph
예제 #2
0
def fail_unavailable(what):
    G_LOGGER.backtrace()
    G_LOGGER.critical("{:} is not available on TensorRT version {:}.".format(
        what, trt.__version__))