Esempio n. 1
0
def variable_device(device, name):
    """Fix the variable device to colocate its ops."""
    if callable(device):
        var_name = tf.get_variable_scope().name + '/' + name
        var_def = graph_pb2.NodeDef(name=var_name, op='Variable')
        device = device(var_def)
    if device is None:
        device = ''
    return device
Esempio n. 2
0
def convert_variables_to_constants(sess, input_graph_def, output_node_names):
    """Replaces all the variables in a graph with constants of the same values.

  If you have a trained graph containing Variable ops, it can be convenient to
  convert them all to Const ops holding the same values. This makes it possible
  to describe the network fully with a single GraphDef file, and allows the
  removal of a lot of ops related to loading and saving the variables.

  Args:
    sess: Active TensorFlow session containing the variables.
    input_graph_def: GraphDef object holding the network.
    output_node_names: List of name strings for the result nodes of the graph.

  Returns:
    GraphDef containing a simplified version of the original.
  """
    found_variables = {}
    variable_names = []
    variable_dict_names = []
    for node in input_graph_def.node:
        if node.op == "Assign":
            variable_name = node.input[0]
            variable_dict_names.append(variable_name)
            variable_names.append(variable_name + ":0")
    if variable_names:
        returned_variables = sess.run(variable_names)
    else:
        returned_variables = []
    found_variables = dict(zip(variable_dict_names, returned_variables))
    logging.info("Frozen %d variables." % len(returned_variables))

    # This graph only includes the nodes needed to evaluate the output nodes, and
    # removes unneeded nodes like those involved in saving and assignment.
    inference_graph = extract_sub_graph(input_graph_def, output_node_names)

    output_graph_def = graph_pb2.GraphDef()
    how_many_converted = 0
    for input_node in inference_graph.node:
        output_node = graph_pb2.NodeDef()
        if input_node.name in found_variables:
            output_node.op = "Const"
            output_node.name = input_node.name
            dtype = input_node.attr["dtype"]
            data = found_variables[input_node.name]
            output_node.attr["dtype"].CopyFrom(dtype)
            output_node.attr["value"].CopyFrom(
                attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                    data, dtype=dtype.type, shape=data.shape)))
            how_many_converted += 1
        else:
            output_node.CopyFrom(input_node)
        output_graph_def.node.extend([output_node])
    print("Converted %d variables to const ops." % how_many_converted)
    return output_graph_def
Esempio n. 3
0
def _operator_to_node(shapes, op):
    assert op.name, op
    # Check for existance of __version__ for backwards compatibility
    n = tf.NodeDef() if hasattr(tf, '__version__') else graph_pb2.NodeDef()
    n.name = op.name
    n.input.extend(op.input)
    n.op = op.type
    n.device = _tf_device(op.device_option)
    if shapes:
        # Add shapes in order.
        for output in op.output:
            if output not in shapes:
                break
            _add_tf_shape(n.attr, shapes[output])
    for arg in op.arg:
        _set_tf_attr(n.attr, arg)
    return n
Esempio n. 4
0
def _blob_to_node(producing_ops, shapes, name):
    assert name
    # Check for existance of __version__ for backwards compatibility
    n = tf.NodeDef() if hasattr(tf, '__version__') else graph_pb2.NodeDef()
    n.name = name
    inputs = producing_ops.get(name, [])
    if inputs:
        n.op = 'Blob'
    else:
        n.op = 'Placeholder'
    n.input.extend('%s:%d' % (op.name, i) for op, i in inputs)
    if inputs:
        device = inputs[0][0].device_option
        if (all(input[0].device_option == device for input in inputs)):
            n.device = _tf_device(device)
    if shapes and name in shapes:
        _add_tf_shape(n.attr, shapes[name])
    return n
Esempio n. 5
0
 def _StripNode(self, nd):
     snode = graph_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input)
     if nd.device:
         snode.device = nd.device
     return snode
Esempio n. 6
0
def convert_variables_to_constants(sess, input_graph_def, output_node_names):
    """Replaces all the variables in a graph with constants of the same values.

  If you have a trained graph containing Variable ops, it can be convenient to
  convert them all to Const ops holding the same values. This makes it possible
  to describe the network fully with a single GraphDef file, and allows the
  removal of a lot of ops related to loading and saving the variables.

  Args:
    sess: Active TensorFlow session containing the variables.
    input_graph_def: GraphDef object holding the network.
    output_node_names: List of name strings for the result nodes of the graph.

  Returns:
    GraphDef containing a simplified version of the original.
  """
    print('call convert_variables')
    found_variables = {}
    variable_name_list = []
    found_variables_list = []
    print('search nodes...')
    for i, node in enumerate(input_graph_def.node):
        # print('node %s' % node)
        if node.op == "Assign":
            variable_name_list.append(node.input[0])
            sys.stdout.write(
                "\r%s" %
                "node: {0}/{1}".format(i + 1, len(input_graph_def.node)))
            sys.stdout.flush()
    print('')
    print('{0} nodes founded'.format(len(variable_name_list)))
    print('evaluate nodes..')
    found_variables_list = sess.run([v + ":0" for v in variable_name_list])
    print('insert values..')
    for i, v in enumerate(variable_name_list):
        found_variables[v] = found_variables_list[i]
        sys.stdout.write(
            "\r%s" % "node: {0}/{1}".format(i + 1, len(variable_name_list)))
        sys.stdout.flush()
    print('')

    # This graph only includes the nodes needed to evaluate the output nodes, and
    # removes unneeded nodes like those involved in saving and assignment.
    inference_graph = graph_util.extract_sub_graph(input_graph_def,
                                                   output_node_names)

    output_graph_def = graph_pb2.GraphDef()
    how_many_converted = 0
    for input_node in inference_graph.node:
        output_node = graph_pb2.NodeDef()
        if input_node.name in found_variables:
            output_node.op = "Const"
            output_node.name = input_node.name
            dtype = input_node.attr["dtype"]
            data = found_variables[input_node.name]
            output_node.attr["dtype"].CopyFrom(dtype)
            output_node.attr["value"].CopyFrom(
                attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
                    data, dtype=dtype.type, shape=data.shape)))
            how_many_converted += 1
        else:
            output_node.CopyFrom(input_node)
        output_graph_def.node.extend([output_node])
    print("Converted %d variables to const ops." % how_many_converted)
    return output_graph_def
Esempio n. 7
0
def remove_training_nodes(input_graph):
    """Prunes out nodes that aren't needed for inference.

  There are nodes like Identity and CheckNumerics that are only useful
  during training, and can be removed in graphs that will be used for
  nothing but inference. Here we identify and remove them, returning an
  equivalent graph. To be specific, CheckNumerics nodes are always removed, and
  Identity nodes that aren't involved in control edges are spliced out so that
  their input and outputs are directly connected.

  Args:
    input_graph: Model to analyze and prune.

  Returns:
    A list of nodes with the unnecessary ones removed.
  """

    types_to_remove = {"CheckNumerics": True}

    input_nodes = input_graph.node
    names_to_remove = {}
    for node in input_nodes:
        if node.op in types_to_remove:
            names_to_remove[node.name] = True

    nodes_after_removal = []
    for node in input_nodes:
        if node.name in names_to_remove:
            continue
        new_node = graph_pb2.NodeDef()
        new_node.CopyFrom(node)
        input_before_removal = node.input
        del new_node.input[:]
        for full_input_name in input_before_removal:
            input_name = re.sub(r"^\^", "", full_input_name)
            if input_name in names_to_remove:
                continue
            new_node.input.append(full_input_name)
        nodes_after_removal.append(new_node)

    types_to_splice = {"Identity": True}
    names_to_splice = {}
    for node in nodes_after_removal:
        if node.op in types_to_splice:
            # We don't want to remove nodes that have control edge inputs, because
            # they might be involved in subtle dependency issues that removing them
            # will jeopardize.
            has_control_edge = False
            for input_name in node.input:
                if re.match(r"^\^", input_name):
                    has_control_edge = True
            if not has_control_edge:
                names_to_splice[node.name] = node.input[0]

    nodes_after_splicing = []
    for node in nodes_after_removal:
        if node.name in names_to_splice:
            continue
        new_node = graph_pb2.NodeDef()
        new_node.CopyFrom(node)
        input_before_removal = node.input
        del new_node.input[:]
        for full_input_name in input_before_removal:
            input_name = re.sub(r"^\^", "", full_input_name)
            if input_name in names_to_splice:
                new_node.input.append(names_to_splice[input_name])
            else:
                new_node.input.append(full_input_name)
        nodes_after_splicing.append(new_node)

    output_graph = graph_pb2.GraphDef()
    output_graph.node.extend(nodes_after_splicing)
    return output_graph