예제 #1
0
def strip_meta_graph(meta_graph_def, node_names, var_names):
    node_names = node_names[:]
    collections = meta_graph_def.collection_def

    # Look for matching variable names and initializers and keep them too.
    var_def = variable_pb2.VariableDef()
    for var_col_name in ["variables", "trainable_variables"]:
        var_def_bs = collections[var_col_name].bytes_list.value
        for var_def_b in var_def_bs:
            var_def.ParseFromString(var_def_b)
            if var_def.variable_name not in var_names:
                # TODO(adamb) Should remove variable from collection.
                continue
            node_names.append(var_def.initializer_name)

    wc_def = control_flow_pb2.WhileContextDef()
    wc_values = collections["while_context"].bytes_list.value
    for wc_ix in range(len(wc_values) - 1, -1, -1):
        wc_bytes = wc_values[wc_ix]
        wc_def.ParseFromString(wc_bytes)
        unused = True
        wc_pivot_name = wc_def.pivot_name
        for name in node_names:
            if name.startswith(wc_pivot_name):
                unused = False
                break

        if unused:
            del wc_values[wc_ix]

    graph_def = meta_graph_def.graph_def
    eprint("only keeping", node_names, "from",
           [n.name for n in graph_def.node])
    graph_def = graph_util.extract_sub_graph(graph_def, node_names)
    meta_graph_def.graph_def.CopyFrom(graph_def)
예제 #2
0
    def to_proto(self, export_scope=None):
        """Converts a `ResourceVariable` to a `VariableDef` protocol buffer.

    Args:
      export_scope: Optional `string`. Name scope to remove.

    Raises:
      RuntimeError: If run in EAGER mode.

    Returns:
      A `VariableDef` protocol buffer, or `None` if the `Variable` is not
      in the specified name scope.
    """
        if context.in_eager_mode():
            raise RuntimeError("to_proto not supported in EAGER mode.")
        if export_scope is None or self.handle.name.startswith(export_scope):
            var_def = variable_pb2.VariableDef()
            var_def.variable_name = ops.strip_name_scope(
                self.handle.name, export_scope)
            var_def.initializer_name = ops.strip_name_scope(
                self.initializer.name, export_scope)
            if self._cached_value is not None:
                var_def.snapshot_name = ops.strip_name_scope(
                    self._cached_value.name, export_scope)
            var_def.is_resource = True
            if self._save_slice_info:
                var_def.save_slice_info_def.MergeFrom(
                    self._save_slice_info.to_proto(export_scope=export_scope))
            return var_def
        else:
            return None
예제 #3
0
    def to_proto(self, export_scope=None):
        """Converts a `Variable` to a `VariableDef` protocol buffer.

    Args:
      export_scope: Optional `string`. Name scope to remove.

    Returns:
      A `VariableDef` protocol buffer, or `None` if the `Variable` is not
      in the specified name scope.
    """
        if (export_scope is None
                or self._variable.name.startswith(export_scope)):
            var_def = variable_pb2.VariableDef()
            var_def.variable_name = ops.strip_name_scope(
                self._variable.name, export_scope)
            var_def.initializer_name = ops.strip_name_scope(
                self.initializer.name, export_scope)
            var_def.snapshot_name = ops.strip_name_scope(
                self._snapshot.name, export_scope)
            if self._save_slice_info:
                var_def.save_slice_info_def.MergeFrom(
                    self._save_slice_info.to_proto(export_scope=export_scope))
            return var_def
        else:
            return None
예제 #4
0
 def to_proto(self, export_scope=None):
   full_proto = super(_MissingFieldsVariable, self).to_proto(export_scope)
   return variable_pb2.VariableDef(
       variable_name=full_proto.variable_name,
       initial_value_name=full_proto.initial_value_name,
       initializer_name=full_proto.snapshot_name,
       save_slice_info_def=full_proto.save_slice_info_def,
       is_resource=full_proto.is_resource)
예제 #5
0
def _run_inline_graph_optimization(func, lower_control_flow):
    """Apply function inline optimization to the graph.

  Returns the GraphDef after Grappler's function inlining optimization is
  applied. This optimization does not work on models with control flow.

  Args:
    func: ConcreteFunction.
    lower_control_flow: Boolean indicating whether or not to lower control flow
      ops such as If and While. (default True)

  Returns:
    GraphDef
  """
    graph_def = func.graph.as_graph_def()
    if not lower_control_flow:
        graph_def = disable_lower_using_switch_merge(graph_def)

    # In some cases, a secondary implementation of the function (e.g. for GPU) is
    # written to the "api_implements" attribute. (e.g. `tf.keras.layers.LSTM` in
    # TF2 produces a CuDNN-based RNN for GPU).
    # This function suppose to inline all functions calls, but "api_implements"
    # prevents this from happening. Removing the attribute solves the problem.
    # To learn more about "api_implements", see:
    #   tensorflow/core/grappler/optimizers/implementation_selector.h
    for function in graph_def.library.function:
        if "api_implements" in function.attr:
            del function.attr["api_implements"]

    meta_graph = export_meta_graph(graph_def=graph_def, graph=func.graph)

    # Clear the initializer_name for the variables collections, since they are not
    # needed after saved to saved_model.
    for name in [
            "variables", "model_variables", "trainable_variables",
            "local_variables"
    ]:
        raw_list = []
        for raw in meta_graph.collection_def["variables"].bytes_list.value:
            variable = variable_pb2.VariableDef()
            variable.ParseFromString(raw)
            variable.ClearField("initializer_name")
            raw_list.append(variable.SerializeToString())
        meta_graph.collection_def[name].bytes_list.value[:] = raw_list

    # Add a collection 'train_op' so that Grappler knows the outputs.
    fetch_collection = meta_graph_pb2.CollectionDef()
    for array in func.inputs + func.outputs:
        fetch_collection.node_list.value.append(array.name)
    meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)

    # Initialize RewriterConfig with everything disabled except function inlining.
    config = config_pb2.ConfigProto()
    rewrite_options = config.graph_options.rewrite_options
    rewrite_options.min_graph_nodes = -1  # do not skip small graphs
    rewrite_options.optimizers.append("function")
    return tf_optimizer.OptimizeGraph(config, meta_graph)
예제 #6
0
파일: origin.py 프로젝트: wangye707/PY_DELL
 def update_snapshot_name(self, var_coll_name):
     var_list = self._metagraph.collection_def[var_coll_name]
     for i, value in enumerate(var_list.bytes_list.value):
         var_def = variable_pb2.VariableDef()
         var_def.ParseFromString(value)
         if var_def.snapshot_name != "Model/global_step/read:0":
             var_def.snapshot_name = with_autoparallel_prefix(
                 0, var_def.snapshot_name)
         value = var_def.SerializeToString()
         var_list.bytes_list.value[i] = value
예제 #7
0
def _duplicate_layer(layer_name,
                     layer_sgv,
                     branch_name,
                     add_to_collections=True):
    """Duplicates a network layer, while preserving connections.

    Args:
      layer_name:         a layer is identified by its name scope
      layer_sgv:          SubgraphView (see tf.contrib.graph_editor)
      branch_name:        the duplicate is "layer_name + branch_name"
      add_to_collections: add duplicate vars to the same collections

    Returns:
      info:            see ret vals of `tf.contrib.graph_editor.copy`
      var_duplication: a list of tuples (var, dup_of_var)
    """

    if layer_name[-1] == '/':
        new_layer_name = layer_name[:-1] + branch_name + '/'
    else:
        new_layer_name = layer_name + branch_name

    replacement_ts = {}
    for op in layer_sgv.inputs:
        replacement_ts[op] = op

    duplicate_sgv, info = ge.copy_with_input_replacements(
        layer_sgv,
        replacement_ts=replacement_ts,
        src_scope=layer_name,
        dst_scope=new_layer_name)

    var_duplication = []
    for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
        if layer_name not in v.name:
            continue
        vproto = v.to_proto()
        new_vardef = variable_pb2.VariableDef()
        for field, val in vproto.ListFields():
            if isinstance(val, str):
                new_val = val.replace(layer_name, new_layer_name)
            else:
                new_val = val
            setattr(new_vardef, field.name, new_val)
        new_var = tf.Variable(variable_def=new_vardef)
        tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, new_var)
        var_duplication.append((v, new_var))

        if add_to_collections:
            for k in tf.get_default_graph().get_all_collection_keys():
                collection = tf.get_collection(k)
                if v in collection and new_var not in collection:
                    tf.add_to_collection(k, new_var)

    return info, var_duplication
예제 #8
0
def update_snapshot_name(self, var_coll_name):
    var_list = self._metagraph.collection_def[var_coll_name]
    for i, value in enumerate(var_list.bytes_list.value):
        var_def = variable_pb2.VariableDef()
        var_def.ParseFromString(value)
        # Somehow node Model/global_step/read doesn't have any fanout and seems to
        # be only used for snapshot; this is different from all other variables.
        if var_def.snapshot_name != "Model/global_step/read:0":
            var_def.snapshot_name = with_autoparallel_prefix(0, var_def.snapshot_name)
        value = var_def.SerializeToString()
        var_list.bytes_list.value[i] = value
예제 #9
0
  def to_proto(self):
    """Converts a `Variable` to a `VariableDef` protocol buffer.

    Returns:
      A `VariableDef` protocol buffer.
    """
    var_def = variable_pb2.VariableDef()
    var_def.variable_name = self._variable.name
    var_def.initializer_name = self.initializer.name
    var_def.snapshot_name = self._snapshot.name
    if self._save_slice_info:
      var_def.save_slice_info_def.MergeFrom(self._save_slice_info.to_proto())
    return var_def
예제 #10
0
    def to_proto(self):
        """
    Inverse of `from_proto()` method.

    Returns a `VariableDef` protocol buffer message that represents this
    variable.
    """
        ret = variable_pb2.VariableDef()
        ret.variable_name = self._variable_name
        ret.initial_value_name = self._initial_value_name
        ret.initializer_name = self._initializer_name
        ret.snapshot_name = self._snapshot_name
        ret.trainable = self._trainable
        return ret
예제 #11
0
    def to_proto(self) -> variable_pb2.VariableDef:
        """
    Convert this object into its equivalent TensorFlow protocol buffer
    message.

    Returns a `VariableDef` protobuf equivalent to this object.
    """
        ret = variable_pb2.VariableDef()
        ret.variable_name = self.name
        ret.initial_value_name = self.initial_value_name
        ret.initializer_name = self.initializer_name
        ret.snapshot_name = self.snapshot_name
        ret.trainable = self.trainable
        # TODO(frreiss): Figure out what to do with the is_resource field
        # TODO(frreiss): Figure out what to do with the save_slice_info_def field
        return ret
def _run_inline_graph_optimization(func, lower_control_flow):
    """Apply function inline optimization to the graph.

  Returns the GraphDef after Grappler's function inlining optimization is
  applied. This optimization does not work on models with control flow.

  Args:
    func: ConcreteFunction.
    lower_control_flow: Boolean indicating whether or not to lower control flow
      ops such as If and While. (default True)

  Returns:
    GraphDef
  """
    graph_def = func.graph.as_graph_def()
    if not lower_control_flow:
        graph_def = disable_lower_using_switch_merge(graph_def)
    meta_graph = export_meta_graph(graph_def=graph_def, graph=func.graph)

    # Clear the initializer_name for the variables collections, since they are not
    # needed after saved to saved_model.
    for name in [
            "variables", "model_variables", "trainable_variables",
            "local_variables"
    ]:
        raw_list = []
        for raw in meta_graph.collection_def["variables"].bytes_list.value:
            variable = variable_pb2.VariableDef()
            variable.ParseFromString(raw)
            variable.ClearField("initializer_name")
            raw_list.append(variable.SerializeToString())
        meta_graph.collection_def[name].bytes_list.value[:] = raw_list

    # Add a collection 'train_op' so that Grappler knows the outputs.
    fetch_collection = meta_graph_pb2.CollectionDef()
    for array in func.inputs + func.outputs:
        fetch_collection.node_list.value.append(array.name)
    meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)

    # Initialize RewriterConfig with everything disabled except function inlining.
    config = config_pb2.ConfigProto()
    rewrite_options = config.graph_options.rewrite_options
    rewrite_options.min_graph_nodes = -1  # do not skip small graphs
    rewrite_options.optimizers.append("function")
    return tf_optimizer.OptimizeGraph(config, meta_graph)
예제 #13
0
  def to_proto(self, export_scope=None):
    """Converts a `ResourceVariable` to a `VariableDef` protocol buffer.

    Args:
      export_scope: Optional `string`. Name scope to remove.

    Raises:
      RuntimeError: If run in EAGER mode.

    Returns:
      A `VariableDef` protocol buffer, or `None` if the `Variable` is not
      in the specified name scope.
    """
    if context.executing_eagerly():
      raise RuntimeError("to_proto not supported in EAGER mode.")
    if export_scope is None or self.handle.name.startswith(export_scope):
      var_def = variable_pb2.VariableDef()
      var_def.variable_name = ops.strip_name_scope(self.handle.name,
                                                   export_scope)
      if self._initial_value is not None:
        # This is inside an if-statement for backwards compatibility, since
        # self._initial_value might be None for variables constructed from old
        # protos.
        var_def.initial_value_name = ops.strip_name_scope(
            self._initial_value.name, export_scope)
      var_def.initializer_name = ops.strip_name_scope(self.initializer.name,
                                                      export_scope)
      if self._cached_value is not None:
        var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name,
                                                     export_scope)
      else:
        # Store the graph_element here
        var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name,
                                                     export_scope)
      var_def.is_resource = True
      var_def.trainable = self.trainable
      if self._save_slice_info:
        var_def.save_slice_info_def.MergeFrom(
            self._save_slice_info.to_proto(export_scope=export_scope))
      return var_def
    else:
      return None
예제 #14
0
 def to_proto(self, export_scope=None):
   if (export_scope is None or
       self._variable.name.startswith(export_scope)):
     var_def = variable_pb2.VariableDef()
     var_def.variable_name = ops.strip_name_scope(
         self._variable.name, export_scope)
     if self._initial_value is not None:
       # For backwards compatibility.
       var_def.initial_value_name = ops.strip_name_scope(
           self._initial_value.name, export_scope)
     var_def.initializer_name = ops.strip_name_scope(
         self.initializer.name, export_scope)
     var_def.snapshot_name = ops.strip_name_scope(
         self._snapshot.name, export_scope)
     if self._save_slice_info:
       var_def.save_slice_info_def.MergeFrom(self._save_slice_info.to_proto(
           export_scope=export_scope))
     return var_def
   else:
     return None
예제 #15
0
def _get_grads(single_gpu_meta_graph_def):
    trainable_vars = []
    trainable_vars_defs = single_gpu_meta_graph_def.collection_def[tf.GraphKeys.TRAINABLE_VARIABLES]
    for var_def_string in trainable_vars_defs.bytes_list.value:
        var_def = variable_pb2.VariableDef()
        var_def.ParseFromString(var_def_string)
        trainable_vars.append(var_def.variable_name)
    sparse_grads = []
    dense_grads = []
    grad_info_defs = single_gpu_meta_graph_def.collection_def[tf.GraphKeys.GRADIENTS_INFO]
    for grad_info_def_string in grad_info_defs.bytes_list.value:
        gradients_info_def = gradients_info_pb2.GradientsInfoDef()
        gradients_info_def.ParseFromString(grad_info_def_string)
        if gradients_info_def.target_tensor_info.values_tensor_name not in trainable_vars:
            continue
        if gradients_info_def.grad_tensor_info.tensor_type == gradients_info_pb2.GradientsInfoDef.TensorInfoDef.INDEXED_SLICES:
            sparse_grads.append(gradients_info_def)
        else:
            dense_grads.append(gradients_info_def)
    assert len(sparse_grads) > 0 or len(dense_grads) > 0
    return sparse_grads, dense_grads