Ejemplo n.º 1
0
def graph_def_session(model, input_tensor_names, output_tensor_names,
                      **kwargs):
    """Build session with tf.compat.v1.GraphDef

    Args:
        model (tf.compat.v1.GraphDef): tf.compat.v1.GraphDef object 
        input_tensor_names (list of string): input_tensor_names of model
        output_tensor_names (list of string): output_tensor_names of model

     Returns:
        sess (tf.compat.v1.Session): tf.compat.v1.Session object
        input_tensor_names (list of string): validated input_tensor_names
        output_tensor_names (list of string): validated output_tensor_names
    """

    graph = tf.Graph()
    try:
        with graph.as_default():
            tf.import_graph_def(model, name='')
    except:
        input_tensor_names, output_tensor_names = validate_and_inference_input_output(\
            model, input_tensor_names, output_tensor_names)
        from lpot.adaptor.tf_utils.util import fix_ref_type_of_graph_def
        from lpot.adaptor.tf_utils.util import strip_unused_nodes
        model = fix_ref_type_of_graph_def(model)
        input_node_names = tensor_to_node(input_tensor_names)
        output_node_names = tensor_to_node(output_tensor_names)
        model = strip_unused_nodes(model, input_node_names, output_node_names)
        with graph.as_default():
            tf.import_graph_def(model, name='')

    return graph_session(graph, input_tensor_names, output_tensor_names,
                         **kwargs)
Ejemplo n.º 2
0
def graph_def_session(model, input_tensor_names, output_tensor_names,
                      **kwargs):
    graph = tf.Graph()
    try:
        with graph.as_default():
            tf.import_graph_def(model, name='')
    except:
        input_tensor_names, output_tensor_names = validate_and_inference_input_output(\
            model, input_tensor_names, output_tensor_names)
        from lpot.adaptor.tf_utils.util import fix_ref_type_of_graph_def
        from lpot.adaptor.tf_utils.util import strip_unused_nodes
        model = fix_ref_type_of_graph_def(model)
        input_node_names = tensor_to_node(input_tensor_names)
        output_node_names = tensor_to_node(output_tensor_names)
        model = strip_unused_nodes(model, input_node_names, output_node_names)
        with graph.as_default():
            tf.import_graph_def(model, name='')

    return graph_session(graph, input_tensor_names, output_tensor_names,
                         **kwargs)
Ejemplo n.º 3
0
 def do_transformation(self):
     from lpot.adaptor.tf_utils.util import fix_ref_type_of_graph_def
     from lpot.adaptor.tf_utils.util import strip_unused_nodes
     self.model = fix_ref_type_of_graph_def(self.model)
     return strip_unused_nodes(self.model, self.input_node_names,
                               self.output_node_names)