def convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist=None, variable_names_blacklist=None, use_fp16=False): from tensorflow.python.framework.graph_util_impl import extract_sub_graph from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.python.framework import tensor_util def patch_dtype(input_node, field_name, output_node): if use_fp16 and (field_name in input_node.attr) and ( input_node.attr[field_name].type == types_pb2.DT_FLOAT): output_node.attr[field_name].CopyFrom( attr_value_pb2.AttrValue(type=types_pb2.DT_HALF)) inference_graph = extract_sub_graph(input_graph_def, output_node_names) variable_names = [] variable_dict_names = [] for node in inference_graph.node: if node.op in ["Variable", "VariableV2", "VarHandleOp"]: variable_name = node.name if ((variable_names_whitelist is not None and variable_name not in variable_names_whitelist) or (variable_names_blacklist is not None and variable_name in variable_names_blacklist)): continue variable_dict_names.append(variable_name) if node.op == "VarHandleOp": variable_names.append(variable_name + "/Read/ReadVariableOp:0") else: variable_names.append(variable_name + ":0") if variable_names: returned_variables = sess.run(variable_names) else: returned_variables = [] found_variables = dict(zip(variable_dict_names, returned_variables)) output_graph_def = graph_pb2.GraphDef() how_many_converted = 0 for input_node in inference_graph.node: output_node = node_def_pb2.NodeDef() if input_node.name in found_variables: output_node.op = "Const" output_node.name = input_node.name dtype = input_node.attr["dtype"] data = found_variables[input_node.name] if use_fp16 and dtype.type == types_pb2.DT_FLOAT: output_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue( tensor=tensor_util.make_tensor_proto( data.astype('float16'), dtype=types_pb2.DT_HALF, shape=data.shape))) else: output_node.attr["dtype"].CopyFrom(dtype) output_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue( tensor=tensor_util.make_tensor_proto( data, dtype=dtype.type, shape=data.shape))) how_many_converted += 1 elif input_node.op == "ReadVariableOp" and (input_node.input[0] in found_variables): # placeholder nodes # print('- %s | %s ' % (input_node.name, input_node.attr["dtype"])) output_node.op = "Identity" output_node.name = input_node.name output_node.input.extend([input_node.input[0]]) output_node.attr["T"].CopyFrom(input_node.attr["dtype"]) if "_class" in input_node.attr: output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) else: # mostly op nodes output_node.CopyFrom(input_node) patch_dtype(input_node, 'dtype', output_node) patch_dtype(input_node, 'T', output_node) patch_dtype(input_node, 'DstT', output_node) patch_dtype(input_node, 'SrcT', output_node) patch_dtype(input_node, 'Tparams', output_node) if use_fp16 and ('value' in output_node.attr) and ( output_node.attr['value'].tensor.dtype == types_pb2.DT_FLOAT): # hard-coded value need to be converted as well output_node.attr['value'].CopyFrom( attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( output_node.attr['value'].tensor.float_val[0], dtype=types_pb2.DT_HALF))) output_graph_def.node.extend([output_node]) output_graph_def.library.CopyFrom(inference_graph.library) return output_graph_def
def quantize_graph_def(graph_def, skip=None, output_nodes=None, rel_tol=None, only=None): """ :type graph_def: GraphDef :type skip: set|list :type output_nodes: list :type rel_tol: float :type only: str :return: QuantizedGraph """ if output_nodes is not None and len(output_nodes) > 0: graph_def = extract_sub_graph(graph_def, output_nodes) nodes = [] items = [] for node in graph_def.node: # check skip if should_skip(node, skip): nodes.append(node) continue # try convert to constant try: value = MakeNdarray(node.attr['value'].tensor) # type: np.ndarray except TypeError: nodes.append(node) continue # check repeated field same_value = all_same_value(value, rel_tol) if same_value is not None: nodes.append( const_node(node.attr['dtype'].type, np.array([same_value], dtype=value.dtype), value.shape)) continue # check data size elif value.size < 4096: nodes.append(node) continue # finally processed_node = NodeDef() processed_node.name = node.name processed_node.op = 'Placeholder' processed_node.attr['dtype'].type = node.attr['dtype'].type processed_node.attr['shape'].shape.CopyFrom( as_shape(value.shape).as_proto()) nodes.append(processed_node) item = QuantizedItem() item.name = node.name item.dtype = node.attr['dtype'].type item.shape.extend(value.shape) print('quantize {}'.format(node.name)) _fill(item, value, only=only) items.append(item) graph = QuantizedGraph() graph.graph.versions.CopyFrom(graph_def.versions) graph.graph.library.CopyFrom(graph_def.library) graph.graph.node.extend(nodes) graph.items.extend(items) return graph