Esempio n. 1
0
def auto_parallel(metagraph, model):
    from tensorflow.python.grappler import tf_optimizer
    rewriter_config = rewriter_config_pb2.RewriterConfig()
    rewriter_config.optimizers.append("autoparallel")
    rewriter_config.auto_parallel.enable = True
    rewriter_config.auto_parallel.num_replicas = FLAGS.num_gpus
    optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
    metagraph.graph_def.CopyFrom(optimized_graph)
    UpdateCollection(metagraph, model)
Esempio n. 2
0
def _run_inline_graph_optimization(func, lower_control_flow):
  """Apply function inline optimization to the graph.

  Returns the GraphDef after Grappler's function inlining optimization is
  applied. This optimization does not work on models with control flow.

  Args:
    func: ConcreteFunction.
    lower_control_flow: Boolean indicating whether or not to lower control flow
      ops such as If and While. (default True)

  Returns:
    GraphDef
  """
  graph_def = func.graph.as_graph_def()
  if not lower_control_flow:
    graph_def = disable_lower_using_switch_merge(graph_def)

  # In some cases, a secondary implementation of the function (e.g. for GPU) is
  # written to the "api_implements" attribute. (e.g. `tf.keras.layers.LSTM` in
  # TF2 produces a CuDNN-based RNN for GPU).
  # This function suppose to inline all functions calls, but "api_implements"
  # prevents this from happening. Removing the attribute solves the problem.
  # To learn more about "api_implements", see:
  #   tensorflow/core/grappler/optimizers/implementation_selector.h
  for function in graph_def.library.function:
    if "api_implements" in function.attr:
      del function.attr["api_implements"]

  meta_graph = export_meta_graph(graph_def=graph_def, graph=func.graph)

  # Clear the initializer_name for the variables collections, since they are not
  # needed after saved to saved_model.
  for name in [
      "variables", "model_variables", "trainable_variables", "local_variables"
  ]:
    raw_list = []
    for raw in meta_graph.collection_def["variables"].bytes_list.value:
      variable = variable_pb2.VariableDef()
      variable.ParseFromString(raw)
      variable.ClearField("initializer_name")
      raw_list.append(variable.SerializeToString())
    meta_graph.collection_def[name].bytes_list.value[:] = raw_list

  # Add a collection 'train_op' so that Grappler knows the outputs.
  fetch_collection = meta_graph_pb2.CollectionDef()
  for array in func.inputs + func.outputs:
    fetch_collection.node_list.value.append(array.name)
  meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)

  # Initialize RewriterConfig with everything disabled except function inlining.
  config = config_pb2.ConfigProto()
  rewrite_options = config.graph_options.rewrite_options
  rewrite_options.min_graph_nodes = -1  # do not skip small graphs
  rewrite_options.optimizers.append("function")
  return tf_optimizer.OptimizeGraph(config, meta_graph)
Esempio n. 3
0
 def GetOptimizedGraph():
     mg = meta_graph.create_meta_graph_def(
         graph=ops.get_default_graph())
     config = config_pb2.ConfigProto()
     config.graph_options.rewrite_options.CopyFrom(
         rewriter_config_pb2.RewriterConfig(
             constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
             memory_optimization=rewriter_config_pb2.RewriterConfig.
             MANUAL))
     return tf_optimizer.OptimizeGraph(config, mg)
Esempio n. 4
0
def optimize_graph_spec(graph_spec_obj):
    meta_graph_def = graph_spec_obj.to_meta_graph_def()
    config_proto = tf.compat.v1.ConfigProto()
    # TODO(b/154367032): Determine a set of optimizer configurations for TFF.
    optimized_graph_def = tf_optimizer.OptimizeGraph(config_proto,
                                                     meta_graph_def)
    return graph_spec.GraphSpec(optimized_graph_def,
                                init_op=graph_spec_obj.init_op,
                                in_names=graph_spec_obj.in_names,
                                out_names=graph_spec_obj.out_names)
Esempio n. 5
0
def optimize_graph(graph,
                   output_node_names,
                   output_graph,
                   tf_version,
                   quantization_dtype=None,
                   skip_op_check=False,
                   strip_debug_ops=False):
    """Takes a Python Graph object and optimizes the graph.

  Args:
    graph: The frozen graph to optimize.
    output_node_names: List of output node names.
    output_graph: The location of the output graph.
    tf_version: Tensorflow version of the input graph.
    quantization_dtype: An optional numpy dtype to quantize weights to for
      compression. Only np.uint8 and np.uint16 are supported.
    skip_op_check: Bool whether to skip the op check.
    strip_debug_ops: Bool whether to strip debug ops.
  """

    # Add a collection 'train_op' so that Grappler knows the outputs.
    for output in output_node_names:
        graph.add_to_collection('train_op',
                                graph.get_operation_by_name(output))

    graph_def = graph.as_graph_def()
    unsupported = validate(graph_def.node, skip_op_check, strip_debug_ops)
    if unsupported:
        raise ValueError('Unsupported Ops in the model before optimization\n' +
                         ', '.join(unsupported))

    config = config_pb2.ConfigProto()
    rewriter_config = config.graph_options.rewrite_options
    rewriter_config.optimizers[:] = [
        'pruning', 'constfold', 'arithmetic', 'dependency', 'pruning', 'remap',
        'constfold', 'arithmetic', 'dependency'
    ]
    if strip_debug_ops:
        rewriter_config.optimizers.insert(0, 'debug_stripper')
    meta_graph = export_meta_graph(graph_def=graph_def, graph=graph)

    optimized_graph = tf_optimizer.OptimizeGraph(config,
                                                 meta_graph,
                                                 cluster=get_cluster())

    unsupported = validate(optimized_graph.node, skip_op_check,
                           strip_debug_ops)

    if unsupported:
        raise ValueError('Unsupported Ops in the model after optimization\n' +
                         ', '.join(unsupported))

    extract_weights(optimized_graph, output_graph, tf_version,
                    quantization_dtype)
    return optimize_graph
Esempio n. 6
0
  def testGradient(self):
    meta_graph = _simple_metagraph()
    rewrite_options = rewriter_config_pb2.RewriterConfig(
        layout_optimizer=rewriter_config_pb2.RewriterConfig.ON)
    optimized_graph = tf_optimizer.OptimizeGraph(
        rewrite_options, meta_graph, cluster=_get_cluster())

    found = 0
    for node in optimized_graph.node:
      if node.op in ['Conv2D', 'Conv2DBackpropFilter', 'Conv2DBackpropInput']:
        found += 1
        self.assertEqual(node.attr['data_format'].s, 'NCHW')
    self.assertEqual(found, 5)
Esempio n. 7
0
    def _inline_functions(self, graph_def, arrays):
        meta_graph = export_meta_graph(graph_def=graph_def)
        fetch_collection = meta_graph_pb2.CollectionDef()
        for name in arrays:
            fetch_collection.node_list.value.append(name)
        meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)

        # Initialize RewriterConfig with everything disabled except function
        # inlining.
        config = tf.compat.v1.ConfigProto()
        rewrite_options = config.graph_options.rewrite_options
        rewrite_options.optimizers.append("function")
        return tf_optimizer.OptimizeGraph(config, meta_graph)
Esempio n. 8
0
def main(_):
  with gfile.GFile(FLAGS.input) as input_file:
    metagraph = meta_graph_pb2.MetaGraphDef()
    metagraph.ParseFromString(input_file.read())

  if FLAGS.rewriter_config is not None:
    rewriter_config = rewriter_config_pb2.RewriterConfig()
    text_format.Merge(FLAGS.rewriter_config, rewriter_config)
    optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
    metagraph.graph_def.CopyFrom(optimized_graph)

  report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report)
  print(report)
Esempio n. 9
0
def main(_):
    metagraph = get_metagraph()
    rewriter_config = rewriter_config_pb2.RewriterConfig()
    if FLAGS.rewriter_config is not None:
        text_format.Merge(FLAGS.rewriter_config, rewriter_config)
    optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
    metagraph.graph_def.CopyFrom(optimized_graph)

    report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report,
                                              FLAGS.verbose)
    print(report)
    if FLAGS.memory_report:
        report = cost_analyzer.GenerateMemoryReport(metagraph)
        print(report)
def _optimize_graph(meta_graph_def, signature_def):
    """Optimize `meta_graph_def` using grappler.  Returns a `GraphDef`."""
    # We need to add a collection called 'train_op' so that grappler
    # knows what the outputs are.
    new_meta_graph_def = copy.deepcopy(meta_graph_def)
    fetch_collection = meta_graph_pb2.CollectionDef()
    for tensor_info in (list(signature_def.inputs.values()) +
                        list(signature_def.outputs.values())):
        fetch_collection.node_list.value.append(tensor_info.name)

    new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection)

    config = config_pb2.ConfigProto()
    return tf_optimizer.OptimizeGraph(config, new_meta_graph_def)
Esempio n. 11
0
  def _run_conversion(self):
    """Run Grappler's OptimizeGraph() tool to convert the graph."""
    # Create custom ConfigProto for Grappler.
    grappler_session_config = config_pb2.ConfigProto()
    grappler_session_config.CopyFrom(self._session_config)
    custom_rewriter_config = self.get_rewriter_config()
    grappler_session_config.graph_options.rewrite_options.CopyFrom(
        custom_rewriter_config)

    # Run Grappler.
    self._converted_graph_def = tf_optimizer.OptimizeGraph(
        grappler_session_config,
        self._grappler_meta_graph_def,
        graph_id=b"tf_graph")
    self._converted = True
Esempio n. 12
0
  def _run_conversion(self, meta_graph_def):
    """Run Grappler's OptimizeGraph() tool to convert the graph.

    Args:
      meta_graph_def: the MetaGraphDef instance to run the optimizations on.

    Returns:
      The optimized GraphDef.
    """
    rewriter_config = get_tensorrt_rewriter_config(
        conversion_params=self._conversion_params, is_v2=True)
    grappler_session_config = config_pb2.ConfigProto()
    grappler_session_config.graph_options.rewrite_options.CopyFrom(
        rewriter_config)
    return tf_optimizer.OptimizeGraph(
        grappler_session_config, meta_graph_def, graph_id=b"tf_graph")
Esempio n. 13
0
  def testDepthwise(self):
    meta_graph = _simple_metagraph(depthwise=True)
    rewrite_options = rewriter_config_pb2.RewriterConfig(
        layout_optimizer=rewriter_config_pb2.RewriterConfig.ON)
    optimized_graph = tf_optimizer.OptimizeGraph(
        rewrite_options, meta_graph, cluster=_get_cluster())

    found = 0
    for node in optimized_graph.node:
      if node.op in [
          'DepthwiseConv2dNative', 'DepthwiseConv2dNativeBackpropFilter',
          'DepthwiseConv2dNativeBackpropInput'
      ]:
        found += 1
        self.assertEqual(node.attr['data_format'].s, 'NCHW')
    self.assertEqual(found, 6)
Esempio n. 14
0
 def testHintDoesRewrite(self):
   graph = self._annotated_graph()[0]
   with graph.as_default():
     metagraph = train.export_meta_graph()
   self.assertEqual(
       0,
       len([node for node in metagraph.graph_def.node
            if 'Recomputed/' in node.name]))
   rewritten_graph_def = tf_optimizer.OptimizeGraph(
       rewriter_config_pb2.RewriterConfig(
           memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL),
       metagraph)
   self.assertEqual(
       9,
       len([node for node in rewritten_graph_def.node
            if 'Recomputed/' in node.name]))
Esempio n. 15
0
def tf_optimize_grappler(input_names, output_names, graph_def):
    config = config_pb2.ConfigProto()
    rewrite_options = config.graph_options.rewrite_options
    config.graph_options.infer_shapes = True
    rewrite_options.optimizers[:] = [
        'pruning', 'constfold', 'arithmetic', 'dependency', 'function',
    ]

    meta_graph = tf.compat.v1.train.export_meta_graph(graph_def=graph_def)

    fetch_collection = meta_graph_pb2.CollectionDef()
    for t in input_names + output_names:
        fetch_collection.node_list.value.append(t)
    meta_graph.collection_def['train_op'].CopyFrom(fetch_collection)

    graph_def = tf_optimizer.OptimizeGraph(config, meta_graph)
    return graph_def
Esempio n. 16
0
def optimize(g, inputs, outputs):
    sd = SignatureDef()
    for name in inputs:
        input_t = g.get_operation_by_name(name).outputs[0]
        sd.inputs[name].name = name
        sd.inputs[name].dtype = input_t.dtype.as_datatype_enum
        sd.inputs[name].tensor_shape.CopyFrom(input_t.shape.as_proto())
    for name in outputs:
        output_t = g.get_operation_by_name(name).outputs[0]
        sd.outputs[name].name = name
        sd.outputs[name].dtype = output_t.dtype.as_datatype_enum
        sd.outputs[name].tensor_shape.CopyFrom(output_t.shape.as_proto())

    tf.compat.v1.enable_resource_variables()
    cl = cluster.Cluster(disable_detailed_stats=True)

    # We have to run this twice to eliminate constants that are left after
    # optimising away split/pad/transpose nodes. They are const parameters like
    # axis, perm. They remain after 1 iter of optimization because we specify them
    # in the whitelist
    for i in range(2):
        if i == 0:
            graph = g
            c = get_default_config()
        else:
            graph = get_graph_from(optimized_graph_def)
            c = get_only_prune_config()

        white_list = get_white_list(graph)
        for name in white_list:
            graph.add_to_collection(
                GraphKeys.TRAIN_OP, graph.get_operation_by_name(name)
            )

        meta_graph = tf.compat.v1.train.export_meta_graph(
            graph_def=graph.as_graph_def(), graph=graph
        )
        meta_graph.signature_def["not_used_key"].CopyFrom(sd)

        optimized_graph_def = tf_optimizer.OptimizeGraph(
            config_proto=c, metagraph=meta_graph, cluster=cl
        )
    # Don't create VarHandleOp, ReadVariableOp, VarIsInitializedOp
    # Instead create VariableV2 ops in the future
    tf.disable_resource_variables()
    return optimized_graph_def
Esempio n. 17
0
  def testNoSwapping(self):
    """Make sure the graph is preserved when there is nothing to swap."""
    a = constant_op.constant(10, name='a')
    b = constant_op.constant(20, name='b')
    c = math_ops.add_n([a, b], name='c')
    d = math_ops.add_n([b, c], name='d')
    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    train_op.append(d)
    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())

    rewriter_config = rewriter_config_pb2.RewriterConfig(
        memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
    graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)

    self.assertEqual(len(graph.node), 4)
    self.assertItemsEqual([node.name
                           for node in graph.node], ['a', 'b', 'c', 'd'])
def _run_inline_graph_optimization(func, lower_control_flow):
    """Apply function inline optimization to the graph.

  Returns the GraphDef after Grappler's function inlining optimization is
  applied. This optimization does not work on models with control flow.

  Args:
    func: ConcreteFunction.
    lower_control_flow: Boolean indicating whether or not to lower control flow
      ops such as If and While. (default True)

  Returns:
    GraphDef
  """
    graph_def = func.graph.as_graph_def()
    if not lower_control_flow:
        graph_def = disable_lower_using_switch_merge(graph_def)
    meta_graph = export_meta_graph(graph_def=graph_def, graph=func.graph)

    # Clear the initializer_name for the variables collections, since they are not
    # needed after saved to saved_model.
    for name in [
            "variables", "model_variables", "trainable_variables",
            "local_variables"
    ]:
        raw_list = []
        for raw in meta_graph.collection_def["variables"].bytes_list.value:
            variable = variable_pb2.VariableDef()
            variable.ParseFromString(raw)
            variable.ClearField("initializer_name")
            raw_list.append(variable.SerializeToString())
        meta_graph.collection_def[name].bytes_list.value[:] = raw_list

    # Add a collection 'train_op' so that Grappler knows the outputs.
    fetch_collection = meta_graph_pb2.CollectionDef()
    for array in func.inputs + func.outputs:
        fetch_collection.node_list.value.append(array.name)
    meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)

    # Initialize RewriterConfig with everything disabled except function inlining.
    config = config_pb2.ConfigProto()
    rewrite_options = config.graph_options.rewrite_options
    rewrite_options.min_graph_nodes = -1  # do not skip small graphs
    rewrite_options.optimizers.append("function")
    return tf_optimizer.OptimizeGraph(config, meta_graph)
Esempio n. 19
0
    def testBasic(self):
        """Make sure arguments can be passed correctly."""
        a = constant_op.constant(10, name='a')
        b = constant_op.constant(20, name='b')
        c = math_ops.add_n([a, b], name='c')
        d = math_ops.add_n([b, c], name='d')
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(d)
        mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())

        rewriter_config = rewriter_config_pb2.RewriterConfig()
        rewriter_config.optimizers.append('constfold')

        graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)

        self.assertEqual(len(graph.node), 5)
        self.assertItemsEqual([node.name for node in graph.node],
                              ['a', 'b', 'c', 'd', 'ConstantFolding/c'])
def optimize_graph(graph, output_graph, quantization_dtype=None):
  """Takes a Python Graph object and optimizes the graph.

  Args:
    graph: tf.Graph tensorflow dataflow graph
  """
  rewriter_config = rewriter_config_pb2.RewriterConfig()
  rewriter_config.optimizers[:] = [
      'pruning', 'constfold', 'arithmetic', 'dependency', 'pruning',
      'constfold', 'arithmetic', 'dependency'
  ]
  meta_graph = tf.train.export_meta_graph(
      graph_def=graph.as_graph_def(), graph=graph)
  optimized_graph = tf_optimizer.OptimizeGraph(
      rewriter_config, meta_graph, cluster=get_cluster())

  extract_weights(optimized_graph, output_graph, quantization_dtype)
  return optimize_graph
Esempio n. 21
0
def constfold(graphdef, output_name):
    graph = ops.Graph()
    with graph.as_default():
        outputs = output_name.split(',')
        output_collection = meta_graph_pb2.CollectionDef()
        output_list = output_collection.node_list.value
        for output in outputs:
            output_list.append(output)
        importer.import_graph_def(graphdef, name="")
        metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
        metagraph.collection_def["train_op"].CopyFrom(output_collection)
        
    rewriter_config = rewriter_config_pb2.RewriterConfig()
    rewriter_config.optimizers.extend(["constfold"])
    rewriter_config.meta_optimizer_iterations = (rewriter_config_pb2.RewriterConfig.ONE)
    session_config = config_pb2.ConfigProto()
    session_config.graph_options.rewrite_options.CopyFrom(rewriter_config)
    
    return tf_optimizer.OptimizeGraph(session_config, metagraph)
Esempio n. 22
0
    def do_transformation(self):
        convert = False
        for node in self.model.node:
            if 'Conv' in node.op and \
               'data_format' in node.attr and \
               node.attr['data_format'].s == b'NCHW':
                convert = True
                break
        if convert:
            assert tf.version.VERSION >= '2.4.0', 'layout convert is only supported by \
                                                            tensorflow 2.4.0 and above'

            g = tf.Graph()
            with g.as_default(): # pylint: disable=not-context-manager
                g = tf.compat.v1.import_graph_def(self.model, name='')
                meta_graph = saver_lib.export_meta_graph(
                    graph_def=self.model, graph=g, clear_devices=True)
                fetch_collection = meta_graph_pb2.CollectionDef()
                for fetch in self.outputs:
                    fetch_collection.node_list.value.append(fetch) # pylint: disable=no-member
                meta_graph.collection_def["train_op"].CopyFrom( # pylint: disable=no-member
                                                    fetch_collection) # pylint: disable=no-member
                
            config = config_pb2.ConfigProto()
            convert = rewriter_config_pb2.RewriterConfig.NCHW_TO_NHWC # pylint: disable=no-member
            config.graph_options.rewrite_options.CopyFrom( # pylint: disable=no-member
                rewriter_config_pb2.RewriterConfig(
                    disable_model_pruning=True,
                    constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
                    dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                    memory_optimization=rewriter_config_pb2.RewriterConfig.NO_MEM_OPT,
                    arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                    shape_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                    loop_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                    function_optimization=rewriter_config_pb2.RewriterConfig.OFF,
                    remapping=rewriter_config_pb2.RewriterConfig.OFF,
                    implementation_selector=rewriter_config_pb2.RewriterConfig.OFF,
                    cpu_layout_conversion=convert))

            optimized_graph = tf_optimizer.OptimizeGraph(config, meta_graph)
            return optimized_graph
        else:
            return self.model
    def testLoops(self):
        g = ops.Graph()
        with g.as_default():

            def _Cond(_, counter):
                return counter < end

            def _Body(buf, counter):
                buf = array_ops.concat([buf, [counter]], 0)
                counter += 1
                return [buf, counter]

            start = array_ops.placeholder(shape=[], dtype=dtypes.int32)
            end = array_ops.placeholder(shape=[], dtype=dtypes.int32)
            init_buf = array_ops.zeros(shape=[0], dtype=dtypes.int32)
            loop_vars = [init_buf, start]
            shape_inv = [
                tensor_shape.TensorShape([None]),
                tensor_shape.TensorShape([])
            ]
            buf, _ = control_flow_ops.while_loop(_Cond, _Body, loop_vars,
                                                 shape_inv)

            f = -array_ops.ones_like(buf, optimize=False)  # pylint: disable=invalid-unary-operand-type
            buf_shape = array_ops.shape(buf)
            f_shape = array_ops.shape(f)
            ops.add_to_collection('train_op', buf_shape)
            ops.add_to_collection('train_op', f_shape)

        # Optimize the graph.
        mg = meta_graph.create_meta_graph_def(graph=g)
        config = config_pb2.ConfigProto()
        rewriter_config = config.graph_options.rewrite_options
        rewriter_config.min_graph_nodes = -1
        optimized_graph = tf_optimizer.OptimizeGraph(config, mg)
        mg.graph_def.CopyFrom(optimized_graph)

        # Check that the nodes referenced in various collections have been preserved
        item = gitem.Item(mg)
        props = item.GetOpProperties()
        buf_prop = props[buf.op.name]
        f_prop = props[f.op.name]
        self.assertEqual(buf_prop, f_prop)
Esempio n. 24
0
    def testNoSwapping(self):
        """Make sure the graph is preserved when there is nothing to swap."""
        a = variables.Variable(10, name='a')
        b = variables.Variable(20, name='b')
        c = math_ops.add_n([a, b], name='c')
        d = math_ops.add_n([b, c], name='d')
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(d)
        mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
        graph_size = len(mg.graph_def.node)
        nodes = [node.name for node in mg.graph_def.node]

        rewriter_config = rewriter_config_pb2.RewriterConfig(
            disable_model_pruning=True,
            memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
        graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)

        self.assertEqual(len(graph.node), graph_size)
        self.assertItemsEqual([node.name for node in graph.node], nodes)
Esempio n. 25
0
    def _test_convert_variables_with_functions(self, inline_functions):
        """Freezes a graph with functions."""
        @function.Defun(dtypes.float32)
        def plus_one(x):
            return x + 1.0

        with ops.Graph().as_default():
            variable_node = variables.Variable(1.0, name="variable_node")
            _ = variables.Variable(1.0, name="unused_variable_node")
            defun_node = plus_one(variable_node)
            _ = math_ops_lib.multiply(defun_node, 2.0, name="output_node")

            with session.Session() as sess:
                self.evaluate(variables.variables_initializer([variable_node]))
                variable_graph_def = sess.graph.as_graph_def()

                if inline_functions:
                    # Run Grappler to create the VarOpHandle --> Placeholder -->
                    # ResourceVariable pattern.
                    meta_graph = export_meta_graph(
                        graph_def=variable_graph_def)
                    fetch_collection = meta_graph_pb2.CollectionDef()
                    for name in ["variable_node", "output_node"]:
                        fetch_collection.node_list.value.append(name)
                    meta_graph.collection_def["train_op"].CopyFrom(
                        fetch_collection)

                    # Initialize RewriterConfig with everything disabled except function
                    # inlining.
                    config = config_pb2.ConfigProto()
                    rewrite_options = config.graph_options.rewrite_options
                    rewrite_options.optimizers.append("function")
                    variable_graph_def = tf_optimizer.OptimizeGraph(
                        config, meta_graph)

                constant_graph_def = graph_util.convert_variables_to_constants(
                    sess, variable_graph_def, ["output_node"])

        # Ensure there are no variables after freezing.
        for node in constant_graph_def.node:
            self.assertNotIn(
                node.op,
                ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
Esempio n. 26
0
def run_graph_optimizations(graph_def,
                            input_arrays,
                            output_arrays,
                            config,
                            graph=None):
    """Apply standard TensorFlow optimizations to the graph_def.

  Args:
    graph_def: Frozen GraphDef to be optimized.
    input_arrays: List of arrays that are considered inputs of the graph.
    output_arrays: List of arrays that are considered outputs of the graph.
    config: tf.ConfigProto.
    graph: TensorFlow Graph. Required when Eager mode is enabled. (default None)

  Returns:
    A new, optimized GraphDef.
  """
    meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph)

    signature = _meta_graph_pb2.SignatureDef()
    for array in input_arrays:
        signature.inputs[array.name].name = array.name
        signature.inputs[array.name].dtype = array.dtype.as_datatype_enum
        signature.inputs[array.name].tensor_shape.CopyFrom(
            array.shape.as_proto())

    for array in output_arrays:
        signature.outputs[array.name].name = array.name
        signature.outputs[array.name].dtype = array.dtype.as_datatype_enum
        signature.outputs[array.name].tensor_shape.CopyFrom(
            array.shape.as_proto())

    meta_graph.signature_def["not_used_key"].CopyFrom(signature)

    # We need to add a collection called 'train_op' so that grappler
    # knows what the outputs are.
    fetch_collection = _meta_graph_pb2.CollectionDef()
    for array in input_arrays + output_arrays:
        fetch_collection.node_list.value.append(array.name)
    meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)

    return tf_optimizer.OptimizeGraph(config, meta_graph)
Esempio n. 27
0
def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=None):
    from tensorflow.core.protobuf import meta_graph_pb2 as meta_graph_pb2, config_pb2, rewriter_config_pb2
    from tensorflow.python.grappler import tf_optimizer as tf_opt

    config = config_pb2.ConfigProto()
    rewrite_options = config.graph_options.rewrite_options
    config.graph_options.infer_shapes = True
    # TODO: if we turn on pruning, grappler removes some identities that the tf-1.x lstm rewriter
    #   depends on so for now don't turn this on.
    rewrite_options.optimizers[:] = [
        # 'pruning', 'constfold', 'arithmetic', 'dependency', 'function',
        'constfold', 'function'
    ]
    meta_graph = tf.compat.v1.train.export_meta_graph(graph_def=graph_def)
    fetch_collection = meta_graph_pb2.CollectionDef()
    for t in input_names + output_names:
        fetch_collection.node_list.value.append(t)
    meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
    graph_def = tf_opt.OptimizeGraph(config, meta_graph)
    return graph_def
    def testBasic(self):
        """Make sure arguments can be passed correctly."""
        a = constant_op.constant(10, name='a')
        b = constant_op.constant(20, name='b')
        c = math_ops.add_n([a, b], name='c')
        d = math_ops.add_n([b, c], name='d')
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        # Being a train_op will make 'd' to be added as a fetch node.
        train_op.append(d)
        mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())

        config = config_pb2.ConfigProto()
        rewriter_config = config.graph_options.rewrite_options
        rewriter_config.optimizers.append('constfold')
        rewriter_config.min_graph_nodes = -1

        graph = tf_optimizer.OptimizeGraph(config, mg)

        self.assertEqual(len(graph.node), 1)
        self.assertItemsEqual([node.name for node in graph.node], ['d'])
Esempio n. 29
0
    def testSimpleSwap(self):
        """Check that the swap annotations are followed."""
        with ops.device('/gpu:0'):
            a = variables.VariableV1(10, name='a')
            b = variables.VariableV1(20, name='b')
            c = math_ops.add_n([a, b], name='c')
            d = math_ops.add_n([b, c], name='d')
            train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
            train_op.append(d)

            d.op._set_attr('_swap_to_host', attr_value_pb2.AttrValue(i=0))

            mg = meta_graph.create_meta_graph_def(
                graph=ops.get_default_graph())
            graph_size = len(mg.graph_def.node)

            config = config_pb2.ConfigProto()
            config.graph_options.rewrite_options.CopyFrom(
                rewriter_config_pb2.RewriterConfig(
                    disable_model_pruning=True,
                    meta_optimizer_iterations=rewriter_config_pb2.
                    RewriterConfig.ONE,
                    constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
                    memory_optimization=rewriter_config_pb2.RewriterConfig.
                    MANUAL,
                    min_graph_nodes=-1))
            graph = tf_optimizer.OptimizeGraph(config, mg)

            self.assertEqual(len(graph.node), graph_size + 2)
            self.assertTrue(
                set([node.name for node in graph.node]) > set(
                    ['a', 'b', 'c', 'd', 'swap_in_d_0', 'swap_out_d_0']))
            for node in graph.node:
                if node.name == 'swap_in_d_0':
                    self.assertEqual('swap_out_d_0', node.input[0])
                    self.assertEqual('^b/read', node.input[1])
                elif node.name == 'swap_out_d_0':
                    self.assertEqual('b/read', node.input[0])
                elif node.name == 'd':
                    self.assertEqual('swap_in_d_0', node.input[0])
                    self.assertEqual('c', node.input[1])
Esempio n. 30
0
def grappler_optimize(graph, fetches=None, rewriter_config=None):
    """Tries to optimize the provided graph using grappler.

  Args:
    graph: A @{tf.Graph} instance containing the graph to optimize.
    fetches: An optional list of `Tensor`s to fetch (i.e. not optimize away).
      Grappler uses the 'train_op' collection to look for fetches, so if not
      provided this collection should be non-empty.
    rewriter_config: An optional @{tf.RewriterConfig} to use when rewriting the
      graph.

  Returns:
    A @{tf.GraphDef} containing the rewritten graph.
  """
    if rewriter_config is None:
        rewriter_config = rewriter_config_pb2.RewriterConfig()
    if fetches is not None:
        for fetch in fetches:
            graph.add_to_collection('train_op', fetch)
    metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def())
    return tf_optimizer.OptimizeGraph(rewriter_config, metagraph)