Example #1
0
def checkpoint_session(model, input_tensor_names, output_tensor_names,
                       **kwargs):

    assert output_tensor_names is not None and len(output_tensor_names) > 0, \
        'outputs should not be None of checkpoint....'

    ckpt_prefix = [os.path.splitext(i)[0] for i in os.listdir(model) \
        if i.endswith(".meta")][0]

    config = tf.compat.v1.ConfigProto()
    config.use_per_session_threads = 1
    config.inter_op_parallelism_threads = 1
    graph = tf.Graph()
    sess = tf.compat.v1.Session(graph=graph, config=config)
    with graph.as_default():
        saver = tf.compat.v1.train.import_meta_graph(\
            os.path.join(model, ckpt_prefix + '.meta'), clear_devices=True)

        sess.run(tf.compat.v1.global_variables_initializer())
        saver.restore(sess, os.path.join(model, ckpt_prefix))

    from lpot.adaptor.tf_utils.util import get_input_node_names
    input_tensor_names = input_tensor_names if validate_graph_node(\
        sess.graph.as_graph_def(), tensor_to_node(input_tensor_names)) else \
        get_input_node_names(sess.graph.as_graph_def())

    return sess, input_tensor_names, output_tensor_names
Example #2
0
def validate_and_inference_input_output(graph_def, \
    input_tensor_names, output_tensor_names):

    from lpot.adaptor.tf_utils.util import get_input_node_names
    input_tensor_names = input_tensor_names if validate_graph_node(\
        graph_def, tensor_to_node(input_tensor_names)) else \
        get_input_node_names(graph_def)

    from lpot.adaptor.tf_utils.util import get_output_node_names
    output_tensor_names = output_tensor_names if validate_graph_node(\
        graph_def, tensor_to_node(output_tensor_names)) else \
        get_output_node_names(graph_def)
    return input_tensor_names, output_tensor_names
Example #3
0
def validate_and_inference_input_output(graph_def, \
    input_tensor_names, output_tensor_names):
    """validate and inference the input and output tensor names of graph_def

    Args:
        graph_def (tf.compat.v1.GraphDef): tf.compat.v1.GraphDef object
        input_tensor_names (list of string): input_tensor_names of graph_def
        output_tensor_names (list of string): output_tensor_names of graph_def

    Returns:
        input_tensor_names (list of string): validated input_tensor_names
        output_tensor_names (list of string): validated output_tensor_names
    """

    from lpot.adaptor.tf_utils.util import get_input_node_names
    input_tensor_names = input_tensor_names if validate_graph_node(\
        graph_def, tensor_to_node(input_tensor_names)) else \
        get_input_node_names(graph_def)

    from lpot.adaptor.tf_utils.util import get_output_node_names
    output_tensor_names = output_tensor_names if validate_graph_node(\
        graph_def, tensor_to_node(output_tensor_names)) else \
        get_output_node_names(graph_def)
    return input_tensor_names, output_tensor_names