def extract_weights(graph_def, output_graph, tf_version, signature_def, quantization_dtype=None): """Takes a Python GraphDef object and extract the weights. Args: graph_def: tf.GraphDef TensorFlow GraphDef proto object, which represents the model topology. tf_version: Tensorflow version of the input graph. signature_def: the SignatureDef of the inference graph. quantization_dtype: An optional numpy dtype to quantize weights to for compression. Only np.uint8 and np.uint16 are supported. """ constants = [node for node in graph_def.node if node.op == 'Const'] const_inputs = {} # removed the conditional inputs for constants for const in constants: const_inputs[const.name] = const.input[:] del const.input[:] print('Writing weight file ' + output_graph + '...') const_manifest = [] graph = tf.Graph() fuse_prelu.register_prelu_func(graph) with tf.compat.v1.Session(graph=graph) as sess: tf.import_graph_def(graph_def, name='') for const in constants: tensor = graph.get_tensor_by_name(const.name + ':0') value = tensor.eval(session=sess) if not isinstance(value, np.ndarray): value = np.array(value) const_manifest.append({'name': const.name, 'data': value}) # Restore the conditional inputs const.input[:] = const_inputs[const.name] # Remove the binary array from tensor and save it to the external file. for field_name in CLEARED_TENSOR_FIELDS: const.attr["value"].tensor.ClearField(field_name) write_artifacts(MessageToDict(graph_def), [const_manifest], output_graph, tf_version, signature_def, quantization_dtype=quantization_dtype)
def optimize_graph(graph, output_node_names, output_graph, tf_version, quantization_dtype=None, skip_op_check=False, strip_debug_ops=False): """Takes a Python Graph object and optimizes the graph. Args: graph: The frozen graph to optimize. output_node_names: List of output node names. output_graph: The location of the output graph. tf_version: Tensorflow version of the input graph. quantization_dtype: An optional numpy dtype to quantize weights to for compression. Only np.uint8 and np.uint16 are supported. skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to strip debug ops. """ fuse_prelu.register_prelu_func(graph) # Add a collection 'train_op' so that Grappler knows the outputs. for output in output_node_names: graph.add_to_collection('train_op', graph.get_operation_by_name(output)) graph_def = graph.as_graph_def() unsupported = validate(graph_def.node, skip_op_check, strip_debug_ops) if unsupported: raise ValueError('Unsupported Ops in the model before optimization\n' + ', '.join(unsupported)) # Because TF break the Prelu op into 6 ops, for performance we are # fusing those ops into a single prelu optimized_graph = fuse_prelu.fuse_ops_for_prelu(graph_def) # first pass of grappler optimization, this is needed for batch norm folding. config = config_pb2.ConfigProto() rewriter_config = config.graph_options.rewrite_options rewriter_config.optimizers[:] = [ 'pruning', 'constfold', 'arithmetic', 'dependency', 'pruning', 'constfold', 'arithmetic', 'dependency' ] if strip_debug_ops: rewriter_config.optimizers.insert(0, 'debug_stripper') optimized_graph = _run_grappler(config, optimized_graph, graph) # batch norm folding optimized_graph = fold_batch_norms.fold_batch_norms(optimized_graph) # set the device to CPU for all Conv2d nodes, since grappler remap optimizer # only support FusedConv2D for CPU. for node in optimized_graph.node: if node.op == 'Conv2D': node.device = '/device:CPU:0' # rerun grappler to fuse conv2d config.graph_options.rewrite_options.optimizers[:] = [ 'remap', 'constfold', 'arithmetic', 'dependency' ] optimized_graph = _run_grappler(config, optimized_graph, graph) # Since the grappler remap optimizer doe snot support prelu as the activation # function for _FusedConv2D op, we are doing it manually here. optimized_graph = fuse_prelu.fuse_prelu_with_fused_conv2d(optimized_graph) unsupported = validate(optimized_graph.node, skip_op_check, strip_debug_ops) if unsupported: raise ValueError('Unsupported Ops in the model after optimization\n' + ', '.join(unsupported)) extract_weights(optimized_graph, output_graph, tf_version, quantization_dtype) return optimize_graph