def auto_parallel(metagraph, model): from tensorflow.python.grappler import tf_optimizer rewriter_config = rewriter_config_pb2.RewriterConfig() rewriter_config.optimizers.append("autoparallel") rewriter_config.auto_parallel.enable = True rewriter_config.auto_parallel.num_replicas = FLAGS.num_gpus optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph) metagraph.graph_def.CopyFrom(optimized_graph) UpdateCollection(metagraph, model)
def _run_inline_graph_optimization(func, lower_control_flow): """Apply function inline optimization to the graph. Returns the GraphDef after Grappler's function inlining optimization is applied. This optimization does not work on models with control flow. Args: func: ConcreteFunction. lower_control_flow: Boolean indicating whether or not to lower control flow ops such as If and While. (default True) Returns: GraphDef """ graph_def = func.graph.as_graph_def() if not lower_control_flow: graph_def = disable_lower_using_switch_merge(graph_def) # In some cases, a secondary implementation of the function (e.g. for GPU) is # written to the "api_implements" attribute. (e.g. `tf.keras.layers.LSTM` in # TF2 produces a CuDNN-based RNN for GPU). # This function suppose to inline all functions calls, but "api_implements" # prevents this from happening. Removing the attribute solves the problem. # To learn more about "api_implements", see: # tensorflow/core/grappler/optimizers/implementation_selector.h for function in graph_def.library.function: if "api_implements" in function.attr: del function.attr["api_implements"] meta_graph = export_meta_graph(graph_def=graph_def, graph=func.graph) # Clear the initializer_name for the variables collections, since they are not # needed after saved to saved_model. for name in [ "variables", "model_variables", "trainable_variables", "local_variables" ]: raw_list = [] for raw in meta_graph.collection_def["variables"].bytes_list.value: variable = variable_pb2.VariableDef() variable.ParseFromString(raw) variable.ClearField("initializer_name") raw_list.append(variable.SerializeToString()) meta_graph.collection_def[name].bytes_list.value[:] = raw_list # Add a collection 'train_op' so that Grappler knows the outputs. fetch_collection = meta_graph_pb2.CollectionDef() for array in func.inputs + func.outputs: fetch_collection.node_list.value.append(array.name) meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) # Initialize RewriterConfig with everything disabled except function inlining. config = config_pb2.ConfigProto() rewrite_options = config.graph_options.rewrite_options rewrite_options.min_graph_nodes = -1 # do not skip small graphs rewrite_options.optimizers.append("function") return tf_optimizer.OptimizeGraph(config, meta_graph)
def GetOptimizedGraph(): mg = meta_graph.create_meta_graph_def( graph=ops.get_default_graph()) config = config_pb2.ConfigProto() config.graph_options.rewrite_options.CopyFrom( rewriter_config_pb2.RewriterConfig( constant_folding=rewriter_config_pb2.RewriterConfig.OFF, memory_optimization=rewriter_config_pb2.RewriterConfig. MANUAL)) return tf_optimizer.OptimizeGraph(config, mg)
def optimize_graph_spec(graph_spec_obj): meta_graph_def = graph_spec_obj.to_meta_graph_def() config_proto = tf.compat.v1.ConfigProto() # TODO(b/154367032): Determine a set of optimizer configurations for TFF. optimized_graph_def = tf_optimizer.OptimizeGraph(config_proto, meta_graph_def) return graph_spec.GraphSpec(optimized_graph_def, init_op=graph_spec_obj.init_op, in_names=graph_spec_obj.in_names, out_names=graph_spec_obj.out_names)
def optimize_graph(graph, output_node_names, output_graph, tf_version, quantization_dtype=None, skip_op_check=False, strip_debug_ops=False): """Takes a Python Graph object and optimizes the graph. Args: graph: The frozen graph to optimize. output_node_names: List of output node names. output_graph: The location of the output graph. tf_version: Tensorflow version of the input graph. quantization_dtype: An optional numpy dtype to quantize weights to for compression. Only np.uint8 and np.uint16 are supported. skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to strip debug ops. """ # Add a collection 'train_op' so that Grappler knows the outputs. for output in output_node_names: graph.add_to_collection('train_op', graph.get_operation_by_name(output)) graph_def = graph.as_graph_def() unsupported = validate(graph_def.node, skip_op_check, strip_debug_ops) if unsupported: raise ValueError('Unsupported Ops in the model before optimization\n' + ', '.join(unsupported)) config = config_pb2.ConfigProto() rewriter_config = config.graph_options.rewrite_options rewriter_config.optimizers[:] = [ 'pruning', 'constfold', 'arithmetic', 'dependency', 'pruning', 'remap', 'constfold', 'arithmetic', 'dependency' ] if strip_debug_ops: rewriter_config.optimizers.insert(0, 'debug_stripper') meta_graph = export_meta_graph(graph_def=graph_def, graph=graph) optimized_graph = tf_optimizer.OptimizeGraph(config, meta_graph, cluster=get_cluster()) unsupported = validate(optimized_graph.node, skip_op_check, strip_debug_ops) if unsupported: raise ValueError('Unsupported Ops in the model after optimization\n' + ', '.join(unsupported)) extract_weights(optimized_graph, output_graph, tf_version, quantization_dtype) return optimize_graph
def testGradient(self): meta_graph = _simple_metagraph() rewrite_options = rewriter_config_pb2.RewriterConfig( layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) optimized_graph = tf_optimizer.OptimizeGraph( rewrite_options, meta_graph, cluster=_get_cluster()) found = 0 for node in optimized_graph.node: if node.op in ['Conv2D', 'Conv2DBackpropFilter', 'Conv2DBackpropInput']: found += 1 self.assertEqual(node.attr['data_format'].s, 'NCHW') self.assertEqual(found, 5)
def _inline_functions(self, graph_def, arrays): meta_graph = export_meta_graph(graph_def=graph_def) fetch_collection = meta_graph_pb2.CollectionDef() for name in arrays: fetch_collection.node_list.value.append(name) meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) # Initialize RewriterConfig with everything disabled except function # inlining. config = tf.compat.v1.ConfigProto() rewrite_options = config.graph_options.rewrite_options rewrite_options.optimizers.append("function") return tf_optimizer.OptimizeGraph(config, meta_graph)
def main(_): with gfile.GFile(FLAGS.input) as input_file: metagraph = meta_graph_pb2.MetaGraphDef() metagraph.ParseFromString(input_file.read()) if FLAGS.rewriter_config is not None: rewriter_config = rewriter_config_pb2.RewriterConfig() text_format.Merge(FLAGS.rewriter_config, rewriter_config) optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph) metagraph.graph_def.CopyFrom(optimized_graph) report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report) print(report)
def main(_): metagraph = get_metagraph() rewriter_config = rewriter_config_pb2.RewriterConfig() if FLAGS.rewriter_config is not None: text_format.Merge(FLAGS.rewriter_config, rewriter_config) optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph) metagraph.graph_def.CopyFrom(optimized_graph) report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report, FLAGS.verbose) print(report) if FLAGS.memory_report: report = cost_analyzer.GenerateMemoryReport(metagraph) print(report)
def _optimize_graph(meta_graph_def, signature_def): """Optimize `meta_graph_def` using grappler. Returns a `GraphDef`.""" # We need to add a collection called 'train_op' so that grappler # knows what the outputs are. new_meta_graph_def = copy.deepcopy(meta_graph_def) fetch_collection = meta_graph_pb2.CollectionDef() for tensor_info in (list(signature_def.inputs.values()) + list(signature_def.outputs.values())): fetch_collection.node_list.value.append(tensor_info.name) new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection) config = config_pb2.ConfigProto() return tf_optimizer.OptimizeGraph(config, new_meta_graph_def)
def _run_conversion(self): """Run Grappler's OptimizeGraph() tool to convert the graph.""" # Create custom ConfigProto for Grappler. grappler_session_config = config_pb2.ConfigProto() grappler_session_config.CopyFrom(self._session_config) custom_rewriter_config = self.get_rewriter_config() grappler_session_config.graph_options.rewrite_options.CopyFrom( custom_rewriter_config) # Run Grappler. self._converted_graph_def = tf_optimizer.OptimizeGraph( grappler_session_config, self._grappler_meta_graph_def, graph_id=b"tf_graph") self._converted = True
def _run_conversion(self, meta_graph_def): """Run Grappler's OptimizeGraph() tool to convert the graph. Args: meta_graph_def: the MetaGraphDef instance to run the optimizations on. Returns: The optimized GraphDef. """ rewriter_config = get_tensorrt_rewriter_config( conversion_params=self._conversion_params, is_v2=True) grappler_session_config = config_pb2.ConfigProto() grappler_session_config.graph_options.rewrite_options.CopyFrom( rewriter_config) return tf_optimizer.OptimizeGraph( grappler_session_config, meta_graph_def, graph_id=b"tf_graph")
def testDepthwise(self): meta_graph = _simple_metagraph(depthwise=True) rewrite_options = rewriter_config_pb2.RewriterConfig( layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) optimized_graph = tf_optimizer.OptimizeGraph( rewrite_options, meta_graph, cluster=_get_cluster()) found = 0 for node in optimized_graph.node: if node.op in [ 'DepthwiseConv2dNative', 'DepthwiseConv2dNativeBackpropFilter', 'DepthwiseConv2dNativeBackpropInput' ]: found += 1 self.assertEqual(node.attr['data_format'].s, 'NCHW') self.assertEqual(found, 6)
def testHintDoesRewrite(self): graph = self._annotated_graph()[0] with graph.as_default(): metagraph = train.export_meta_graph() self.assertEqual( 0, len([node for node in metagraph.graph_def.node if 'Recomputed/' in node.name])) rewritten_graph_def = tf_optimizer.OptimizeGraph( rewriter_config_pb2.RewriterConfig( memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL), metagraph) self.assertEqual( 9, len([node for node in rewritten_graph_def.node if 'Recomputed/' in node.name]))
def tf_optimize_grappler(input_names, output_names, graph_def): config = config_pb2.ConfigProto() rewrite_options = config.graph_options.rewrite_options config.graph_options.infer_shapes = True rewrite_options.optimizers[:] = [ 'pruning', 'constfold', 'arithmetic', 'dependency', 'function', ] meta_graph = tf.compat.v1.train.export_meta_graph(graph_def=graph_def) fetch_collection = meta_graph_pb2.CollectionDef() for t in input_names + output_names: fetch_collection.node_list.value.append(t) meta_graph.collection_def['train_op'].CopyFrom(fetch_collection) graph_def = tf_optimizer.OptimizeGraph(config, meta_graph) return graph_def
def optimize(g, inputs, outputs): sd = SignatureDef() for name in inputs: input_t = g.get_operation_by_name(name).outputs[0] sd.inputs[name].name = name sd.inputs[name].dtype = input_t.dtype.as_datatype_enum sd.inputs[name].tensor_shape.CopyFrom(input_t.shape.as_proto()) for name in outputs: output_t = g.get_operation_by_name(name).outputs[0] sd.outputs[name].name = name sd.outputs[name].dtype = output_t.dtype.as_datatype_enum sd.outputs[name].tensor_shape.CopyFrom(output_t.shape.as_proto()) tf.compat.v1.enable_resource_variables() cl = cluster.Cluster(disable_detailed_stats=True) # We have to run this twice to eliminate constants that are left after # optimising away split/pad/transpose nodes. They are const parameters like # axis, perm. They remain after 1 iter of optimization because we specify them # in the whitelist for i in range(2): if i == 0: graph = g c = get_default_config() else: graph = get_graph_from(optimized_graph_def) c = get_only_prune_config() white_list = get_white_list(graph) for name in white_list: graph.add_to_collection( GraphKeys.TRAIN_OP, graph.get_operation_by_name(name) ) meta_graph = tf.compat.v1.train.export_meta_graph( graph_def=graph.as_graph_def(), graph=graph ) meta_graph.signature_def["not_used_key"].CopyFrom(sd) optimized_graph_def = tf_optimizer.OptimizeGraph( config_proto=c, metagraph=meta_graph, cluster=cl ) # Don't create VarHandleOp, ReadVariableOp, VarIsInitializedOp # Instead create VariableV2 ops in the future tf.disable_resource_variables() return optimized_graph_def
def testNoSwapping(self): """Make sure the graph is preserved when there is nothing to swap.""" a = constant_op.constant(10, name='a') b = constant_op.constant(20, name='b') c = math_ops.add_n([a, b], name='c') d = math_ops.add_n([b, c], name='d') train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(d) mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) rewriter_config = rewriter_config_pb2.RewriterConfig( memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL) graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) self.assertEqual(len(graph.node), 4) self.assertItemsEqual([node.name for node in graph.node], ['a', 'b', 'c', 'd'])
def _run_inline_graph_optimization(func, lower_control_flow): """Apply function inline optimization to the graph. Returns the GraphDef after Grappler's function inlining optimization is applied. This optimization does not work on models with control flow. Args: func: ConcreteFunction. lower_control_flow: Boolean indicating whether or not to lower control flow ops such as If and While. (default True) Returns: GraphDef """ graph_def = func.graph.as_graph_def() if not lower_control_flow: graph_def = disable_lower_using_switch_merge(graph_def) meta_graph = export_meta_graph(graph_def=graph_def, graph=func.graph) # Clear the initializer_name for the variables collections, since they are not # needed after saved to saved_model. for name in [ "variables", "model_variables", "trainable_variables", "local_variables" ]: raw_list = [] for raw in meta_graph.collection_def["variables"].bytes_list.value: variable = variable_pb2.VariableDef() variable.ParseFromString(raw) variable.ClearField("initializer_name") raw_list.append(variable.SerializeToString()) meta_graph.collection_def[name].bytes_list.value[:] = raw_list # Add a collection 'train_op' so that Grappler knows the outputs. fetch_collection = meta_graph_pb2.CollectionDef() for array in func.inputs + func.outputs: fetch_collection.node_list.value.append(array.name) meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) # Initialize RewriterConfig with everything disabled except function inlining. config = config_pb2.ConfigProto() rewrite_options = config.graph_options.rewrite_options rewrite_options.min_graph_nodes = -1 # do not skip small graphs rewrite_options.optimizers.append("function") return tf_optimizer.OptimizeGraph(config, meta_graph)
def testBasic(self): """Make sure arguments can be passed correctly.""" a = constant_op.constant(10, name='a') b = constant_op.constant(20, name='b') c = math_ops.add_n([a, b], name='c') d = math_ops.add_n([b, c], name='d') train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(d) mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) rewriter_config = rewriter_config_pb2.RewriterConfig() rewriter_config.optimizers.append('constfold') graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) self.assertEqual(len(graph.node), 5) self.assertItemsEqual([node.name for node in graph.node], ['a', 'b', 'c', 'd', 'ConstantFolding/c'])
def optimize_graph(graph, output_graph, quantization_dtype=None): """Takes a Python Graph object and optimizes the graph. Args: graph: tf.Graph tensorflow dataflow graph """ rewriter_config = rewriter_config_pb2.RewriterConfig() rewriter_config.optimizers[:] = [ 'pruning', 'constfold', 'arithmetic', 'dependency', 'pruning', 'constfold', 'arithmetic', 'dependency' ] meta_graph = tf.train.export_meta_graph( graph_def=graph.as_graph_def(), graph=graph) optimized_graph = tf_optimizer.OptimizeGraph( rewriter_config, meta_graph, cluster=get_cluster()) extract_weights(optimized_graph, output_graph, quantization_dtype) return optimize_graph
def constfold(graphdef, output_name): graph = ops.Graph() with graph.as_default(): outputs = output_name.split(',') output_collection = meta_graph_pb2.CollectionDef() output_list = output_collection.node_list.value for output in outputs: output_list.append(output) importer.import_graph_def(graphdef, name="") metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def(add_shapes=True), graph=graph) metagraph.collection_def["train_op"].CopyFrom(output_collection) rewriter_config = rewriter_config_pb2.RewriterConfig() rewriter_config.optimizers.extend(["constfold"]) rewriter_config.meta_optimizer_iterations = (rewriter_config_pb2.RewriterConfig.ONE) session_config = config_pb2.ConfigProto() session_config.graph_options.rewrite_options.CopyFrom(rewriter_config) return tf_optimizer.OptimizeGraph(session_config, metagraph)
def do_transformation(self): convert = False for node in self.model.node: if 'Conv' in node.op and \ 'data_format' in node.attr and \ node.attr['data_format'].s == b'NCHW': convert = True break if convert: assert tf.version.VERSION >= '2.4.0', 'layout convert is only supported by \ tensorflow 2.4.0 and above' g = tf.Graph() with g.as_default(): # pylint: disable=not-context-manager g = tf.compat.v1.import_graph_def(self.model, name='') meta_graph = saver_lib.export_meta_graph( graph_def=self.model, graph=g, clear_devices=True) fetch_collection = meta_graph_pb2.CollectionDef() for fetch in self.outputs: fetch_collection.node_list.value.append(fetch) # pylint: disable=no-member meta_graph.collection_def["train_op"].CopyFrom( # pylint: disable=no-member fetch_collection) # pylint: disable=no-member config = config_pb2.ConfigProto() convert = rewriter_config_pb2.RewriterConfig.NCHW_TO_NHWC # pylint: disable=no-member config.graph_options.rewrite_options.CopyFrom( # pylint: disable=no-member rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, constant_folding=rewriter_config_pb2.RewriterConfig.OFF, dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, memory_optimization=rewriter_config_pb2.RewriterConfig.NO_MEM_OPT, arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, shape_optimization=rewriter_config_pb2.RewriterConfig.OFF, loop_optimization=rewriter_config_pb2.RewriterConfig.OFF, function_optimization=rewriter_config_pb2.RewriterConfig.OFF, remapping=rewriter_config_pb2.RewriterConfig.OFF, implementation_selector=rewriter_config_pb2.RewriterConfig.OFF, cpu_layout_conversion=convert)) optimized_graph = tf_optimizer.OptimizeGraph(config, meta_graph) return optimized_graph else: return self.model
def testLoops(self): g = ops.Graph() with g.as_default(): def _Cond(_, counter): return counter < end def _Body(buf, counter): buf = array_ops.concat([buf, [counter]], 0) counter += 1 return [buf, counter] start = array_ops.placeholder(shape=[], dtype=dtypes.int32) end = array_ops.placeholder(shape=[], dtype=dtypes.int32) init_buf = array_ops.zeros(shape=[0], dtype=dtypes.int32) loop_vars = [init_buf, start] shape_inv = [ tensor_shape.TensorShape([None]), tensor_shape.TensorShape([]) ] buf, _ = control_flow_ops.while_loop(_Cond, _Body, loop_vars, shape_inv) f = -array_ops.ones_like(buf, optimize=False) # pylint: disable=invalid-unary-operand-type buf_shape = array_ops.shape(buf) f_shape = array_ops.shape(f) ops.add_to_collection('train_op', buf_shape) ops.add_to_collection('train_op', f_shape) # Optimize the graph. mg = meta_graph.create_meta_graph_def(graph=g) config = config_pb2.ConfigProto() rewriter_config = config.graph_options.rewrite_options rewriter_config.min_graph_nodes = -1 optimized_graph = tf_optimizer.OptimizeGraph(config, mg) mg.graph_def.CopyFrom(optimized_graph) # Check that the nodes referenced in various collections have been preserved item = gitem.Item(mg) props = item.GetOpProperties() buf_prop = props[buf.op.name] f_prop = props[f.op.name] self.assertEqual(buf_prop, f_prop)
def testNoSwapping(self): """Make sure the graph is preserved when there is nothing to swap.""" a = variables.Variable(10, name='a') b = variables.Variable(20, name='b') c = math_ops.add_n([a, b], name='c') d = math_ops.add_n([b, c], name='d') train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(d) mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) graph_size = len(mg.graph_def.node) nodes = [node.name for node in mg.graph_def.node] rewriter_config = rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL) graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) self.assertEqual(len(graph.node), graph_size) self.assertItemsEqual([node.name for node in graph.node], nodes)
def _test_convert_variables_with_functions(self, inline_functions): """Freezes a graph with functions.""" @function.Defun(dtypes.float32) def plus_one(x): return x + 1.0 with ops.Graph().as_default(): variable_node = variables.Variable(1.0, name="variable_node") _ = variables.Variable(1.0, name="unused_variable_node") defun_node = plus_one(variable_node) _ = math_ops_lib.multiply(defun_node, 2.0, name="output_node") with session.Session() as sess: self.evaluate(variables.variables_initializer([variable_node])) variable_graph_def = sess.graph.as_graph_def() if inline_functions: # Run Grappler to create the VarOpHandle --> Placeholder --> # ResourceVariable pattern. meta_graph = export_meta_graph( graph_def=variable_graph_def) fetch_collection = meta_graph_pb2.CollectionDef() for name in ["variable_node", "output_node"]: fetch_collection.node_list.value.append(name) meta_graph.collection_def["train_op"].CopyFrom( fetch_collection) # Initialize RewriterConfig with everything disabled except function # inlining. config = config_pb2.ConfigProto() rewrite_options = config.graph_options.rewrite_options rewrite_options.optimizers.append("function") variable_graph_def = tf_optimizer.OptimizeGraph( config, meta_graph) constant_graph_def = graph_util.convert_variables_to_constants( sess, variable_graph_def, ["output_node"]) # Ensure there are no variables after freezing. for node in constant_graph_def.node: self.assertNotIn( node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
def run_graph_optimizations(graph_def, input_arrays, output_arrays, config, graph=None): """Apply standard TensorFlow optimizations to the graph_def. Args: graph_def: Frozen GraphDef to be optimized. input_arrays: List of arrays that are considered inputs of the graph. output_arrays: List of arrays that are considered outputs of the graph. config: tf.ConfigProto. graph: TensorFlow Graph. Required when Eager mode is enabled. (default None) Returns: A new, optimized GraphDef. """ meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph) signature = _meta_graph_pb2.SignatureDef() for array in input_arrays: signature.inputs[array.name].name = array.name signature.inputs[array.name].dtype = array.dtype.as_datatype_enum signature.inputs[array.name].tensor_shape.CopyFrom( array.shape.as_proto()) for array in output_arrays: signature.outputs[array.name].name = array.name signature.outputs[array.name].dtype = array.dtype.as_datatype_enum signature.outputs[array.name].tensor_shape.CopyFrom( array.shape.as_proto()) meta_graph.signature_def["not_used_key"].CopyFrom(signature) # We need to add a collection called 'train_op' so that grappler # knows what the outputs are. fetch_collection = _meta_graph_pb2.CollectionDef() for array in input_arrays + output_arrays: fetch_collection.node_list.value.append(array.name) meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) return tf_optimizer.OptimizeGraph(config, meta_graph)
def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=None): from tensorflow.core.protobuf import meta_graph_pb2 as meta_graph_pb2, config_pb2, rewriter_config_pb2 from tensorflow.python.grappler import tf_optimizer as tf_opt config = config_pb2.ConfigProto() rewrite_options = config.graph_options.rewrite_options config.graph_options.infer_shapes = True # TODO: if we turn on pruning, grappler removes some identities that the tf-1.x lstm rewriter # depends on so for now don't turn this on. rewrite_options.optimizers[:] = [ # 'pruning', 'constfold', 'arithmetic', 'dependency', 'function', 'constfold', 'function' ] meta_graph = tf.compat.v1.train.export_meta_graph(graph_def=graph_def) fetch_collection = meta_graph_pb2.CollectionDef() for t in input_names + output_names: fetch_collection.node_list.value.append(t) meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) graph_def = tf_opt.OptimizeGraph(config, meta_graph) return graph_def
def testBasic(self): """Make sure arguments can be passed correctly.""" a = constant_op.constant(10, name='a') b = constant_op.constant(20, name='b') c = math_ops.add_n([a, b], name='c') d = math_ops.add_n([b, c], name='d') train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) # Being a train_op will make 'd' to be added as a fetch node. train_op.append(d) mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) config = config_pb2.ConfigProto() rewriter_config = config.graph_options.rewrite_options rewriter_config.optimizers.append('constfold') rewriter_config.min_graph_nodes = -1 graph = tf_optimizer.OptimizeGraph(config, mg) self.assertEqual(len(graph.node), 1) self.assertItemsEqual([node.name for node in graph.node], ['d'])
def testSimpleSwap(self): """Check that the swap annotations are followed.""" with ops.device('/gpu:0'): a = variables.VariableV1(10, name='a') b = variables.VariableV1(20, name='b') c = math_ops.add_n([a, b], name='c') d = math_ops.add_n([b, c], name='d') train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(d) d.op._set_attr('_swap_to_host', attr_value_pb2.AttrValue(i=0)) mg = meta_graph.create_meta_graph_def( graph=ops.get_default_graph()) graph_size = len(mg.graph_def.node) config = config_pb2.ConfigProto() config.graph_options.rewrite_options.CopyFrom( rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, meta_optimizer_iterations=rewriter_config_pb2. RewriterConfig.ONE, constant_folding=rewriter_config_pb2.RewriterConfig.OFF, memory_optimization=rewriter_config_pb2.RewriterConfig. MANUAL, min_graph_nodes=-1)) graph = tf_optimizer.OptimizeGraph(config, mg) self.assertEqual(len(graph.node), graph_size + 2) self.assertTrue( set([node.name for node in graph.node]) > set( ['a', 'b', 'c', 'd', 'swap_in_d_0', 'swap_out_d_0'])) for node in graph.node: if node.name == 'swap_in_d_0': self.assertEqual('swap_out_d_0', node.input[0]) self.assertEqual('^b/read', node.input[1]) elif node.name == 'swap_out_d_0': self.assertEqual('b/read', node.input[0]) elif node.name == 'd': self.assertEqual('swap_in_d_0', node.input[0]) self.assertEqual('c', node.input[1])
def grappler_optimize(graph, fetches=None, rewriter_config=None): """Tries to optimize the provided graph using grappler. Args: graph: A @{tf.Graph} instance containing the graph to optimize. fetches: An optional list of `Tensor`s to fetch (i.e. not optimize away). Grappler uses the 'train_op' collection to look for fetches, so if not provided this collection should be non-empty. rewriter_config: An optional @{tf.RewriterConfig} to use when rewriting the graph. Returns: A @{tf.GraphDef} containing the rewritten graph. """ if rewriter_config is None: rewriter_config = rewriter_config_pb2.RewriterConfig() if fetches is not None: for fetch in fetches: graph.add_to_collection('train_op', fetch) metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def()) return tf_optimizer.OptimizeGraph(rewriter_config, metagraph)