def setNodeConstValue(gdef, node, value): output_node = tf.NodeDef() output_node.name = node.name output_node.op = node.op dtype = node.attr["dtype"].type output_node.attr["dtype"].type = dtype output_node.attr["value"].CopyFrom( tf.AttrValue( tensor=tf.contrib.util.make_tensor_proto(value, dtype=dtype))) node.CopyFrom(output_node) return node
def quantize_weights(self, input_graph, quantization_mode): """Quantize float Const ops. There are two modes of operations, both replace float Const ops with quantized values. 1. If quantization_mode is "weights_rounded", this function replaces float Const ops with quantized float Const ops - same as the original op, but float values being mapped to the center of one of 1<<FLAGS.bitdepth buckets. This does not change the raw model size, but compression algorithms such as zip (as used for compressing apks) or bzip2 will achieve a very good compression ratio. 2. For other quantization modes ("MIN_COMBINED" or "MIN_FIRST"), float Const ops are quantized and replaced by a tuple of four ops to perform the dequantization at runtime: * eight-bit Const (bucket indices, same shape as original float Const op * two float Const ops (min and max value of original float Const op) * Dequantize op to convert the eight-bit consts to float tensors. The quantization mode is important because we see accuracy problems when quantizing weights for different situations depending on the algorithm used. We haven't figured out exactly what the underlying cause is yet, unfortunately. Args: input_graph: A GraphDef of the model containing float Const ops. quantization_mode: How to quantize and dequantize the values. Returns: A GraphDef of the converted graph. Raises: ValueError: If quantization_mode is unsupported. """ output_graph = tf.GraphDef() for input_node in input_graph.node: should_quantize = False if input_node.op == "Const": dtype = tf.as_dtype(input_node.attr["dtype"].type) if dtype == tf.float32: should_quantize = True if should_quantize: if quantization_mode == "weights_rounded": output_graph.node.extend(quantize_weight_rounded(input_node)) elif quantization_mode in (b"MIN_COMBINED", b"MIN_FIRST"): output_graph.node.extend(quantize_weight_eightbit(input_node, quantization_mode)) else: raise ValueError("Unsupported quantization mode %s." % quantization_mode) else: output_node = tf.NodeDef() output_node.CopyFrom(input_node) output_graph.node.extend([output_node]) return output_graph
def rename_nodes(graph_def: tf.GraphDef, new_names: Dict[str, str]) -> tf.GraphDef: """Rename items in the graph to new ones defined in new_names Parameters ---------- graph_def : tf.GraphDef Graph Definition new_names : Dict[str, str] Mapping old name -> new name Returns ------- tf.GraphDef A copy of the input GraphDef with renamed nodes Raises ------ TensorsNotFoundError If new_names refers to an node not found in graph_def """ # Create copy of each node with a new name nodes = [] for node in graph_def.node: new_node = tf.NodeDef() new_node.CopyFrom(node) nodes.append(new_node) match = re.match(r"^(?:cond(?:_\d+)?/)?(.+?)(?:_\d+)?$", node.name) if match and match.groups()[0] in new_names: new_name = new_names[match.groups()[0]] new_node.name = new_name LOGGER.info(f"Node renamed: {node.name} -> {new_node.name}") # Check that all new names were used if not set(new_names.values()) <= set(node.name for node in nodes): missing = set(new_names.values()) - set(node.name for node in nodes) raise TensorsNotFoundError(missing) # Update node references (inputs and location) to renamed nodes for node in nodes: for idx, name in enumerate(node.input): node.input[idx] = new_names[name] if name in new_names else name if "_class" in node.attr: attr = node.attr["_class"] for idx, item in enumerate(attr.list.s): loc_match = re.match(r"^loc:@(.+)$", item.decode()) if loc_match and loc_match.groups()[0] in new_names: new_name = new_names[loc_match.groups()[0]] attr.list.s[idx] = f"loc:@{new_name}".encode() # Create Graph with renamed nodes new_graph = tf.GraphDef() new_graph.node.extend(nodes) return new_graph
def to_node_def(self): ret = tf.NodeDef() ret.name = self.name ret.op = self.op_name for input_tensor in self.inputs: ret.input.append(input_tensor.name) for control_input_node in self.control_inputs: ret.input.append("^" + control_input_node.name) ret.device = self.device for (attr_name, attr_value) in self._attributes: # Funky syntax for setting a field of a union in a protobuf ret.attr[attr_name].CopyFrom(_python_type_to_attr_value(attr_value)) return ret
def init_node(self, node_name): """ according to the node_name, find the node from the graph :param node_name: :return: """ new_node = tf.NodeDef() for node in self.graph_pb.node: if node.name == node_name: new_node = node return new_node if new_node.name != node_name: print("There isn't this node in graph") return
def make_placeholders(graph_def: tf.GraphDef, names: List[str]) -> tf.GraphDef: """Create placeholders for names and remove other placeholders Parameters ---------- graph_def : tf.GraphDef Graph definition names : List[str] Names of placeholders to keep / create for this graph Returns ------- tf.GraphDef A copy of the input GraphDef with new placeholders Raises ------ ValueError If names refers to a node that is not present """ # Create copy of each node and change to Placeholder if in names nodes = [] for node in graph_def.node: if node.name not in names and node.op == "Placeholder": LOGGER.info(f"Removing placeholder {node.name}") continue new_node = tf.NodeDef() if node.name in names and node.op != "Placeholder": LOGGER.info(f"Creating placeholder {node.name}") new_node.name = node.name new_node.op = "Placeholder" new_node.attr["shape"].CopyFrom( tf.AttrValue(shape=node.attr["_output_shapes"].list.shape[0])) new_node.attr["dtype"].CopyFrom(node.attr["T"]) else: new_node.CopyFrom(node) nodes.append(new_node) # Check that all expected placeholders have been found if not set(names) <= set(node.name for node in nodes): raise ValueError( f"Missing placeholders: {set(names) - set(node.name for node in nodes)}" ) # Create Graph with renamed nodes new_graph = tf.GraphDef() new_graph.node.extend(nodes) return new_graph
def create_new_node(input_weight_a, intput_weight_b, old_node): import numpy as np #w_init = np.random.randn(131072, 103).astype(np.float32) merge_weight = np.concatenate((input_weight_a, intput_weight_b), axis=1) #w = tf.Variable(tf.convert_to_tensor(w_init)) tensor_proto = tf.make_tensor_proto(merge_weight) new_node = tf.NodeDef(name=old_node.name, op='Const', attr={ 'value': tf.AttrValue(tensor=tensor_proto), 'dtype': tf.AttrValue(type='DT_FLOAT') }) #new_node.input.extend([' model/classifier_block/flatten', 'model/classifier_block/dense_1/kernel/read']) return new_node
def _operator_to_node(shapes, op): assert op.name, op # Check for existance of __version__ for backwards compatibility n = tf.NodeDef() if hasattr(tf, '__version__') else graph_pb2.NodeDef() n.name = op.name n.input.extend(op.input) n.op = op.type n.device = _tf_device(op.device_option) if shapes: # Add shapes in order. for output in op.output: if output not in shapes: break _add_tf_shape(n.attr, shapes[output]) for arg in op.arg: _set_tf_attr(n.attr, arg) return n
def test_add_control_deps_for_init_op(self): graph_def = tf.compat.v1.GraphDef(node=[ tf.NodeDef(name='foo', input=[]), tf.NodeDef(name='bar', input=['foo']), tf.NodeDef(name='baz', input=['foo', 'bar']), tf.NodeDef(name='bak', input=['bar', '^abc']), tf.NodeDef(name='abc', input=['def:0']), tf.NodeDef(name='def', input=['^ghi']), tf.NodeDef(name='ghi', input=[]), ]) new_graph_def = graph_utils.add_control_deps_for_init_op(graph_def, 'abc') self.assertEqual( ','.join('{}({})'.format(node.name, ','.join(node.input)) for node in new_graph_def.node), 'foo(^abc),bar(foo,^abc),baz(foo,bar,^abc),' 'bak(bar,^abc),abc(def:0),def(^ghi),ghi()')
def quantize_node(self, input_node): """Handles quantizing a single node.""" input_name = input_node.name if input_name in self.already_quantized: return self.already_quantized[input_name] = True original_input_name = input_name + "_original" reshape_name = input_name + "_reshape" reshape_dims_name = input_name + "_reshape_dims" max_name = input_name + "_max" min_name = input_name + "_min" dims_name = input_name + "_dims" quantize_name = input_name + "_quantize" dequantize_name = input_name original_input_node = tf.NodeDef() original_input_node.CopyFrom(input_node) original_input_node.name = original_input_name self.add_output_graph_node(original_input_node) reshape_dims_node = create_constant_node(reshape_dims_name, -1, tf.int32, [1]) self.add_output_graph_node(reshape_dims_node) reshape_node = create_node("Reshape", reshape_name, [original_input_name, reshape_dims_name]) set_attr_dtype(reshape_node, "T", tf.float32) self.add_output_graph_node(reshape_node) dims_node = create_constant_node(dims_name, 0, tf.int32, [1]) self.add_output_graph_node(dims_node) max_node = create_node("Max", max_name, [reshape_name, dims_name]) set_attr_dtype(max_node, "T", tf.float32) set_attr_bool(max_node, "keep_dims", False) self.add_output_graph_node(max_node) min_node = create_node("Min", min_name, [reshape_name, dims_name]) set_attr_dtype(min_node, "T", tf.float32) set_attr_bool(min_node, "keep_dims", False) self.add_output_graph_node(min_node) quantize_node = create_node("Quantize", quantize_name, [original_input_name, min_name, max_name]) set_attr_dtype(quantize_node, "T", tf.quint8) set_attr_string(quantize_node, "mode", b"MIN_FIRST") self.add_output_graph_node(quantize_node) dequantize_node = create_node("Dequantize", dequantize_name, [quantize_name, min_name, max_name]) set_attr_dtype(dequantize_node, "T", tf.quint8) set_attr_string(dequantize_node, "mode", b"MIN_FIRST") self.add_output_graph_node(dequantize_node)
def conv_bn(self, op_list): conv_op = add_op = mul_value = add_value = None next_op_list = op_list[-1].outputs for op in op_list[0]: if op.node.op == "Conv2D": conv_op = op elif op.node.op == "Add" or op.node.op == "AddV2": add_op = op for op in op_list[1:]: if op.node.op == "Mul": self._remove_node(self.op_dict[op.node.name]) value = self._run_tensor(op.node.input[1])[0] mul_value = np.transpose(value, (0, 2, 3, 1)) elif op.node.op == "Add" or op.node.op == "AddV2": self._remove_node(self.op_dict[op.node.name]) value = self._run_tensor(op.node.input[1])[0] add_value = np.transpose(value, (0, 2, 3, 1)) weight_value = self._run_tensor(conv_op.node.input[1])[0] weight_value *= mul_value self._create_const_node(conv_op.node.input[1], [weight_value]) if add_op: bias_value = self._run_tensor(add_op.node.input[1])[0] bias_value = bias_value * mul_value + add_value self._create_const_node(add_op.node.input[1], [bias_value]) else: bias_value = add_value bias_name = self.fork_name("bias") self._create_const_node(bias_name, [bias_value]) node = tf.NodeDef() node.name = self.fork_name("add") node.op = "Add" node.input.extend([conv_op.node.name, bias_name]) node.attr['T'].type = op_list[-1].node.attr['T'].type self.op_dict[node.name] = Operator(node) self.op_dict[node.name].inputs = [conv_op] self.op_dict[node.name].outputs = conv_op.outputs conv_op.outputs = [self.op_dict[node.name]] for op in next_op_list: op.node.input[list(op.node.input).index( conv_op.node.name)] = node.name op.inputs.remove(conv_op) op.inputs.append(self.op_dict[node.name]) return conv_op
def creat_conv_node(op_name, stride, padding=b'VALID', dtype=tf.float32): """ :param op_name: :param stride: :param padding: :param dtype: :return: """ new_node = tf.NodeDef() new_node.op = 'Conv2D' new_node.name = op_name new_node.attr["T"].CopyFrom(tf.AttrValue(type=dtype.as_datatype_enum)) new_node.attr["use_cudnn_on_gpu"].CopyFrom(tf.AttrValue(b=1)) new_node.attr["strides"].CopyFrom( tf.AttrValue(list=tf.AttrValue.ListValue(i=stride))) new_node.attr["padding"].CopyFrom(tf.AttrValue(s=padding)) return new_node
def _blob_to_node(producing_ops, shapes, name): assert name # Check for existance of __version__ for backwards compatibility n = tf.NodeDef() if hasattr(tf, '__version__') else graph_pb2.NodeDef() n.name = name inputs = producing_ops.get(name, []) if inputs: n.op = 'Blob' else: n.op = 'Placeholder' n.input.extend('%s:%d' % (op.name, i) for op, i in inputs) if inputs: device = inputs[0][0].device_option if (all(input[0].device_option == device for input in inputs)): n.device = _tf_device(device) if shapes and name in shapes: _add_tf_shape(n.attr, shapes[name]) return n
def add_node(self, head_node_name, tail_node_name, new_node): """ the following operations are supported ---------------------------------------------- head_op------->tail_op ---------------------------------------------- head_op----->add_op------>tail_op ---------------------------------------------- :param head_node_name: :param tail_node_name: :param new_node: :return: """ if isinstance(new_node, tf.NodeDef): # Initialize head_node_name and tail_node_name head_node = self.init_node(head_node_name) tail_node = self.init_node(tail_node_name) # extend the input new_node.input.extend([head_node.name]) # route the new_node for item in self.graph_pb.node: if item.name == tail_node.name: for i, _name in enumerate(item.input): if head_node_name == _name: item.input[i] = new_node.name # build new_graph for node in self.graph_pb.node: if self.node_reference_count[node.name] < 1: continue new = tf.NodeDef() new.CopyFrom(node) self.new_graph.node.extend([new]) self.new_graph.node.extend([new_node]) return self.new_graph else: print("New_node must be the type of tf.NodeDef") return
def strip_unused(input_graph, input_binary, output_graph, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a graph.""" if not tf.gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = tf.GraphDef() mode = "rb" if input_binary else "r" with tf.gfile.FastGFile(input_graph, mode) as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), input_graph_def) # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). input_node_names_list = input_node_names.split(",") inputs_replaced_graph_def = tf.GraphDef() for node in input_graph_def.node: if node.name in input_node_names_list: placeholder_node = tf.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name placeholder_node.attr["dtype"].CopyFrom( tf.AttrValue(type=placeholder_type_enum)) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) output_graph_def = graph_util.extract_sub_graph( inputs_replaced_graph_def, output_node_names.split(",")) with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node))
def test_connect_to_shared_init_op(self): group_deps_name = 'group_deps' init_node_1 = 'table_init_1' init_node_2 = 'table_init_2' orig_graph_def = tf.GraphDef() expected_graph_def_1 = tf.GraphDef() meta_graph_editor._connect_to_shared_init_op(orig_graph_def, group_deps_name, []) self.assertEqual(expected_graph_def_1, orig_graph_def) expected_graph_def_2 = tf.GraphDef() node_def = tf.NodeDef(name=group_deps_name, op='NoOp') node_def.input.extend(['^' + init_node_1, '^' + init_node_2]) expected_graph_def_2.node.extend([node_def]) meta_graph_editor._connect_to_shared_init_op( orig_graph_def, group_deps_name, [init_node_1, init_node_2]) self.assertEqual(expected_graph_def_2, orig_graph_def)
def quantize_nodes_recursively(self, current_node): """The entry point for quantizing nodes to eight bit and back.""" for input_node_name in current_node.input: input_node_name = node_name_from_input(input_node_name) if input_node_name in self.already_visited: continue input_node = self.nodes_map[input_node_name] self.quantize_nodes_recursively(input_node) self.already_visited[current_node.name] = True nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"] if any(current_node.op in s for s in nodes_to_quantize): for input_name in current_node.input: input_name = node_name_from_input(input_name) input_node = self.nodes_map[input_name] self.quantize_node(input_node) self.quantize_node(current_node) else: new_node = tf.NodeDef() new_node.CopyFrom(current_node) self.add_output_graph_node(new_node)
def to_node_def(self, target = None, add_shapes = True): # type: (tf.NodeDef, bool) -> tf.NodeDef """ Args: target: optional preallocated, empty NodeDef object to fill in. If not provided, this method will allocate a new `tf.NodeDef` object. add_shapes: If True, add the special "_output_shapes" attribute with output shape information from this Node's output metadata. Returns: A copy of the contents of this node as a NodeDef proto. The returned proto will *not* change if this node is changed after the call, and vice versa. """ if target is None: target = tf.NodeDef() target.name = self.name target.op = self.op_type for input_tensor in self.inputs: target.input.append(input_tensor.name) for control_input_node in self.control_inputs: target.input.append("^" + control_input_node.name) target.device = self.device for (attr_name, attr_value) in self._attributes: # Funky syntax for setting a field of a union in a protobuf target.attr[attr_name].CopyFrom( util.python_type_to_attr_value(attr_value)) if len(self._colocation_groups) > 0: # Serialize colocation groups. See docstring in getter for # colocation_groups property for more information. transformed_names = [_COLOCATION_PREFIX + name for name in self._colocation_groups] target.attr[_COLOCATION_ATTR_NAME].CopyFrom( util.python_type_to_attr_value(transformed_names) ) if add_shapes and self._outputs is not None and len(self._outputs) > 0: shapes_list = [t.shape for t in self._outputs] target.attr[_OUTPUT_SHAPES_ATTR_NAME].CopyFrom( util.python_type_to_attr_value(shapes_list) ) return target
def strip_unused(input_graph_def, input_node_names, output_node_names, placeholder_type_enum): """Removes unused nodes from a GraphDef. Args: input_graph_def: A graph with nodes we want to prune. input_node_names: A list of the nodes we use as inputs. output_node_names: A list of the output nodes. placeholder_type_enum: The AttrValue enum for the placeholder data type, or a list that specifies one value per input node name. Returns: A GraphDef with all unnecessary ops removed. """ # Here we replace the nodes we're going to override as inputs with # placeholders so that any unused nodes that are inputs to them are # automatically stripped out by extract_sub_graph(). inputs_replaced_graph_def = tf.GraphDef() for node in input_graph_def.node: if node.name in input_node_names: placeholder_node = tf.NodeDef() placeholder_node.op = "Placeholder" placeholder_node.name = node.name if isinstance(placeholder_type_enum, list): input_node_index = input_node_names.index(node.name) placeholder_node.attr["dtype"].CopyFrom( tf.AttrValue(type=placeholder_type_enum[input_node_index])) else: placeholder_node.attr["dtype"].CopyFrom( tf.AttrValue(type=placeholder_type_enum)) if "_output_shapes" in node.attr: placeholder_node.attr["_output_shapes"].CopyFrom( node.attr["_output_shapes"]) inputs_replaced_graph_def.node.extend([placeholder_node]) else: inputs_replaced_graph_def.node.extend([copy.deepcopy(node)]) output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def, output_node_names) return output_graph_def
def create_node(name, op=None, trt_plugin=False, **kwargs): ''' Creates a free-standing TensorFlow NodeDef with the specified properties. Args: name (str): The name of the node. op (str): The node's operation. Keyword Args: dtype (tensorflow.DType): TensorFlow dtype. shape (tuple(int)): Iterable container (usually a tuple) describing the shape of a tensor. inputs (list(tensorflow.NodeDef) or str): Iterable container (usually a tuple) of input nodes or input node names. Supports mixed-type lists. **kwargs (AttrName=Value): Any additional fields that should be present in the node. Currently supports int, float, bool, list(int), list(float), str and NumPy arrays. NumPy arrays will be inserted into the "value" attribute of the node - this can be useful for creating constant nodes equivalent to those created by tensorflow.constant. Returns: tensorflow.NodeDef ''' if not trt_plugin: print( "WARNING: To create TensorRT plugin nodes, please use the `create_plugin_node` function instead." ) node = tf.NodeDef() return update_node(node, name, op, trt_plugin, **kwargs)
def creat_maxpool_node(op_name, ksize, stride, padding=b'VALID', dtype=tf.float32): """ :param op_name: :param ksize: :param stride: :param padding: :param dtype: :return: """ new_node = tf.NodeDef() new_node.op = 'MaxPool' new_node.name = op_name new_node.attr["ksize"].CopyFrom( tf.AttrValue(list=tf.AttrValue.ListValue(i=ksize))) new_node.attr["T"].CopyFrom(tf.AttrValue(type=dtype.as_datatype_enum)) new_node.attr["strides"].CopyFrom( tf.AttrValue(list=tf.AttrValue.ListValue(i=stride))) new_node.attr["padding"].CopyFrom(tf.AttrValue(s=padding)) return new_node
def delete_node(self, node_name): """ the following operations are supported ---------------------------------------------- |------>op_1 op----->remove_op------>op_2 |------>op_3 ---------------------------------------------- |------>op_1 op----->op_2 |------>op_3 ---------------------------------------------- :param node_name: :return: """ # init remove_node remove_node = self.init_node(node_name) # remove old_node assert node_name in self.node_name, "This node isn't in graph" self.node_reference_count[node_name] = 0 # route the new_node for item in self.graph_pb.node: for i, _name in enumerate(item.input): if remove_node.name == _name: item.input[i] = remove_node.input[0] # build new_graph for node in self.graph_pb.node: if self.node_reference_count[node.name] < 1: continue new = tf.NodeDef() new.CopyFrom(node) self.new_graph.node.extend([new]) return self.new_graph
def strip_pruning_vars_fn(input_graph_def, output_node_names): """Removes mask variable from the graph. Replaces the masked_weight tensor with element-wise multiplication of mask and the corresponding weight variable. Args: input_graph_def: A GraphDef in which the variables have been converted to constants. This is typically the output of tf.graph_util.convert_variables_to_constant() output_node_names: List of name strings for the result nodes of the graph Returns: A GraphDef in which pruning-related variables have been removed """ masked_weights_dict = _get_masked_weights(input_graph_def) pruned_graph_def = tf.GraphDef() # Replace masked_weight with a const op containing the # result of tf.multiply(mask,weight) for node in input_graph_def.node: output_node = tf.NodeDef() if 'masked_weight' in node.name: output_node.op = 'Const' output_node.name = node.name dtype = node.attr['T'] data = masked_weights_dict[node.name] output_node.attr['dtype'].CopyFrom(dtype) output_node.attr['value'].CopyFrom( tf.AttrValue(tensor=tf.make_tensor_proto(data))) else: output_node.CopyFrom(node) pruned_graph_def.node.extend([output_node]) # Remove stranded nodes: mask and weights return tf.graph_util.extract_sub_graph(pruned_graph_def, output_node_names)
def test_get_deps_for_graph_node(self): graph_def = tf.compat.v1.GraphDef(node=[ tf.NodeDef(name='foo', input=[]), tf.NodeDef(name='bar', input=['foo:0']), tf.NodeDef(name='baz', input=['foo:1', 'bar']), tf.NodeDef(name='bak', input=['bar', '^abc']), tf.NodeDef(name='abc', input=[]), tf.NodeDef(name='def', input=['abc:0']), tf.NodeDef(name='ghi', input=['^def']), ]) def _get_deps(x): return ','.join( sorted(list(graph_utils.get_deps_for_graph_node(graph_def, x)))) self.assertEqual(_get_deps('foo'), '') self.assertEqual(_get_deps('bar'), 'foo') self.assertEqual(_get_deps('baz'), 'bar,foo') self.assertEqual(_get_deps('bak'), 'abc,bar,foo') self.assertEqual(_get_deps('abc'), '') self.assertEqual(_get_deps('def'), 'abc') self.assertEqual(_get_deps('ghi'), 'abc,def')
def create_const_for_anchor_generator(): """Creates a 'Const' node as an input to 'MultipleGridAnchorGenerator' Note the 'MultipleGridAnchorGenerator' TRT plugin node requires a [1.0, 1.0] array as input. Reference: https://stackoverflow.com/a/56296195/7596504 """ import numpy as np import tensorflow as tf from tensorflow.core.framework.tensor_pb2 import TensorProto from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto value = np.array([1.0, 1.0], dtype=np.float32) dt = tf.as_dtype(value.dtype).as_datatype_enum tensor_shape = TensorShapeProto( dim=[TensorShapeProto.Dim(size=s) for s in value.shape]) tensor_proto = TensorProto(tensor_content=value.tobytes(), tensor_shape=tensor_shape, dtype=dt) return tf.NodeDef(name='const_for_anchors', op='Const', attr={ 'value': tf.AttrValue(tensor=tensor_proto), 'dtype': tf.AttrValue(type=dt) })
def creat_const_node(op_name, arrary=None, dtype=tf.float32, shape=None): """ :param op_name: :param arrary: :param dtype: :param shape: :return: """ new_node = tf.NodeDef() new_node.op = 'Const' new_node.name = op_name new_node.attr['dtype'].CopyFrom( tf.AttrValue(type=dtype.as_datatype_enum)) assert list(np.shape(arrary)) == shape, "Please check the value" new_node.attr['value'].CopyFrom( tf.AttrValue(tensor=tf.make_tensor_proto(arrary, dtype, shape))) new_node.attr['_output_shapes'].CopyFrom( tf.AttrValue(list=tf.AttrValue.ListValue(shape=[ tensor_shape_pb2.TensorShapeProto(dim=[ tensor_shape_pb2.TensorShapeProto.Dim(size=x) for x in shape ]) ]))) return new_node
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices): """Converts all variables in a graph and checkpoint into constants.""" if not gfile.Exists(input_graph): print("Input graph file '" + input_graph + "' does not exist!") return -1 if input_saver and not gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 if not gfile.Exists(input_checkpoint): print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") return -1 if not output_node_names: print("You need to supply the name of a node to --output_node_names.") return -1 input_graph_def = tf.GraphDef() with open(input_graph, "rb") as f: if input_binary: input_graph_def.ParseFromString(f.read()) else: text_format.Merge(bytes(f.read()), input_graph_def) # Remove all the explicit device specifications for this node. This helps to # make the graph more portable. if clear_devices: for node in input_graph_def.node: node.device = "" _ = tf.import_graph_def(input_graph_def, name="") with tf.Session() as sess: if input_saver: with open(input_saver, "rb") as f: saver_def = tf.train.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) saver = tf.train.Saver(saver_def=saver_def) saver.restore(sess, input_checkpoint) else: sess.run([restore_op_name], {filename_tensor_name: input_checkpoint}) found_variables = {} for node in input_graph_def.node: if node.op == "Assign": variable_name = node.input[0] found_variables[variable_name] = sess.run(variable_name + ":0") # This graph only includes the nodes needed to evaluate the output nodes, and # removes unneeded nodes like those involved in saving and assignment. inference_graph = graph_util.extract_sub_graph( input_graph_def, output_node_names.split(",")) output_graph_def = tf.GraphDef() how_many_converted = 0 for input_node in inference_graph.node: output_node = tf.NodeDef() if input_node.name in found_variables: output_node.op = "Const" output_node.name = input_node.name dtype = input_node.attr["dtype"] data = found_variables[input_node.name] set_attr_dtype(output_node, "dtype", dtype) set_attr_tensor(output_node, "value", data, dtype.type, data.shape) how_many_converted += 1 else: output_node.CopyFrom(input_node) output_graph_def.node.extend([output_node]) with gfile.FastGFile(output_graph, "w") as f: f.write(output_graph_def.SerializeToString()) print("Converted %d variables to const ops." % how_many_converted) print("%d ops in the final graph." % len(output_graph_def.node))
def copy_op_to_graph(org_instance, to_graph, variables, scope=''): """Returns a copy of an operation from another Graph under a specified scope. Given an `Operation` `org_instance` from one `Graph`, initializes and returns a copy of it from another `Graph`, under the specified scope (default `""`). The copying is done recursively, so any `Operation` whose output is required to evaluate the `org_instance`, is also copied (unless already done). Since `Variable` instances are copied separately, those required to evaluate `org_instance` must be provided as input. Args: org_instance: An `Operation` from some `Graph`. Could be a `Placeholder` as well. to_graph: The `Graph` to copy `org_instance` to. variables: An iterable of `Variable` instances to copy `org_instance` to. scope: A scope for the new `Variable` (default `""`). Returns: The copied `Operation` from `to_graph`. Raises: TypeError: If `org_instance` is not an `Operation` or `Tensor`. """ #The name of the new instance if scope != '': new_name = scope + '/' + org_instance.name else: new_name = org_instance.name # print(new_name) #Extract names of variables copied_variables = dict((x.name, x) for x in variables) #If a variable by the new name already exists, return the #correspondng tensor that will act as an input if new_name in copied_variables: return to_graph.get_tensor_by_name(copied_variables[new_name].name) #If an instance of the same name exists, return appropriately try: already_present = to_graph.as_graph_element(new_name, allow_tensor=True, allow_operation=True) return already_present except: pass #Get the collections that the new instance needs to be added to. #The new collections will also be a part of the given scope. collections = [] for name, collection in org_instance.graph._collections.items(): if org_instance in collection: if scope == '': collections.append(name) else: collections.append(scope + '/' + name) #Take action based on the class of the instance if isinstance(org_instance, ops.Tensor): #If its a Tensor, it is one of the outputs of the underlying #op. Therefore, copy the op itself and return the appropriate #output. op = org_instance.op new_op = copy_op_to_graph(op, to_graph, variables, scope) output_index = op.outputs.index(org_instance) new_tensor = new_op.outputs[output_index] #Add to collections if any for collection in collections: to_graph.add_to_collection(collection, new_tensor) return new_tensor elif isinstance(org_instance, ops.Operation): op = org_instance #If it has an original_op parameter, copy it if op._original_op is not None: new_original_op = copy_op_to_graph(op._original_op, to_graph, variables, scope) else: new_original_op = None #If it has control inputs, call this function recursively on each. new_control_inputs = [ copy_op_to_graph(x, to_graph, variables, scope) for x in op.control_inputs ] #If it has inputs, call this function recursively on each. new_inputs = [ copy_op_to_graph(x, to_graph, variables, scope) for x in op.inputs ] #Make a new node_def based on that of the original. #An instance of tensorflow.core.framework.node_def_pb2.NodeDef, it #stores String-based info such as name, device and type of the op. #Unique to every Operation instance. #Colocate info needs to be cleared here new_attr = dict() for key in op.node_def.attr: # don't copy colocate info if key == '_class': pass else: new_attr[key] = op.node_def.attr[key] new_node_def = tf.NodeDef(name=new_name, op=op.node_def.op, input=op.node_def.input, device=op.node_def.device, attr=new_attr) #Copy the other inputs needed for initialization output_types = op._output_types[:] input_types = op._input_types[:] #Make a copy of the op_def too. #Its unique to every _type_ of Operation. op_def = deepcopy(op.op_def) #Initialize a new Operation instance new_op = ops.Operation(new_node_def, to_graph, new_inputs, output_types, new_control_inputs, input_types, new_original_op, op_def) #Use Graph's hidden methods to add the op to_graph._add_op(new_op) # pylint: disable=protected-access to_graph._record_op_seen_by_control_dependencies(new_op) for device_function in reversed(to_graph._device_function_stack): new_op._set_device(device_function(new_op)) return new_op else: raise TypeError('Could not copy instance: ' + str(org_instance))
def remove_unneeded_nodes(self, input_graph): """Prunes out nodes that aren't needed for inference. There are nodes like Identity and CheckNumerics that are only useful during training, and can be removed in graphs that will be used for nothing but inference. Here we identify and remove them, returning an equivalent graph. Args: input_graph: Model to analyze and prune. Returns: A list of nodes with the unnecessary ones removed. """ types_to_remove = {"CheckNumerics": True} input_nodes = input_graph.node names_to_remove = {} for node in input_nodes: if node.op in types_to_remove: names_to_remove[node.name] = True nodes_after_removal = [] for node in input_nodes: if node.name in names_to_remove: continue new_node = tf.NodeDef() new_node.CopyFrom(node) input_before_removal = node.input del new_node.input[:] for full_input_name in input_before_removal: input_name = re.sub(r"^\^", "", full_input_name) if input_name in names_to_remove: continue new_node.input.append(full_input_name) nodes_after_removal.append(new_node) types_to_splice = {"Identity": True} names_to_splice = {} for node in nodes_after_removal: if node.op in types_to_splice: # We don't want to remove nodes that have control edge inputs, because # they might be involved in subtle dependency issues that removing them # will jeopardize. has_control_edge = False for input_name in node.input: if re.match(r"^\^", input_name): has_control_edge = True if not has_control_edge: names_to_splice[node.name] = node.input[0] nodes_after_splicing = [] for node in nodes_after_removal: if node.name in names_to_splice: continue new_node = tf.NodeDef() new_node.CopyFrom(node) input_before_removal = node.input del new_node.input[:] for full_input_name in input_before_removal: input_name = re.sub(r"^\^", "", full_input_name) if input_name in names_to_splice: new_node.input.append(names_to_splice[input_name]) else: new_node.input.append(full_input_name) nodes_after_splicing.append(new_node) output_graph = tf.GraphDef() output_graph.node.extend(nodes_after_splicing) return output_graph
def main(_): print("Pix2pix tensorflow Exporter!") if not os.path.exists(args.checkpoint_dir): os.makedirs(args.checkpoint_dir) if not os.path.exists(args.sample_dir): os.makedirs(args.sample_dir) if not os.path.exists(args.test_dir): os.makedirs(args.test_dir) with tf.Session() as sess: model = pix2pix(sess, image_size=args.fine_size, batch_size=args.batch_size, output_size=args.fine_size, dataset_name=args.dataset_name, checkpoint_dir=args.checkpoint_dir, sample_dir=args.sample_dir, input_c_dim=args.input_nc, output_c_dim=args.output_nc, direction=args.which_direction) model.load_model(args) graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() # fix batch norm nodes for node in input_graph_def.node: if node.op == 'RefSwitch': node.op = 'Switch' for index in xrange(len(node.input)): if 'moving_' in node.input[index]: node.input[index] = node.input[index] + '/read' elif node.op == 'AssignSub': node.op = 'Sub' if 'use_locking' in node.attr: del node.attr['use_locking'] # freeze! freeze_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, ['generator/Tanh']) #copy input-related sub graph_util input_node_names_list = ['real_A_and_B_images'] input_replaced_graph_def = tf.GraphDef() for node in freeze_graph_def.node: if node.name in input_node_names_list: placeholder_node = tf.NodeDef() placeholder_node.op = 'Placeholder' placeholder_node.name = node.name placeholder_node.attr['dtype'].CopyFrom( tf.AttrValue(type=tf.float32.as_datatype_enum)) input_replaced_graph_def.node.extend([placeholder_node]) print(node.name, 'is replaced with placeholder') else: input_replaced_graph_def.node.extend([copy.deepcopy(node)]) # extract subgraph output_sub_graph_def = graph_util.extract_sub_graph( input_replaced_graph_def, ['generator/Tanh']) with tf.gfile.GFile('export_model.pb', 'wb') as f: f.write(output_sub_graph_def.SerializeToString())