Example #1
0
def setNodeConstValue(gdef, node, value):
    output_node = tf.NodeDef()
    output_node.name = node.name
    output_node.op = node.op
    dtype = node.attr["dtype"].type
    output_node.attr["dtype"].type = dtype
    output_node.attr["value"].CopyFrom(
        tf.AttrValue(
            tensor=tf.contrib.util.make_tensor_proto(value, dtype=dtype)))
    node.CopyFrom(output_node)
    return node
Example #2
0
  def quantize_weights(self, input_graph, quantization_mode):
    """Quantize float Const ops.

    There are two modes of operations, both replace float Const ops with
    quantized values.
    1. If quantization_mode is "weights_rounded", this function replaces float
    Const ops with quantized float Const ops - same as the original op, but
    float values being mapped to the center of one of 1<<FLAGS.bitdepth buckets.
    This does not change the raw model size, but compression algorithms such as
    zip (as used for compressing apks) or bzip2 will achieve a very good
    compression ratio.
    2. For other quantization modes ("MIN_COMBINED" or "MIN_FIRST"), float
    Const ops are quantized and replaced by a tuple of four ops to perform
    the dequantization at runtime:
    * eight-bit Const (bucket indices, same shape as original float Const op
    * two float Const ops (min and max value of original float Const op)
    * Dequantize op to convert the eight-bit consts to float tensors.
    The quantization mode is important because we see accuracy problems when
    quantizing weights for different situations depending on the algorithm
    used. We haven't figured out exactly what the underlying cause is yet,
    unfortunately.

    Args:
      input_graph: A GraphDef of the model containing float Const ops.
      quantization_mode: How to quantize and dequantize the values.

    Returns:
      A GraphDef of the converted graph.

    Raises:
      ValueError: If quantization_mode is unsupported.
    """
    output_graph = tf.GraphDef()
    for input_node in input_graph.node:
      should_quantize = False
      if input_node.op == "Const":
        dtype = tf.as_dtype(input_node.attr["dtype"].type)
        if dtype == tf.float32:
          should_quantize = True
      if should_quantize:
        if quantization_mode == "weights_rounded":
          output_graph.node.extend(quantize_weight_rounded(input_node))
        elif quantization_mode in (b"MIN_COMBINED", b"MIN_FIRST"):
          output_graph.node.extend(quantize_weight_eightbit(input_node,
                                                            quantization_mode))
        else:
          raise ValueError("Unsupported quantization mode %s." %
                           quantization_mode)
      else:
        output_node = tf.NodeDef()
        output_node.CopyFrom(input_node)
        output_graph.node.extend([output_node])
    return output_graph
Example #3
0
def rename_nodes(graph_def: tf.GraphDef, new_names: Dict[str, str]) -> tf.GraphDef:
    """Rename items in the graph to new ones defined in new_names

    Parameters
    ----------
    graph_def : tf.GraphDef
        Graph Definition
    new_names : Dict[str, str]
        Mapping old name -> new name

    Returns
    -------
    tf.GraphDef
        A copy of the input GraphDef with renamed nodes

    Raises
    ------
    TensorsNotFoundError
        If new_names refers to an node not found in graph_def
    """
    # Create copy of each node with a new name
    nodes = []
    for node in graph_def.node:
        new_node = tf.NodeDef()
        new_node.CopyFrom(node)
        nodes.append(new_node)
        match = re.match(r"^(?:cond(?:_\d+)?/)?(.+?)(?:_\d+)?$", node.name)
        if match and match.groups()[0] in new_names:
            new_name = new_names[match.groups()[0]]
            new_node.name = new_name
            LOGGER.info(f"Node renamed: {node.name} -> {new_node.name}")

    # Check that all new names were used
    if not set(new_names.values()) <= set(node.name for node in nodes):
        missing = set(new_names.values()) - set(node.name for node in nodes)
        raise TensorsNotFoundError(missing)

    # Update node references (inputs and location) to renamed nodes
    for node in nodes:
        for idx, name in enumerate(node.input):
            node.input[idx] = new_names[name] if name in new_names else name
        if "_class" in node.attr:
            attr = node.attr["_class"]
            for idx, item in enumerate(attr.list.s):
                loc_match = re.match(r"^loc:@(.+)$", item.decode())
                if loc_match and loc_match.groups()[0] in new_names:
                    new_name = new_names[loc_match.groups()[0]]
                    attr.list.s[idx] = f"loc:@{new_name}".encode()

    # Create Graph with renamed nodes
    new_graph = tf.GraphDef()
    new_graph.node.extend(nodes)
    return new_graph
Example #4
0
 def to_node_def(self):
   ret = tf.NodeDef()
   ret.name = self.name
   ret.op = self.op_name
   for input_tensor in self.inputs:
     ret.input.append(input_tensor.name)
   for control_input_node in self.control_inputs:
     ret.input.append("^" + control_input_node.name)
   ret.device = self.device
   for (attr_name, attr_value) in self._attributes:
     # Funky syntax for setting a field of a union in a protobuf
     ret.attr[attr_name].CopyFrom(_python_type_to_attr_value(attr_value))
   return ret
Example #5
0
 def init_node(self, node_name):
     """
     according to the node_name, find the node from the graph
     :param node_name:
     :return:
     """
     new_node = tf.NodeDef()
     for node in self.graph_pb.node:
         if node.name == node_name:
             new_node = node
             return new_node
     if new_node.name != node_name:
         print("There isn't this node in graph")
         return
Example #6
0
def make_placeholders(graph_def: tf.GraphDef, names: List[str]) -> tf.GraphDef:
    """Create placeholders for names and remove other placeholders

    Parameters
    ----------
    graph_def : tf.GraphDef
        Graph definition
    names : List[str]
        Names of placeholders to keep / create for this graph

    Returns
    -------
    tf.GraphDef
        A copy of the input GraphDef with new placeholders

    Raises
    ------
    ValueError
        If names refers to a node that is not present
    """
    # Create copy of each node and change to Placeholder if in names
    nodes = []
    for node in graph_def.node:
        if node.name not in names and node.op == "Placeholder":
            LOGGER.info(f"Removing placeholder {node.name}")
            continue
        new_node = tf.NodeDef()
        if node.name in names and node.op != "Placeholder":
            LOGGER.info(f"Creating placeholder {node.name}")
            new_node.name = node.name
            new_node.op = "Placeholder"
            new_node.attr["shape"].CopyFrom(
                tf.AttrValue(shape=node.attr["_output_shapes"].list.shape[0]))
            new_node.attr["dtype"].CopyFrom(node.attr["T"])
        else:
            new_node.CopyFrom(node)
        nodes.append(new_node)

    # Check that all expected placeholders have been found
    if not set(names) <= set(node.name for node in nodes):
        raise ValueError(
            f"Missing placeholders: {set(names) - set(node.name for node in nodes)}"
        )

    # Create Graph with renamed nodes
    new_graph = tf.GraphDef()
    new_graph.node.extend(nodes)
    return new_graph
Example #7
0
def create_new_node(input_weight_a, intput_weight_b, old_node):
    import numpy as np
    #w_init = np.random.randn(131072, 103).astype(np.float32)

    merge_weight = np.concatenate((input_weight_a, intput_weight_b), axis=1)
    #w = tf.Variable(tf.convert_to_tensor(w_init))
    tensor_proto = tf.make_tensor_proto(merge_weight)

    new_node = tf.NodeDef(name=old_node.name,
                          op='Const',
                          attr={
                              'value': tf.AttrValue(tensor=tensor_proto),
                              'dtype': tf.AttrValue(type='DT_FLOAT')
                          })
    #new_node.input.extend([' model/classifier_block/flatten', 'model/classifier_block/dense_1/kernel/read'])
    return new_node
Example #8
0
def _operator_to_node(shapes, op):
    assert op.name, op
    # Check for existance of __version__ for backwards compatibility
    n = tf.NodeDef() if hasattr(tf, '__version__') else graph_pb2.NodeDef()
    n.name = op.name
    n.input.extend(op.input)
    n.op = op.type
    n.device = _tf_device(op.device_option)
    if shapes:
        # Add shapes in order.
        for output in op.output:
            if output not in shapes:
                break
            _add_tf_shape(n.attr, shapes[output])
    for arg in op.arg:
        _set_tf_attr(n.attr, arg)
    return n
Example #9
0
 def test_add_control_deps_for_init_op(self):
   graph_def = tf.compat.v1.GraphDef(node=[
       tf.NodeDef(name='foo', input=[]),
       tf.NodeDef(name='bar', input=['foo']),
       tf.NodeDef(name='baz', input=['foo', 'bar']),
       tf.NodeDef(name='bak', input=['bar', '^abc']),
       tf.NodeDef(name='abc', input=['def:0']),
       tf.NodeDef(name='def', input=['^ghi']),
       tf.NodeDef(name='ghi', input=[]),
   ])
   new_graph_def = graph_utils.add_control_deps_for_init_op(graph_def, 'abc')
   self.assertEqual(
       ','.join('{}({})'.format(node.name, ','.join(node.input))
                for node in new_graph_def.node),
       'foo(^abc),bar(foo,^abc),baz(foo,bar,^abc),'
       'bak(bar,^abc),abc(def:0),def(^ghi),ghi()')
Example #10
0
 def quantize_node(self, input_node):
     """Handles quantizing a single node."""
     input_name = input_node.name
     if input_name in self.already_quantized:
         return
     self.already_quantized[input_name] = True
     original_input_name = input_name + "_original"
     reshape_name = input_name + "_reshape"
     reshape_dims_name = input_name + "_reshape_dims"
     max_name = input_name + "_max"
     min_name = input_name + "_min"
     dims_name = input_name + "_dims"
     quantize_name = input_name + "_quantize"
     dequantize_name = input_name
     original_input_node = tf.NodeDef()
     original_input_node.CopyFrom(input_node)
     original_input_node.name = original_input_name
     self.add_output_graph_node(original_input_node)
     reshape_dims_node = create_constant_node(reshape_dims_name, -1,
                                              tf.int32, [1])
     self.add_output_graph_node(reshape_dims_node)
     reshape_node = create_node("Reshape", reshape_name,
                                [original_input_name, reshape_dims_name])
     set_attr_dtype(reshape_node, "T", tf.float32)
     self.add_output_graph_node(reshape_node)
     dims_node = create_constant_node(dims_name, 0, tf.int32, [1])
     self.add_output_graph_node(dims_node)
     max_node = create_node("Max", max_name, [reshape_name, dims_name])
     set_attr_dtype(max_node, "T", tf.float32)
     set_attr_bool(max_node, "keep_dims", False)
     self.add_output_graph_node(max_node)
     min_node = create_node("Min", min_name, [reshape_name, dims_name])
     set_attr_dtype(min_node, "T", tf.float32)
     set_attr_bool(min_node, "keep_dims", False)
     self.add_output_graph_node(min_node)
     quantize_node = create_node("Quantize", quantize_name,
                                 [original_input_name, min_name, max_name])
     set_attr_dtype(quantize_node, "T", tf.quint8)
     set_attr_string(quantize_node, "mode", b"MIN_FIRST")
     self.add_output_graph_node(quantize_node)
     dequantize_node = create_node("Dequantize", dequantize_name,
                                   [quantize_name, min_name, max_name])
     set_attr_dtype(dequantize_node, "T", tf.quint8)
     set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
     self.add_output_graph_node(dequantize_node)
Example #11
0
    def conv_bn(self, op_list):
        conv_op = add_op = mul_value = add_value = None
        next_op_list = op_list[-1].outputs
        for op in op_list[0]:
            if op.node.op == "Conv2D":
                conv_op = op
            elif op.node.op == "Add" or op.node.op == "AddV2":
                add_op = op

        for op in op_list[1:]:
            if op.node.op == "Mul":
                self._remove_node(self.op_dict[op.node.name])
                value = self._run_tensor(op.node.input[1])[0]
                mul_value = np.transpose(value, (0, 2, 3, 1))
            elif op.node.op == "Add" or op.node.op == "AddV2":
                self._remove_node(self.op_dict[op.node.name])
                value = self._run_tensor(op.node.input[1])[0]
                add_value = np.transpose(value, (0, 2, 3, 1))

        weight_value = self._run_tensor(conv_op.node.input[1])[0]
        weight_value *= mul_value
        self._create_const_node(conv_op.node.input[1], [weight_value])
        if add_op:
            bias_value = self._run_tensor(add_op.node.input[1])[0]
            bias_value = bias_value * mul_value + add_value
            self._create_const_node(add_op.node.input[1], [bias_value])
        else:
            bias_value = add_value
            bias_name = self.fork_name("bias")
            self._create_const_node(bias_name, [bias_value])
            node = tf.NodeDef()
            node.name = self.fork_name("add")
            node.op = "Add"
            node.input.extend([conv_op.node.name, bias_name])
            node.attr['T'].type = op_list[-1].node.attr['T'].type
            self.op_dict[node.name] = Operator(node)
            self.op_dict[node.name].inputs = [conv_op]
            self.op_dict[node.name].outputs = conv_op.outputs
            conv_op.outputs = [self.op_dict[node.name]]
            for op in next_op_list:
                op.node.input[list(op.node.input).index(
                    conv_op.node.name)] = node.name
                op.inputs.remove(conv_op)
                op.inputs.append(self.op_dict[node.name])
        return conv_op
Example #12
0
    def creat_conv_node(op_name, stride, padding=b'VALID', dtype=tf.float32):
        """
        :param op_name:
        :param stride:
        :param padding:
        :param dtype:
        :return:
        """

        new_node = tf.NodeDef()
        new_node.op = 'Conv2D'
        new_node.name = op_name
        new_node.attr["T"].CopyFrom(tf.AttrValue(type=dtype.as_datatype_enum))
        new_node.attr["use_cudnn_on_gpu"].CopyFrom(tf.AttrValue(b=1))
        new_node.attr["strides"].CopyFrom(
            tf.AttrValue(list=tf.AttrValue.ListValue(i=stride)))
        new_node.attr["padding"].CopyFrom(tf.AttrValue(s=padding))
        return new_node
Example #13
0
def _blob_to_node(producing_ops, shapes, name):
    assert name
    # Check for existance of __version__ for backwards compatibility
    n = tf.NodeDef() if hasattr(tf, '__version__') else graph_pb2.NodeDef()
    n.name = name
    inputs = producing_ops.get(name, [])
    if inputs:
        n.op = 'Blob'
    else:
        n.op = 'Placeholder'
    n.input.extend('%s:%d' % (op.name, i) for op, i in inputs)
    if inputs:
        device = inputs[0][0].device_option
        if (all(input[0].device_option == device for input in inputs)):
            n.device = _tf_device(device)
    if shapes and name in shapes:
        _add_tf_shape(n.attr, shapes[name])
    return n
Example #14
0
    def add_node(self, head_node_name, tail_node_name, new_node):
        """
        the following operations are supported
        ----------------------------------------------
         head_op------->tail_op
        ----------------------------------------------
        head_op----->add_op------>tail_op
        ----------------------------------------------
        :param head_node_name:
        :param tail_node_name:
        :param new_node:
        :return:
        """

        if isinstance(new_node, tf.NodeDef):
            # Initialize head_node_name and tail_node_name

            head_node = self.init_node(head_node_name)
            tail_node = self.init_node(tail_node_name)

            # extend the input
            new_node.input.extend([head_node.name])

            # route the new_node
            for item in self.graph_pb.node:
                if item.name == tail_node.name:
                    for i, _name in enumerate(item.input):
                        if head_node_name == _name:
                            item.input[i] = new_node.name

            # build new_graph
            for node in self.graph_pb.node:
                if self.node_reference_count[node.name] < 1:
                    continue
                new = tf.NodeDef()
                new.CopyFrom(node)
                self.new_graph.node.extend([new])
            self.new_graph.node.extend([new_node])
            return self.new_graph
        else:
            print("New_node must be the type of tf.NodeDef")
            return
Example #15
0
def strip_unused(input_graph, input_binary, output_graph, input_node_names,
                 output_node_names, placeholder_type_enum):
    """Removes unused nodes from a graph."""

    if not tf.gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    input_graph_def = tf.GraphDef()
    mode = "rb" if input_binary else "r"
    with tf.gfile.FastGFile(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read(), input_graph_def)

    # Here we replace the nodes we're going to override as inputs with
    # placeholders so that any unused nodes that are inputs to them are
    # automatically stripped out by extract_sub_graph().
    input_node_names_list = input_node_names.split(",")
    inputs_replaced_graph_def = tf.GraphDef()
    for node in input_graph_def.node:
        if node.name in input_node_names_list:
            placeholder_node = tf.NodeDef()
            placeholder_node.op = "Placeholder"
            placeholder_node.name = node.name
            placeholder_node.attr["dtype"].CopyFrom(
                tf.AttrValue(type=placeholder_type_enum))
            inputs_replaced_graph_def.node.extend([placeholder_node])
        else:
            inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

    output_graph_def = graph_util.extract_sub_graph(
        inputs_replaced_graph_def, output_node_names.split(","))

    with tf.gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())
    print("%d ops in the final graph." % len(output_graph_def.node))
Example #16
0
    def test_connect_to_shared_init_op(self):
        group_deps_name = 'group_deps'
        init_node_1 = 'table_init_1'
        init_node_2 = 'table_init_2'

        orig_graph_def = tf.GraphDef()
        expected_graph_def_1 = tf.GraphDef()

        meta_graph_editor._connect_to_shared_init_op(orig_graph_def,
                                                     group_deps_name, [])
        self.assertEqual(expected_graph_def_1, orig_graph_def)

        expected_graph_def_2 = tf.GraphDef()
        node_def = tf.NodeDef(name=group_deps_name, op='NoOp')
        node_def.input.extend(['^' + init_node_1, '^' + init_node_2])
        expected_graph_def_2.node.extend([node_def])

        meta_graph_editor._connect_to_shared_init_op(
            orig_graph_def, group_deps_name, [init_node_1, init_node_2])
        self.assertEqual(expected_graph_def_2, orig_graph_def)
Example #17
0
 def quantize_nodes_recursively(self, current_node):
     """The entry point for quantizing nodes to eight bit and back."""
     for input_node_name in current_node.input:
         input_node_name = node_name_from_input(input_node_name)
         if input_node_name in self.already_visited:
             continue
         input_node = self.nodes_map[input_node_name]
         self.quantize_nodes_recursively(input_node)
     self.already_visited[current_node.name] = True
     nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
     if any(current_node.op in s for s in nodes_to_quantize):
         for input_name in current_node.input:
             input_name = node_name_from_input(input_name)
             input_node = self.nodes_map[input_name]
             self.quantize_node(input_node)
         self.quantize_node(current_node)
     else:
         new_node = tf.NodeDef()
         new_node.CopyFrom(current_node)
         self.add_output_graph_node(new_node)
Example #18
0
 def to_node_def(self, target = None, add_shapes = True):
   # type: (tf.NodeDef, bool) -> tf.NodeDef
   """
   Args:
     target: optional preallocated, empty NodeDef object to fill in. If not
       provided, this method will allocate a new `tf.NodeDef` object.
     add_shapes: If True, add the special "_output_shapes" attribute with
       output shape information from this Node's output metadata.
   Returns:
       A copy of the contents of this node as a NodeDef proto. The returned
       proto will *not* change if this node is changed after the call, and
       vice versa.
   """
   if target is None:
     target = tf.NodeDef()
   target.name = self.name
   target.op = self.op_type
   for input_tensor in self.inputs:
     target.input.append(input_tensor.name)
   for control_input_node in self.control_inputs:
     target.input.append("^" + control_input_node.name)
   target.device = self.device
   for (attr_name, attr_value) in self._attributes:
     # Funky syntax for setting a field of a union in a protobuf
     target.attr[attr_name].CopyFrom(
       util.python_type_to_attr_value(attr_value))
   if len(self._colocation_groups) > 0:
     # Serialize colocation groups. See docstring in getter for
     # colocation_groups property for more information.
     transformed_names = [_COLOCATION_PREFIX + name
                          for name in self._colocation_groups]
     target.attr[_COLOCATION_ATTR_NAME].CopyFrom(
       util.python_type_to_attr_value(transformed_names)
     )
   if add_shapes and self._outputs is not None and len(self._outputs) > 0:
     shapes_list = [t.shape for t in self._outputs]
     target.attr[_OUTPUT_SHAPES_ATTR_NAME].CopyFrom(
       util.python_type_to_attr_value(shapes_list)
     )
   return target
Example #19
0
def strip_unused(input_graph_def, input_node_names, output_node_names,
                 placeholder_type_enum):
    """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_node_names: A list of the nodes we use as inputs.
    output_node_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    A GraphDef with all unnecessary ops removed.
  """
    # Here we replace the nodes we're going to override as inputs with
    # placeholders so that any unused nodes that are inputs to them are
    # automatically stripped out by extract_sub_graph().
    inputs_replaced_graph_def = tf.GraphDef()
    for node in input_graph_def.node:
        if node.name in input_node_names:
            placeholder_node = tf.NodeDef()
            placeholder_node.op = "Placeholder"
            placeholder_node.name = node.name
            if isinstance(placeholder_type_enum, list):
                input_node_index = input_node_names.index(node.name)
                placeholder_node.attr["dtype"].CopyFrom(
                    tf.AttrValue(type=placeholder_type_enum[input_node_index]))
            else:
                placeholder_node.attr["dtype"].CopyFrom(
                    tf.AttrValue(type=placeholder_type_enum))
            if "_output_shapes" in node.attr:
                placeholder_node.attr["_output_shapes"].CopyFrom(
                    node.attr["_output_shapes"])
            inputs_replaced_graph_def.node.extend([placeholder_node])
        else:
            inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

    output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                    output_node_names)
    return output_graph_def
Example #20
0
def create_node(name, op=None, trt_plugin=False, **kwargs):
    '''
    Creates a free-standing TensorFlow NodeDef with the specified properties.

    Args:
        name (str): The name of the node.
        op (str): The node's operation.

    Keyword Args:
        dtype (tensorflow.DType): TensorFlow dtype.
        shape (tuple(int)): Iterable container (usually a tuple) describing the shape of a tensor.
        inputs (list(tensorflow.NodeDef) or str): Iterable container (usually a tuple) of input nodes or input node names. Supports mixed-type lists.
        **kwargs (AttrName=Value): Any additional fields that should be present in the node. Currently supports int, float, bool, list(int), list(float), str and NumPy arrays. NumPy arrays will be inserted into the "value" attribute of the node - this can be useful for creating constant nodes equivalent to those created by tensorflow.constant.

    Returns:
        tensorflow.NodeDef
    '''
    if not trt_plugin:
        print(
            "WARNING: To create TensorRT plugin nodes, please use the `create_plugin_node` function instead."
        )
    node = tf.NodeDef()
    return update_node(node, name, op, trt_plugin, **kwargs)
Example #21
0
 def creat_maxpool_node(op_name,
                        ksize,
                        stride,
                        padding=b'VALID',
                        dtype=tf.float32):
     """
     :param op_name:
     :param ksize:
     :param stride:
     :param padding:
     :param dtype:
     :return:
     """
     new_node = tf.NodeDef()
     new_node.op = 'MaxPool'
     new_node.name = op_name
     new_node.attr["ksize"].CopyFrom(
         tf.AttrValue(list=tf.AttrValue.ListValue(i=ksize)))
     new_node.attr["T"].CopyFrom(tf.AttrValue(type=dtype.as_datatype_enum))
     new_node.attr["strides"].CopyFrom(
         tf.AttrValue(list=tf.AttrValue.ListValue(i=stride)))
     new_node.attr["padding"].CopyFrom(tf.AttrValue(s=padding))
     return new_node
Example #22
0
    def delete_node(self, node_name):
        """
        the following operations are supported
        ----------------------------------------------
                         |------>op_1
         op----->remove_op------>op_2
                        |------>op_3
        ----------------------------------------------
          |------>op_1
         op----->op_2
        |------>op_3
        ----------------------------------------------
        :param node_name:
        :return:
        """

        # init remove_node
        remove_node = self.init_node(node_name)

        # remove old_node
        assert node_name in self.node_name, "This node isn't in graph"
        self.node_reference_count[node_name] = 0

        # route the new_node
        for item in self.graph_pb.node:
            for i, _name in enumerate(item.input):
                if remove_node.name == _name:
                    item.input[i] = remove_node.input[0]

        # build new_graph
        for node in self.graph_pb.node:
            if self.node_reference_count[node.name] < 1:
                continue
            new = tf.NodeDef()
            new.CopyFrom(node)
            self.new_graph.node.extend([new])
        return self.new_graph
Example #23
0
def strip_pruning_vars_fn(input_graph_def, output_node_names):
    """Removes mask variable from the graph.

  Replaces the masked_weight tensor with element-wise multiplication of mask
  and the corresponding weight variable.

  Args:
    input_graph_def: A GraphDef in which the variables have been converted to
      constants. This is typically the output of
      tf.graph_util.convert_variables_to_constant()
    output_node_names: List of name strings for the result nodes of the graph

  Returns:
    A GraphDef in which pruning-related variables have been removed
  """
    masked_weights_dict = _get_masked_weights(input_graph_def)
    pruned_graph_def = tf.GraphDef()

    # Replace masked_weight with a const op containing the
    # result of tf.multiply(mask,weight)
    for node in input_graph_def.node:
        output_node = tf.NodeDef()
        if 'masked_weight' in node.name:
            output_node.op = 'Const'
            output_node.name = node.name
            dtype = node.attr['T']
            data = masked_weights_dict[node.name]
            output_node.attr['dtype'].CopyFrom(dtype)
            output_node.attr['value'].CopyFrom(
                tf.AttrValue(tensor=tf.make_tensor_proto(data)))

        else:
            output_node.CopyFrom(node)
        pruned_graph_def.node.extend([output_node])

    # Remove stranded nodes: mask and weights
    return tf.graph_util.extract_sub_graph(pruned_graph_def, output_node_names)
Example #24
0
  def test_get_deps_for_graph_node(self):
    graph_def = tf.compat.v1.GraphDef(node=[
        tf.NodeDef(name='foo', input=[]),
        tf.NodeDef(name='bar', input=['foo:0']),
        tf.NodeDef(name='baz', input=['foo:1', 'bar']),
        tf.NodeDef(name='bak', input=['bar', '^abc']),
        tf.NodeDef(name='abc', input=[]),
        tf.NodeDef(name='def', input=['abc:0']),
        tf.NodeDef(name='ghi', input=['^def']),
    ])

    def _get_deps(x):
      return ','.join(
          sorted(list(graph_utils.get_deps_for_graph_node(graph_def, x))))

    self.assertEqual(_get_deps('foo'), '')
    self.assertEqual(_get_deps('bar'), 'foo')
    self.assertEqual(_get_deps('baz'), 'bar,foo')
    self.assertEqual(_get_deps('bak'), 'abc,bar,foo')
    self.assertEqual(_get_deps('abc'), '')
    self.assertEqual(_get_deps('def'), 'abc')
    self.assertEqual(_get_deps('ghi'), 'abc,def')
def create_const_for_anchor_generator():
    """Creates a 'Const' node as an input to 'MultipleGridAnchorGenerator'
    Note the 'MultipleGridAnchorGenerator' TRT plugin node requires a
    [1.0, 1.0] array as input.
    Reference: https://stackoverflow.com/a/56296195/7596504
    """
    import numpy as np
    import tensorflow as tf
    from tensorflow.core.framework.tensor_pb2 import TensorProto
    from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto

    value = np.array([1.0, 1.0], dtype=np.float32)
    dt = tf.as_dtype(value.dtype).as_datatype_enum
    tensor_shape = TensorShapeProto(
        dim=[TensorShapeProto.Dim(size=s) for s in value.shape])
    tensor_proto = TensorProto(tensor_content=value.tobytes(),
                               tensor_shape=tensor_shape,
                               dtype=dt)
    return tf.NodeDef(name='const_for_anchors',
                      op='Const',
                      attr={
                          'value': tf.AttrValue(tensor=tensor_proto),
                          'dtype': tf.AttrValue(type=dt)
                      })
Example #26
0
 def creat_const_node(op_name, arrary=None, dtype=tf.float32, shape=None):
     """
     :param op_name:
     :param arrary:
     :param dtype:
     :param shape:
     :return:
     """
     new_node = tf.NodeDef()
     new_node.op = 'Const'
     new_node.name = op_name
     new_node.attr['dtype'].CopyFrom(
         tf.AttrValue(type=dtype.as_datatype_enum))
     assert list(np.shape(arrary)) == shape, "Please check the value"
     new_node.attr['value'].CopyFrom(
         tf.AttrValue(tensor=tf.make_tensor_proto(arrary, dtype, shape)))
     new_node.attr['_output_shapes'].CopyFrom(
         tf.AttrValue(list=tf.AttrValue.ListValue(shape=[
             tensor_shape_pb2.TensorShapeProto(dim=[
                 tensor_shape_pb2.TensorShapeProto.Dim(size=x)
                 for x in shape
             ])
         ])))
     return new_node
Example #27
0
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
                 output_node_names, restore_op_name, filename_tensor_name,
                 output_graph, clear_devices):
    """Converts all variables in a graph and checkpoint into constants."""

    if not gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return -1

    if input_saver and not gfile.Exists(input_saver):
        print("Input saver file '" + input_saver + "' does not exist!")
        return -1

    if not gfile.Exists(input_checkpoint):
        print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
        return -1

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    input_graph_def = tf.GraphDef()
    with open(input_graph, "rb") as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(bytes(f.read()), input_graph_def)
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ""
    _ = tf.import_graph_def(input_graph_def, name="")

    with tf.Session() as sess:
        if input_saver:
            with open(input_saver, "rb") as f:
                saver_def = tf.train.SaverDef()
                if input_binary:
                    saver_def.ParseFromString(f.read())
                else:
                    text_format.Merge(f.read(), saver_def)
                saver = tf.train.Saver(saver_def=saver_def)
                saver.restore(sess, input_checkpoint)
        else:
            sess.run([restore_op_name],
                     {filename_tensor_name: input_checkpoint})
        found_variables = {}
        for node in input_graph_def.node:
            if node.op == "Assign":
                variable_name = node.input[0]
                found_variables[variable_name] = sess.run(variable_name + ":0")

    # This graph only includes the nodes needed to evaluate the output nodes, and
    # removes unneeded nodes like those involved in saving and assignment.
    inference_graph = graph_util.extract_sub_graph(
        input_graph_def, output_node_names.split(","))

    output_graph_def = tf.GraphDef()
    how_many_converted = 0
    for input_node in inference_graph.node:
        output_node = tf.NodeDef()
        if input_node.name in found_variables:
            output_node.op = "Const"
            output_node.name = input_node.name
            dtype = input_node.attr["dtype"]
            data = found_variables[input_node.name]
            set_attr_dtype(output_node, "dtype", dtype)
            set_attr_tensor(output_node, "value", data, dtype.type, data.shape)
            how_many_converted += 1
        else:
            output_node.CopyFrom(input_node)
        output_graph_def.node.extend([output_node])

    with gfile.FastGFile(output_graph, "w") as f:
        f.write(output_graph_def.SerializeToString())
    print("Converted %d variables to const ops." % how_many_converted)
    print("%d ops in the final graph." % len(output_graph_def.node))
Example #28
0
def copy_op_to_graph(org_instance, to_graph, variables, scope=''):
    """Returns a copy of an operation from another Graph under a specified scope.

  Given an `Operation` `org_instance` from one `Graph`,
  initializes and returns a copy of it from another `Graph`,
  under the specified scope (default `""`).

  The copying is done recursively, so any `Operation` whose output
  is required to evaluate the `org_instance`, is also copied (unless
  already done).

  Since `Variable` instances are copied separately, those required
  to evaluate `org_instance` must be provided as input.

  Args:
    org_instance: An `Operation` from some `Graph`. Could be a
      `Placeholder` as well.
    to_graph: The `Graph` to copy `org_instance` to.
    variables: An iterable of `Variable` instances to copy `org_instance` to.
    scope: A scope for the new `Variable` (default `""`).

  Returns:
    The copied `Operation` from `to_graph`.

  Raises:
    TypeError: If `org_instance` is not an `Operation` or `Tensor`.
  """

    #The name of the new instance
    if scope != '':
        new_name = scope + '/' + org_instance.name
    else:
        new_name = org_instance.name
    # print(new_name)
    #Extract names of variables
    copied_variables = dict((x.name, x) for x in variables)
    #If a variable by the new name already exists, return the
    #correspondng tensor that will act as an input
    if new_name in copied_variables:
        return to_graph.get_tensor_by_name(copied_variables[new_name].name)
    #If an instance of the same name exists, return appropriately
    try:
        already_present = to_graph.as_graph_element(new_name,
                                                    allow_tensor=True,
                                                    allow_operation=True)
        return already_present
    except:
        pass
    #Get the collections that the new instance needs to be added to.
    #The new collections will also be a part of the given scope.
    collections = []
    for name, collection in org_instance.graph._collections.items():
        if org_instance in collection:
            if scope == '':
                collections.append(name)
            else:
                collections.append(scope + '/' + name)
    #Take action based on the class of the instance

    if isinstance(org_instance, ops.Tensor):
        #If its a Tensor, it is one of the outputs of the underlying
        #op. Therefore, copy the op itself and return the appropriate
        #output.
        op = org_instance.op
        new_op = copy_op_to_graph(op, to_graph, variables, scope)
        output_index = op.outputs.index(org_instance)
        new_tensor = new_op.outputs[output_index]
        #Add to collections if any
        for collection in collections:
            to_graph.add_to_collection(collection, new_tensor)

        return new_tensor

    elif isinstance(org_instance, ops.Operation):
        op = org_instance

        #If it has an original_op parameter, copy it
        if op._original_op is not None:
            new_original_op = copy_op_to_graph(op._original_op, to_graph,
                                               variables, scope)
        else:
            new_original_op = None

        #If it has control inputs, call this function recursively on each.
        new_control_inputs = [
            copy_op_to_graph(x, to_graph, variables, scope)
            for x in op.control_inputs
        ]

        #If it has inputs, call this function recursively on each.
        new_inputs = [
            copy_op_to_graph(x, to_graph, variables, scope) for x in op.inputs
        ]

        #Make a new node_def based on that of the original.
        #An instance of tensorflow.core.framework.node_def_pb2.NodeDef, it
        #stores String-based info such as name, device and type of the op.
        #Unique to every Operation instance.
        #Colocate info needs to be cleared here
        new_attr = dict()
        for key in op.node_def.attr:
            # don't copy colocate info
            if key == '_class':
                pass
            else:
                new_attr[key] = op.node_def.attr[key]

        new_node_def = tf.NodeDef(name=new_name,
                                  op=op.node_def.op,
                                  input=op.node_def.input,
                                  device=op.node_def.device,
                                  attr=new_attr)

        #Copy the other inputs needed for initialization
        output_types = op._output_types[:]
        input_types = op._input_types[:]

        #Make a copy of the op_def too.
        #Its unique to every _type_ of Operation.
        op_def = deepcopy(op.op_def)

        #Initialize a new Operation instance
        new_op = ops.Operation(new_node_def, to_graph, new_inputs,
                               output_types, new_control_inputs, input_types,
                               new_original_op, op_def)
        #Use Graph's hidden methods to add the op
        to_graph._add_op(new_op)  # pylint: disable=protected-access
        to_graph._record_op_seen_by_control_dependencies(new_op)
        for device_function in reversed(to_graph._device_function_stack):
            new_op._set_device(device_function(new_op))

        return new_op

    else:
        raise TypeError('Could not copy instance: ' + str(org_instance))
Example #29
0
    def remove_unneeded_nodes(self, input_graph):
        """Prunes out nodes that aren't needed for inference.

    There are nodes like Identity and CheckNumerics that are only useful
    during training, and can be removed in graphs that will be used for
    nothing but inference. Here we identify and remove them, returning an
    equivalent graph.

    Args:
      input_graph: Model to analyze and prune.

    Returns:
    A list of nodes with the unnecessary ones removed.
    """

        types_to_remove = {"CheckNumerics": True}

        input_nodes = input_graph.node
        names_to_remove = {}
        for node in input_nodes:
            if node.op in types_to_remove:
                names_to_remove[node.name] = True

        nodes_after_removal = []
        for node in input_nodes:
            if node.name in names_to_remove:
                continue
            new_node = tf.NodeDef()
            new_node.CopyFrom(node)
            input_before_removal = node.input
            del new_node.input[:]
            for full_input_name in input_before_removal:
                input_name = re.sub(r"^\^", "", full_input_name)
                if input_name in names_to_remove:
                    continue
                new_node.input.append(full_input_name)
            nodes_after_removal.append(new_node)

        types_to_splice = {"Identity": True}
        names_to_splice = {}
        for node in nodes_after_removal:
            if node.op in types_to_splice:
                # We don't want to remove nodes that have control edge inputs, because
                # they might be involved in subtle dependency issues that removing them
                # will jeopardize.
                has_control_edge = False
                for input_name in node.input:
                    if re.match(r"^\^", input_name):
                        has_control_edge = True
                if not has_control_edge:
                    names_to_splice[node.name] = node.input[0]

        nodes_after_splicing = []
        for node in nodes_after_removal:
            if node.name in names_to_splice:
                continue
            new_node = tf.NodeDef()
            new_node.CopyFrom(node)
            input_before_removal = node.input
            del new_node.input[:]
            for full_input_name in input_before_removal:
                input_name = re.sub(r"^\^", "", full_input_name)
                if input_name in names_to_splice:
                    new_node.input.append(names_to_splice[input_name])
                else:
                    new_node.input.append(full_input_name)
            nodes_after_splicing.append(new_node)

        output_graph = tf.GraphDef()
        output_graph.node.extend(nodes_after_splicing)
        return output_graph
Example #30
0
def main(_):
    print("Pix2pix tensorflow Exporter!")
    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)
    if not os.path.exists(args.sample_dir):
        os.makedirs(args.sample_dir)
    if not os.path.exists(args.test_dir):
        os.makedirs(args.test_dir)

    with tf.Session() as sess:
        model = pix2pix(sess,
                        image_size=args.fine_size,
                        batch_size=args.batch_size,
                        output_size=args.fine_size,
                        dataset_name=args.dataset_name,
                        checkpoint_dir=args.checkpoint_dir,
                        sample_dir=args.sample_dir,
                        input_c_dim=args.input_nc,
                        output_c_dim=args.output_nc,
                        direction=args.which_direction)

        model.load_model(args)
        graph = tf.get_default_graph()
        input_graph_def = graph.as_graph_def()

        # fix batch norm nodes
        for node in input_graph_def.node:
            if node.op == 'RefSwitch':
                node.op = 'Switch'
                for index in xrange(len(node.input)):
                    if 'moving_' in node.input[index]:
                        node.input[index] = node.input[index] + '/read'
            elif node.op == 'AssignSub':
                node.op = 'Sub'
                if 'use_locking' in node.attr:
                    del node.attr['use_locking']

        # freeze!
        freeze_graph_def = graph_util.convert_variables_to_constants(
            sess, input_graph_def, ['generator/Tanh'])

        #copy input-related sub graph_util
        input_node_names_list = ['real_A_and_B_images']
        input_replaced_graph_def = tf.GraphDef()
        for node in freeze_graph_def.node:
            if node.name in input_node_names_list:
                placeholder_node = tf.NodeDef()
                placeholder_node.op = 'Placeholder'
                placeholder_node.name = node.name
                placeholder_node.attr['dtype'].CopyFrom(
                    tf.AttrValue(type=tf.float32.as_datatype_enum))
                input_replaced_graph_def.node.extend([placeholder_node])
                print(node.name, 'is replaced with placeholder')
            else:
                input_replaced_graph_def.node.extend([copy.deepcopy(node)])

        # extract subgraph
        output_sub_graph_def = graph_util.extract_sub_graph(
            input_replaced_graph_def, ['generator/Tanh'])

        with tf.gfile.GFile('export_model.pb', 'wb') as f:
            f.write(output_sub_graph_def.SerializeToString())