Esempio n. 1
0
  def _init_from_proto(self, queue_runner_def, import_scope=None):
    """Create a QueueRunner from `QueueRunnerDef`.

    Args:
      queue_runner_def: Optional `QueueRunnerDef` protocol buffer.
      import_scope: Optional `string`. Name scope to add.
    """
    assert isinstance(queue_runner_def, queue_runner_pb2.QueueRunnerDef)
    g = ops.get_default_graph()
    self._queue = g.as_graph_element(
        ops.prepend_name_scope(queue_runner_def.queue_name, import_scope))
    self._enqueue_ops = [g.as_graph_element(
        ops.prepend_name_scope(op, import_scope))
                         for op in queue_runner_def.enqueue_op_name]
    self._close_op = g.as_graph_element(ops.prepend_name_scope(
        queue_runner_def.close_op_name, import_scope))
    self._cancel_op = g.as_graph_element(ops.prepend_name_scope(
        queue_runner_def.cancel_op_name, import_scope))
    self._queue_closed_exception_types = tuple(
        errors.exception_type_from_error_code(code)
        for code in queue_runner_def.queue_closed_exception_types)
    # Legacy support for old QueueRunnerDefs created before this field
    # was added.
    if not self._queue_closed_exception_types:
      self._queue_closed_exception_types = (errors.OutOfRangeError,)
Esempio n. 2
0
    def _init_from_proto(self, queue_runner_def, import_scope=None):
        """Create a QueueRunner from `QueueRunnerDef`.

    Args:
      queue_runner_def: Optional `QueueRunnerDef` protocol buffer.
      import_scope: Optional `string`. Name scope to add.
    """
        assert isinstance(queue_runner_def, queue_runner_pb2.QueueRunnerDef)
        g = ops.get_default_graph()
        self._queue = g.as_graph_element(
            ops.prepend_name_scope(queue_runner_def.queue_name, import_scope))
        self._enqueue_ops = [
            g.as_graph_element(ops.prepend_name_scope(op, import_scope))
            for op in queue_runner_def.enqueue_op_name
        ]
        self._close_op = g.as_graph_element(
            ops.prepend_name_scope(queue_runner_def.close_op_name,
                                   import_scope))
        self._cancel_op = g.as_graph_element(
            ops.prepend_name_scope(queue_runner_def.cancel_op_name,
                                   import_scope))
        self._queue_closed_exception_types = tuple(
            errors.exception_type_from_error_code(code)
            for code in queue_runner_def.queue_closed_exception_types)
        # Legacy support for old QueueRunnerDefs created before this field
        # was added.
        if not self._queue_closed_exception_types:
            self._queue_closed_exception_types = (errors.OutOfRangeError, )
Esempio n. 3
0
  def _init_from_proto(self, variable_def, import_scope=None):
    """Creates a new variable from `VariableDef` protocol buffer.

    Args:
      variable_def: `VariableDef` protocol buffer.
      import_scope: Optional `string`. Name scope to add.
    """
    assert isinstance(variable_def, variable_pb2.VariableDef)
    # Create from variable_def.
    g = ops.get_default_graph()
    self._variable = g.as_graph_element(
        ops.prepend_name_scope(variable_def.variable_name,
                               import_scope=import_scope))
    self._initializer_op = g.as_graph_element(
        ops.prepend_name_scope(variable_def.initializer_name,
                               import_scope=import_scope))
    self._snapshot = g.as_graph_element(
        ops.prepend_name_scope(variable_def.snapshot_name,
                               import_scope=import_scope))
    if variable_def.HasField("save_slice_info_def"):
      self._save_slice_info = Variable.SaveSliceInfo(
          save_slice_info_def=variable_def.save_slice_info_def)
    else:
      self._save_slice_info = None
    self._caching_device = None
  def _init_from_proto(self, variable_def, import_scope=None):
    """Initializes from `VariableDef` proto."""
    assert isinstance(variable_def, variable_pb2.VariableDef)
    if not variable_def.is_resource:
      raise ValueError("Trying to restore Variable as ResourceVariable.")

    # Create from variable_def.
    g = ops.get_default_graph()
    self._handle = g.as_graph_element(
        ops.prepend_name_scope(variable_def.variable_name,
                               import_scope=import_scope))
    self._initialize_op = g.as_graph_element(
        ops.prepend_name_scope(variable_def.initializer_name,
                               import_scope=import_scope))
    if variable_def.snapshot_name:
      self._cached_value = g.as_graph_element(
          ops.prepend_name_scope(variable_def.snapshot_name,
                                 import_scope=import_scope))
    else:
      self._cached_value = None
    if variable_def.HasField("save_slice_info_def"):
      self._save_slice_info = variables.Variable.SaveSliceInfo(
          save_slice_info_def=variable_def.save_slice_info_def)
    else:
      self._save_slice_info = None
    self._caching_device = None
    self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
Esempio n. 5
0
    def _init_from_proto(self, variable_def, import_scope=None):
        """Creates a new variable from `VariableDef` protocol buffer.

    Args:
      variable_def: `VariableDef` protocol buffer.
      import_scope: Optional `string`. Name scope to add.
    """
        assert isinstance(variable_def, variable_pb2.VariableDef)
        # Create from variable_def.
        g = ops.get_default_graph()
        self._variable = g.as_graph_element(
            ops.prepend_name_scope(variable_def.variable_name,
                                   import_scope=import_scope))
        self._initializer_op = g.as_graph_element(
            ops.prepend_name_scope(variable_def.initializer_name,
                                   import_scope=import_scope))
        self._snapshot = g.as_graph_element(
            ops.prepend_name_scope(variable_def.snapshot_name,
                                   import_scope=import_scope))
        if variable_def.HasField("save_slice_info_def"):
            self._save_slice_info = Variable.SaveSliceInfo(
                save_slice_info_def=variable_def.save_slice_info_def)
        else:
            self._save_slice_info = None
        self._caching_device = None
Esempio n. 6
0
    def _init_from_proto(self, variable_def, import_scope=None):
        """Initializes from `VariableDef` proto."""
        assert context.in_graph_mode()
        assert isinstance(variable_def, variable_pb2.VariableDef)
        if not variable_def.is_resource:
            raise ValueError("Trying to restore Variable as ResourceVariable.")

        # Create from variable_def.
        g = ops.get_default_graph()
        self._handle = g.as_graph_element(
            ops.prepend_name_scope(variable_def.variable_name,
                                   import_scope=import_scope))
        self._handle_name = self._handle.name
        self._initializer_op = g.as_graph_element(
            ops.prepend_name_scope(variable_def.initializer_name,
                                   import_scope=import_scope))
        if variable_def.snapshot_name:
            self._cached_value = g.as_graph_element(
                ops.prepend_name_scope(variable_def.snapshot_name,
                                       import_scope=import_scope))
        else:
            self._cached_value = None
        if variable_def.HasField("save_slice_info_def"):
            self._save_slice_info = variables.Variable.SaveSliceInfo(
                save_slice_info_def=variable_def.save_slice_info_def)
        else:
            self._save_slice_info = None
        self._caching_device = None
        self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
        self._graph_element = self.value()
        self._constraint = None
Esempio n. 7
0
 def _init_from_proto(self, variable_def, import_scope=None):
   assert isinstance(variable_def, variable_pb2.VariableDef)
   # Create from variable_def.
   g = ops.get_default_graph()
   self._variable = g.as_graph_element(
       ops.prepend_name_scope(variable_def.variable_name,
                              import_scope=import_scope))
   self._initializer_op = g.as_graph_element(
       ops.prepend_name_scope(variable_def.initializer_name,
                              import_scope=import_scope))
   # Tests whether initial_value_name exists first for backwards compatibility.
   if (hasattr(variable_def, "initial_value_name") and
       variable_def.initial_value_name):
     self._initial_value = g.as_graph_element(
         ops.prepend_name_scope(variable_def.initial_value_name,
                                import_scope=import_scope))
   else:
     self._initial_value = None
   self._snapshot = g.as_graph_element(
       ops.prepend_name_scope(variable_def.snapshot_name,
                              import_scope=import_scope))
   if variable_def.HasField("save_slice_info_def"):
     self._save_slice_info = Variable.SaveSliceInfo(
         save_slice_info_def=variable_def.save_slice_info_def,
         import_scope=import_scope)
   else:
     self._save_slice_info = None
   self._caching_device = None
   self._constraint = None
  def _init_from_proto(self, variable_def, import_scope=None):
    """Initializes from `VariableDef` proto."""
    # Note that init_from_proto is currently not supported in Eager mode.
    assert not context.executing_eagerly()
    self._in_graph_mode = True
    assert isinstance(variable_def, variable_pb2.VariableDef)
    if not variable_def.is_resource:
      raise ValueError("Trying to restore Variable as ResourceVariable.")

    # Create from variable_def.
    g = ops.get_default_graph()
    self._handle = g.as_graph_element(
        ops.prepend_name_scope(
            variable_def.variable_name, import_scope=import_scope))
    self._shape = tensor_shape.TensorShape(
        self._handle.op.get_attr("shape"))
    self._handle_name = self._handle.name
    self._unique_id = self._handle_name
    self._initializer_op = g.as_graph_element(
        ops.prepend_name_scope(
            variable_def.initializer_name, import_scope=import_scope))
    # Check whether initial_value_name exists for backwards compatibility.
    if (hasattr(variable_def, "initial_value_name") and
        variable_def.initial_value_name):
      self._initial_value = g.as_graph_element(
          ops.prepend_name_scope(variable_def.initial_value_name,
                                 import_scope=import_scope))
    else:
      self._initial_value = None
    self._trainable = getattr(variable_def, "trainable", True)
    if variable_def.snapshot_name:
      snapshot = g.as_graph_element(
          ops.prepend_name_scope(
              variable_def.snapshot_name, import_scope=import_scope))
      self._cached_value = snapshot
      while snapshot.op.type != "ReadVariableOp":
        snapshot = snapshot.op.inputs[0]
      self._graph_element = snapshot
    else:
      self._cached_value = None
      # Legacy case for protos without the snapshot name; assume it's the
      # following.
      self._graph_element = g.get_tensor_by_name(
          self._handle.op.name + "/Read/ReadVariableOp:0")
    if variable_def.HasField("save_slice_info_def"):
      self._save_slice_info = variables.Variable.SaveSliceInfo(
          save_slice_info_def=variable_def.save_slice_info_def,
          import_scope=import_scope)
    else:
      self._save_slice_info = None
    self._caching_device = None
    self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
    self._constraint = None
    self._cached_shape_as_list = None
Esempio n. 9
0
  def _init_from_proto(self, variable_def, import_scope=None):
    """Initializes from `VariableDef` proto."""
    # Note that init_from_proto is currently not supported in Eager mode.
    assert not context.executing_eagerly()
    self._in_graph_mode = True
    assert isinstance(variable_def, variable_pb2.VariableDef)
    if not variable_def.is_resource:
      raise ValueError("Trying to restore Variable as ResourceVariable.")

    # Create from variable_def.
    g = ops.get_default_graph()
    self._handle = g.as_graph_element(
        ops.prepend_name_scope(
            variable_def.variable_name, import_scope=import_scope))
    self._shape = tensor_shape.TensorShape(
        self._handle.op.get_attr("shape"))
    self._handle_name = self._handle.name
    self._unique_id = self._handle_name
    self._initializer_op = g.as_graph_element(
        ops.prepend_name_scope(
            variable_def.initializer_name, import_scope=import_scope))
    # Check whether initial_value_name exists for backwards compatibility.
    if (hasattr(variable_def, "initial_value_name") and
        variable_def.initial_value_name):
      self._initial_value = g.as_graph_element(
          ops.prepend_name_scope(variable_def.initial_value_name,
                                 import_scope=import_scope))
    else:
      self._initial_value = None
    self._trainable = getattr(variable_def, "trainable", True)
    if variable_def.snapshot_name:
      snapshot = g.as_graph_element(
          ops.prepend_name_scope(
              variable_def.snapshot_name, import_scope=import_scope))
      self._cached_value = snapshot
      while snapshot.op.type != "ReadVariableOp":
        snapshot = snapshot.op.inputs[0]
      self._graph_element = snapshot
    else:
      self._cached_value = None
      # Legacy case for protos without the snapshot name; assume it's the
      # following.
      self._graph_element = g.get_tensor_by_name(
          self._handle.op.name + "/Read/ReadVariableOp:0")
    if variable_def.HasField("save_slice_info_def"):
      self._save_slice_info = variables.Variable.SaveSliceInfo(
          save_slice_info_def=variable_def.save_slice_info_def,
          import_scope=import_scope)
    else:
      self._save_slice_info = None
    self._caching_device = None
    self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
    self._constraint = None
    self._cached_shape_as_list = None
Esempio n. 10
0
    def _build(self, checkpoint_path, build_save, build_restore):
        """Builds saver_def."""
        if not context.executing_eagerly():
            if self._is_built:
                return
            self._is_built = True

        if not self.saver_def or context.executing_eagerly():
            if self._builder is None:
                # Attention: this is our target!!
                self._builder = SecureBulkSaverBuilder(self._write_version)

            if self._var_list is None:
                # pylint: disable=protected-access
                self._var_list = variables._all_saveable_objects()
            if not self._var_list:
                if self._allow_empty:
                    self._is_empty = True
                    return
                else:
                    raise ValueError("No variables to save")
            self._is_empty = False

            self.saver_def = self._builder._build_internal(  # pylint: disable=protected-access
                self._var_list,
                reshape=self._reshape,
                sharded=self._sharded,
                max_to_keep=self._max_to_keep,
                keep_checkpoint_every_n_hours=self.
                _keep_checkpoint_every_n_hours,
                name=self._name,
                restore_sequentially=self._restore_sequentially,
                filename=checkpoint_path,
                build_save=build_save,
                build_restore=build_restore)
        elif self.saver_def and self._name:
            # Since self._name is used as a name_scope by builder(), we are
            # overloading the use of this field to represent the "import_scope" as
            # well.
            self.saver_def.filename_tensor_name = ops.prepend_name_scope(
                self.saver_def.filename_tensor_name, self._name)
            self.saver_def.save_tensor_name = ops.prepend_name_scope(
                self.saver_def.save_tensor_name, self._name)
            self.saver_def.restore_op_name = ops.prepend_name_scope(
                self.saver_def.restore_op_name, self._name)

        self._check_saver_def()
        if not context.executing_eagerly():
            # Updates next checkpoint time.
            # Set in __init__ when executing eagerly.
            self._next_checkpoint_time = (
                time.time() +
                self.saver_def.keep_checkpoint_every_n_hours * 3600)
Esempio n. 11
0
    def __init__(self,
                 full_name=None,
                 full_shape=None,
                 var_offset=None,
                 var_shape=None,
                 save_slice_info_def=None,
                 import_scope=None):
      """Create a `SaveSliceInfo`.

      Args:
        full_name: Name of the full variable of which this `Variable` is a
            slice.
        full_shape: Shape of the full variable, as a list of int.
        var_offset: Offset of this `Variable` into the full variable, as a
            list of int.
        var_shape: Shape of this `Variable`, as a list of int.
        save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`,
          recreates the SaveSliceInfo object its contents.
          `save_slice_info_def` and other arguments are mutually
          exclusive.
        import_scope: Optional `string`. Name scope to add. Only used
          when initializing from protocol buffer.
      """
      if save_slice_info_def:
        assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef)
        self.full_name = ops.prepend_name_scope(
            save_slice_info_def.full_name, import_scope=import_scope)
        self.full_shape = [i for i in save_slice_info_def.full_shape]
        self.var_offset = [i for i in save_slice_info_def.var_offset]
        self.var_shape = [i for i in save_slice_info_def.var_shape]
      else:
        self.full_name = full_name
        self.full_shape = full_shape
        self.var_offset = var_offset
        self.var_shape = var_shape
Esempio n. 12
0
    def __init__(self,
                 full_name=None,
                 full_shape=None,
                 var_offset=None,
                 var_shape=None,
                 save_slice_info_def=None,
                 import_scope=None):
      """Create a `SaveSliceInfo`.

      Args:
        full_name: Name of the full variable of which this `Variable` is a
            slice.
        full_shape: Shape of the full variable, as a list of int.
        var_offset: Offset of this `Variable` into the full variable, as a
            list of int.
        var_shape: Shape of this `Variable`, as a list of int.
        save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`,
          recreates the SaveSliceInfo object its contents.
          `save_slice_info_def` and other arguments are mutually
          exclusive.
        import_scope: Optional `string`. Name scope to add. Only used
          when initializing from protocol buffer.
      """
      if save_slice_info_def:
        assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef)
        self.full_name = ops.prepend_name_scope(
            save_slice_info_def.full_name, import_scope=import_scope)
        self.full_shape = [i for i in save_slice_info_def.full_shape]
        self.var_offset = [i for i in save_slice_info_def.var_offset]
        self.var_shape = [i for i in save_slice_info_def.var_shape]
      else:
        self.full_name = full_name
        self.full_shape = full_shape
        self.var_offset = var_offset
        self.var_shape = var_shape
Esempio n. 13
0
 def _restore_collections(dest_graph, src_meta_graph_def,
                          collection_keys):
     """Restores collections that we need to keep."""
     scope = ""
     for key in collection_keys:
         collection_def = src_meta_graph_def.collection_def[key]
         kind = collection_def.WhichOneof("kind")
         if kind is None:
             tf_logging.error(
                 "Cannot identify data type for collection %s. Skipping.",
                 key)
             continue
         from_proto = ops.get_from_proto_function(key)
         if from_proto and kind == "bytes_list":
             proto_type = ops.get_collection_proto_type(key)
             # It is assumed that there are no Variables Keys in collections
             for value in collection_def.bytes_list.value:
                 proto = proto_type()
                 proto.ParseFromString(value)
                 try:
                     new_value = from_proto(proto, import_scope=scope)
                 except:
                     continue
                 dest_graph.add_to_collection(key, new_value)
         else:
             field = getattr(collection_def, kind)
             if kind == "node_list":
                 for value in field.value:
                     name = ops.prepend_name_scope(value, scope)
                     # Since the graph has been optimized, the node may no longer
                     # exists
                     try:
                         col_op = dest_graph.as_graph_element(name)
                     except (TypeError, ValueError, KeyError) as e:
                         continue
                     dest_graph.add_to_collection(key, col_op)
             elif kind == "int64_list":
                 # NOTE(opensource): This force conversion is to work around the
                 # fact that Python2 distinguishes between int and long, while
                 # Python3 has only int.
                 for value in field.value:
                     dest_graph.add_to_collection(key, int(value))
             else:
                 for value in field.value:
                     dest_graph.add_to_collection(
                         key, ops.prepend_name_scope(value, scope))
Esempio n. 14
0
  def _init_from_proto(self, variable_def, import_scope=None):
    """Initializes from `VariableDef` proto."""
    # Note that init_from_proto is currently not supported in Eager mode.
    assert context.in_graph_mode()
    self._in_graph_mode = True
    assert isinstance(variable_def, variable_pb2.VariableDef)
    if not variable_def.is_resource:
      raise ValueError("Trying to restore Variable as ResourceVariable.")

    # Create from variable_def.
    g = ops.get_default_graph()
    self._handle = g.as_graph_element(
        ops.prepend_name_scope(
            variable_def.variable_name, import_scope=import_scope))
    self._shape = tensor_shape.TensorShape(
        self._handle.op.get_attr("shape"))
    self._handle_device = self._handle.device
    self._handle_name = self._handle.name
    self._initializer_op = g.as_graph_element(
        ops.prepend_name_scope(
            variable_def.initializer_name, import_scope=import_scope))
    # Check whether initial_value_name exists for backwards compatibility.
    if (hasattr(variable_def, "initial_value_name") and
        variable_def.initial_value_name):
      self._initial_value = g.as_graph_element(
          ops.prepend_name_scope(variable_def.initial_value_name,
                                 import_scope=import_scope))
    else:
      self._initial_value = None
    if variable_def.snapshot_name:
      self._cached_value = g.as_graph_element(
          ops.prepend_name_scope(
              variable_def.snapshot_name, import_scope=import_scope))
    else:
      self._cached_value = None
    if variable_def.HasField("save_slice_info_def"):
      self._save_slice_info = variables.Variable.SaveSliceInfo(
          save_slice_info_def=variable_def.save_slice_info_def,
          import_scope=import_scope)
    else:
      self._save_slice_info = None
    self._caching_device = None
    self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
    self._graph_element = self.value()
    self._constraint = None
Esempio n. 15
0
    def _init_from_proto(self, variable_def, import_scope=None):
        """Initializes from `VariableDef` proto."""
        # Note that init_from_proto is currently not supported in Eager mode.
        assert context.in_graph_mode()
        self._in_graph_mode = True
        assert isinstance(variable_def, variable_pb2.VariableDef)
        if not variable_def.is_resource:
            raise ValueError("Trying to restore Variable as ResourceVariable.")

        # Create from variable_def.
        g = ops.get_default_graph()
        self._handle = g.as_graph_element(
            ops.prepend_name_scope(variable_def.variable_name,
                                   import_scope=import_scope))
        self._shape = tensor_shape.TensorShape(
            self._handle.op.get_attr("shape"))
        self._handle_device = self._handle.device
        self._handle_name = self._handle.name
        self._initializer_op = g.as_graph_element(
            ops.prepend_name_scope(variable_def.initializer_name,
                                   import_scope=import_scope))
        # Check whether initial_value_name exists for backwards compatibility.
        if (hasattr(variable_def, "initial_value_name")
                and variable_def.initial_value_name):
            self._initial_value = g.as_graph_element(
                ops.prepend_name_scope(variable_def.initial_value_name,
                                       import_scope=import_scope))
        else:
            self._initial_value = None
        if variable_def.snapshot_name:
            self._cached_value = g.as_graph_element(
                ops.prepend_name_scope(variable_def.snapshot_name,
                                       import_scope=import_scope))
        else:
            self._cached_value = None
        if variable_def.HasField("save_slice_info_def"):
            self._save_slice_info = variables.Variable.SaveSliceInfo(
                save_slice_info_def=variable_def.save_slice_info_def,
                import_scope=import_scope)
        else:
            self._save_slice_info = None
        self._caching_device = None
        self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
        self._graph_element = self.value()
        self._constraint = None
Esempio n. 16
0
 def _restore_collections(dest_graph, src_meta_graph_def, collection_keys):
   """Restores collections that we need to keep."""
   scope = ""
   for key in collection_keys:
     collection_def = src_meta_graph_def.collection_def[key]
     kind = collection_def.WhichOneof("kind")
     if kind is None:
       tf_logging.error(
           "Cannot identify data type for collection %s. Skipping.", key)
       continue
     from_proto = ops.get_from_proto_function(key)
     if from_proto and kind == "bytes_list":
       proto_type = ops.get_collection_proto_type(key)
       # It is assumed that there are no Variables Keys in collections
       for value in collection_def.bytes_list.value:
         proto = proto_type()
         proto.ParseFromString(value)
         try:
           new_value = from_proto(proto, import_scope=scope)
         except:
           continue
         dest_graph.add_to_collection(key, new_value)
     else:
       field = getattr(collection_def, kind)
       if kind == "node_list":
         for value in field.value:
           name = ops.prepend_name_scope(value, scope)
           # Since the graph has been optimized, the node may no longer
           # exists
           try:
             col_op = dest_graph.as_graph_element(name)
           except (TypeError, ValueError, KeyError) as e:
             continue
           dest_graph.add_to_collection(key, col_op)
       elif kind == "int64_list":
         # NOTE(opensource): This force conversion is to work around the
         # fact that Python2 distinguishes between int and long, while
         # Python3 has only int.
         for value in field.value:
           dest_graph.add_to_collection(key, int(value))
       else:
         for value in field.value:
           dest_graph.add_to_collection(key,
                                        ops.prepend_name_scope(value, scope))
    def _init_from_proto(self, variable_def, import_scope=None):
        """Initializes from `VariableDef` proto."""
        # Note that init_from_proto is currently not supported in Eager mode.
        assert not context.executing_eagerly()
        self._in_graph_mode = True
        assert isinstance(variable_def, variable_pb2.VariableDef)
        if not variable_def.is_resource:
            raise ValueError(
                "Trying to restore Variable as EmbeddingVariable.")

        # Create from variable_def.
        g = ops.get_default_graph()
        self._handle = g.as_graph_element(
            ops.prepend_name_scope(variable_def.variable_name,
                                   import_scope=import_scope))
        self._graph_shape = tensor_shape.TensorShape(
            self._handle.op.get_attr("shape"))
        self._handle_device = self._handle.device
        self._handle_name = self._handle.name
        self._initializer_op = g.as_graph_element(
            ops.prepend_name_scope(variable_def.initializer_name,
                                   import_scope=import_scope))
        self._trainable = getattr(variable_def, "trainable", True)
        if variable_def.snapshot_name:
            self._cached_value = g.as_graph_element(
                ops.prepend_name_scope(variable_def.snapshot_name,
                                       import_scope=import_scope))
        else:
            self._cached_value = None
        if variable_def.HasField("save_slice_info_def"):
            self._save_slice_info = variables.Variable.SaveSliceInfo(
                save_slice_info_def=variable_def.save_slice_info_def)
        else:
            self._save_slice_info = None
        self._caching_device = None
        self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
        self._invalid_key = -1
        self._initial_value = ops.convert_to_tensor([0],
                                                    name="initial_value",
                                                    dtype=self._dtype)
        self._invalid_key_type = dtypes.as_dtype(
            self._handle.op.get_attr("Tkeys"))
        self._graph_element = None
        self._constraint = None
Esempio n. 18
0
    def _remap_feed(self, feed, feed_val=None):
        """
        Remap the feeds to the right element in the transformed graph.

        For example, there are N copies of a placeholder for N replicas
          and we have to feed all of them with tensors.

        Args:
            feed: feed graph element or name
            feed_val: feed value

        Returns:
            List of (new_feed, new_feed_value) pairs
        """
        feed_name = feed if isinstance(feed, str) else feed.name
        try:
            transformed_feeds = [
                self._graph_item.graph.as_graph_element(feed_name)
            ]
        except KeyError:
            transformed_feeds = [
                self._graph_item.graph.as_graph_element(
                    ops.prepend_name_scope(feed_name, replica_prefix(i)))
                for i in range(self._graph_transformer.num_local_replicas)
            ]

        num_replicated_feeds = self._graph_transformer.num_local_replicas
        feed = feed if not isinstance(feed, str) else transformed_feeds[0]

        def expand_feed_val(feed_val, feed=feed):
            """Given a original feed or replicated feed, expand the feed value."""
            # If we have replicated placeholders with undefined (polymorphic) shape, we split the feed_val across it;
            #  otherwise we feed all replicated placeholders the same feed_val
            polymorphic_dim = self._polymorphic_dim(feed)
            if polymorphic_dim:
                feed_vals = np.array_split(np.asarray(feed_val),
                                           num_replicated_feeds,
                                           axis=polymorphic_dim)
            else:
                feed_vals = [feed_val for _ in range(num_replicated_feeds)]
            return feed_vals

        if feed_val is not None:
            feed_vals = expand_feed_val(feed_val)
            transformed_feeds = list(zip(transformed_feeds, feed_vals))
        return transformed_feeds, expand_feed_val
Esempio n. 19
0
def get_element_from_tensor_info(tensor_info, graph=None, import_scope=None):
    """Returns the element in the graph described by a TensorInfo proto.

  Args:
    tensor_info: A TensorInfo proto describing an Op or Tensor by name.
    graph: The tf.Graph in which tensors are looked up. If None, the current
      default graph is used.
    import_scope: If not None, names in `tensor_info` are prefixed with this
      string before lookup.

  Returns:
    Op or tensor in `graph` described by `tensor_info`.

  Raises:
    KeyError: If `tensor_info` does not correspond to an op or tensor in `graph`
  """
    graph = graph or ops.get_default_graph()
    return graph.as_graph_element(
        ops.prepend_name_scope(tensor_info.name, import_scope=import_scope))
Esempio n. 20
0
def get_element_from_tensor_info(tensor_info, graph=None, import_scope=None):
  """Returns the element in the graph described by a TensorInfo proto.

  Args:
    tensor_info: A TensorInfo proto describing an Op or Tensor by name.
    graph: The tf.Graph in which tensors are looked up. If None, the current
      default graph is used.
    import_scope: If not None, names in `tensor_info` are prefixed with this
      string before lookup.

  Returns:
    Op or tensor in `graph` described by `tensor_info`.

  Raises:
    KeyError: If `tensor_info` does not correspond to an op or tensor in `graph`
  """
  graph = graph or ops.get_default_graph()
  return graph.as_graph_element(
      ops.prepend_name_scope(tensor_info.name, import_scope=import_scope))
Esempio n. 21
0
def import_scoped_meta_graph_with_return_elements(
        meta_graph_or_file,
        clear_devices=False,
        graph=None,
        import_scope=None,
        input_map=None,
        unbound_inputs_col_name="unbound_inputs",
        restore_collections_predicate=(lambda key: True),
        return_elements=None):
    """Imports graph from `MetaGraphDef` and returns vars and return elements.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates the desired collections, and returns a dictionary of
  all the Variables imported into the name scope.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.
    restore_collections_predicate: a predicate on collection names. A collection
      named c (i.e whose key is c) will be restored iff
      1) `restore_collections_predicate(c)` is True, and
      2) `c != unbound_inputs_col_name`.
    return_elements:  A list of strings containing operation names in the
      `MetaGraphDef` that will be returned as `Operation` objects; and/or
      tensor names in `MetaGraphDef` that will be returned as `Tensor` objects.

  Returns:
    A tuple of (
      dictionary of all the `Variables` imported into the name scope,
      list of `Operation` or `Tensor` objects from the `return_elements` list).

  Raises:
    ValueError: If the graph_def contains unbound inputs.

  """
    if context.executing_eagerly():
        raise ValueError(
            "Exporting/importing meta graphs is not supported when "
            "eager execution is enabled.")
    if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
        meta_graph_def = meta_graph_or_file
    else:
        meta_graph_def = read_meta_graph_file(meta_graph_or_file)

    if unbound_inputs_col_name:
        for key, col_def in meta_graph_def.collection_def.items():
            if key == unbound_inputs_col_name:
                kind = col_def.WhichOneof("kind")
                field = getattr(col_def, kind)
                if field.value and (not input_map or sorted(
                    [compat.as_str(v)
                     for v in field.value]) != sorted(input_map)):
                    raise ValueError(
                        "Graph contains unbound inputs: %s. Must "
                        "provide these inputs through input_map." % ",".join(
                            compat.as_str(v) for v in field.value
                            if not input_map or v not in input_map))
                break

    # Sets graph to default graph if it's not passed in.
    graph = graph or ops.get_default_graph()

    # Gathers the list of nodes we are interested in.
    with graph.as_default():
        producer_op_list = None
        if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
            producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
        input_graph_def = meta_graph_def.graph_def
        # Remove all the explicit device specifications for this node. This helps to
        # make the graph more portable.
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""

        scope_to_prepend_to_names = graph.unique_name(import_scope or "",
                                                      mark_as_used=False)

        imported_return_elements = importer.import_graph_def(
            input_graph_def,
            name=(import_scope or scope_to_prepend_to_names),
            input_map=input_map,
            producer_op_list=producer_op_list,
            return_elements=return_elements)

        # TensorFlow versions before 1.9 (not inclusive) exported SavedModels
        # without a VariableDef.trainable field set.
        tf_version = meta_graph_def.meta_info_def.tensorflow_version
        if not tf_version:
            variables_have_trainable = True
        else:
            variables_have_trainable = (packaging_version.parse(tf_version) >=
                                        packaging_version.parse("1.9"))

        # Sort collections so we see TRAINABLE_VARIABLES first and can default these
        # variables to trainable if the value is not set in their VariableDef.
        sorted_collections = []
        if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def:
            sorted_collections.append((ops.GraphKeys.TRAINABLE_VARIABLES,
                                       meta_graph_def.collection_def[
                                           ops.GraphKeys.TRAINABLE_VARIABLES]))
        for key, value in sorted(meta_graph_def.collection_def.items()):
            if key != ops.GraphKeys.TRAINABLE_VARIABLES:
                sorted_collections.append((key, value))

        # Restores all the other collections.
        variable_objects = {}
        for key, col_def in sorted_collections:
            # Don't add unbound_inputs to the new graph.
            if key == unbound_inputs_col_name:
                continue
            if not restore_collections_predicate(key):
                continue

            kind = col_def.WhichOneof("kind")
            if kind is None:
                logging.error(
                    "Cannot identify data type for collection %s. Skipping.",
                    key)
                continue
            from_proto = ops.get_from_proto_function(key)

            # Temporary change to allow the TFMA evaluator to read metric variables
            # saved as a bytes list.
            # TODO(kathywu): Remove this hack once cl/248406059 has been submitted.
            if key == ops.GraphKeys.METRIC_VARIABLES:
                # Metric variables will use the same proto functions as GLOBAL_VARIABLES
                from_proto = ops.get_from_proto_function(
                    ops.GraphKeys.GLOBAL_VARIABLES)
            if from_proto and kind == "bytes_list":
                proto_type = ops.get_collection_proto_type(key)
                if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
                    for value in col_def.bytes_list.value:
                        variable = variable_objects.get(value, None)
                        if variable is None:
                            proto = proto_type()
                            proto.ParseFromString(value)
                            if not variables_have_trainable:
                                # If the VariableDef proto does not contain a "trainable"
                                # property because it was exported before that property was
                                # added, we default it to whether the variable is in the
                                # TRAINABLE_VARIABLES collection. We've sorted
                                # TRAINABLE_VARIABLES to be first, so trainable variables will
                                # be created from that collection.
                                proto.trainable = (
                                    key == ops.GraphKeys.TRAINABLE_VARIABLES)
                            variable = from_proto(
                                proto, import_scope=scope_to_prepend_to_names)
                            variable_objects[value] = variable
                        graph.add_to_collection(key, variable)
                else:
                    for value in col_def.bytes_list.value:
                        proto = proto_type()
                        proto.ParseFromString(value)
                        graph.add_to_collection(
                            key,
                            from_proto(proto,
                                       import_scope=scope_to_prepend_to_names))
            else:
                field = getattr(col_def, kind)
                if key in _COMPAT_COLLECTION_LIST:
                    logging.warning(
                        "The saved meta_graph is possibly from an older release:\n"
                        "'%s' collection should be of type 'byte_list', but instead "
                        "is of type '%s'.", key, kind)
                if kind == "node_list":
                    for value in field.value:
                        col_op = graph.as_graph_element(
                            ops.prepend_name_scope(value,
                                                   scope_to_prepend_to_names))
                        graph.add_to_collection(key, col_op)
                elif kind == "int64_list":
                    # NOTE(opensource): This force conversion is to work around the fact
                    # that Python2 distinguishes between int and long, while Python3 has
                    # only int.
                    for value in field.value:
                        graph.add_to_collection(key, int(value))
                else:
                    for value in field.value:
                        graph.add_to_collection(
                            key,
                            ops.prepend_name_scope(value,
                                                   scope_to_prepend_to_names))

        var_list = {}
        variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                         scope=scope_to_prepend_to_names)
        for v in variables:
            var_list[ops.strip_name_scope(v.name,
                                          scope_to_prepend_to_names)] = v

    return var_list, imported_return_elements
Esempio n. 22
0
 def _get_tensor(name):
     return graph.get_tensor_by_name(
         ops.prepend_name_scope(name, import_scope=import_scope))
Esempio n. 23
0
def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs",
                             restore_collections_predicate=(lambda key: True)):
  """Recreates a `Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates the desired collections, and returns a dictionary of
  all the Variables imported into the name scope.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.
    restore_collections_predicate: a predicate on collection names. A collection
      named c (i.e whose key is c) will be restored iff
      1) `restore_collections_predicate(c)` is True, and
      2) `c != unbound_inputs_col_name`.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
  if context.executing_eagerly():
    raise ValueError("Exporting/importing meta graphs is not supported when "
                     "eager execution is enabled.")
  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
    meta_graph_def = meta_graph_or_file
  else:
    meta_graph_def = read_meta_graph_file(meta_graph_or_file)

  if unbound_inputs_col_name:
    for key, col_def in meta_graph_def.collection_def.items():
      if key == unbound_inputs_col_name:
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." %
                           ",".join([compat.as_str(v) for v in field.value
                                     if not input_map or v not in input_map]))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""

    scope_to_prepend_to_names = graph.unique_name(
        import_scope or "", mark_as_used=False)

    importer.import_graph_def(
        input_graph_def,
        name=(import_scope or scope_to_prepend_to_names),
        input_map=input_map,
        producer_op_list=producer_op_list)

    # Restores all the other collections.
    variable_objects = {}
    for key, col_def in sorted(meta_graph_def.collection_def.items()):
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue
      if not restore_collections_predicate(key):
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)
      if from_proto and kind == "bytes_list":
        proto_type = ops.get_collection_proto_type(key)
        if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
          for value in col_def.bytes_list.value:
            variable = variable_objects.get(value, None)
            if variable is None:
              proto = proto_type()
              proto.ParseFromString(value)
              variable = from_proto(
                  proto, import_scope=scope_to_prepend_to_names)
              variable_objects[value] = variable
            graph.add_to_collection(key, variable)
        else:
          for value in col_def.bytes_list.value:
            proto = proto_type()
            proto.ParseFromString(value)
            graph.add_to_collection(
                key, from_proto(
                    proto, import_scope=scope_to_prepend_to_names))
      else:
        field = getattr(col_def, kind)
        if key in _COMPAT_COLLECTION_LIST:
          logging.warning(
              "The saved meta_graph is possibly from an older release:\n"
              "'%s' collection should be of type 'byte_list', but instead "
              "is of type '%s'.", key, kind)
        if kind == "node_list":
          for value in field.value:
            col_op = graph.as_graph_element(
                ops.prepend_name_scope(value, scope_to_prepend_to_names))
            graph.add_to_collection(key, col_op)
        elif kind == "int64_list":
          # NOTE(opensource): This force conversion is to work around the fact
          # that Python2 distinguishes between int and long, while Python3 has
          # only int.
          for value in field.value:
            graph.add_to_collection(key, int(value))
        else:
          for value in field.value:
            graph.add_to_collection(
                key, ops.prepend_name_scope(value, scope_to_prepend_to_names))

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                     scope=scope_to_prepend_to_names)
    for v in variables:
      var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v

  return var_list
Esempio n. 24
0
def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs",
                             restore_collections_predicate=(lambda key: True)):
  """Recreates a `Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates the desired collections, and returns a dictionary of
  all the Variables imported into the name scope.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.
    restore_collections_predicate: a predicate on collection names. A collection
      named c (i.e whose key is c) will be restored iff
      1) `restore_collections_predicate(c)` is True, and
      2) `c != unbound_inputs_col_name`.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
  if context.in_eager_mode():
    raise ValueError("Exporting/importing meta graphs is not supported when "
                     "eager execution is enabled.")
  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
    meta_graph_def = meta_graph_or_file
  else:
    meta_graph_def = read_meta_graph_file(meta_graph_or_file)

  if unbound_inputs_col_name:
    for key, col_def in meta_graph_def.collection_def.items():
      if key == unbound_inputs_col_name:
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." %
                           ",".join([compat.as_str(v) for v in field.value
                                     if not input_map or v not in input_map]))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""

    scope_to_prepend_to_names = graph.unique_name(
        import_scope or "", mark_as_used=False)

    importer.import_graph_def(
        input_graph_def, name=(import_scope or ""), input_map=input_map,
        producer_op_list=producer_op_list)

    # Restores all the other collections.
    for key, col_def in sorted(meta_graph_def.collection_def.items()):
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue
      if not restore_collections_predicate(key):
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)
      if from_proto and kind == "bytes_list":
        proto_type = ops.get_collection_proto_type(key)
        for value in col_def.bytes_list.value:
          proto = proto_type()
          proto.ParseFromString(value)
          graph.add_to_collection(
              key, from_proto(proto, import_scope=scope_to_prepend_to_names))
      else:
        field = getattr(col_def, kind)
        if key in _COMPAT_COLLECTION_LIST:
          logging.warning(
              "The saved meta_graph is possibly from an older release:\n"
              "'%s' collection should be of type 'byte_list', but instead "
              "is of type '%s'.", key, kind)
        if kind == "node_list":
          for value in field.value:
            col_op = graph.as_graph_element(
                ops.prepend_name_scope(value, scope_to_prepend_to_names))
            graph.add_to_collection(key, col_op)
        elif kind == "int64_list":
          # NOTE (opensource): This force conversion is to work around the fact id:3223 gh:3224
          # that Python2 distinguishes between int and long, while Python3 has
          # only int.
          for value in field.value:
            graph.add_to_collection(key, int(value))
        else:
          for value in field.value:
            graph.add_to_collection(
                key, ops.prepend_name_scope(value, scope_to_prepend_to_names))

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                     scope=scope_to_prepend_to_names)
    for v in variables:
      var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v

  return var_list
Esempio n. 25
0
def copy_ops_meta_graph(op_list, from_scope, to_scope, replace=None):
    """Copies a list of `Operation`s from one scope to another, with variables
  shared between them.

  Args:
    op_list: A list of `Operation` objects to be copied.
    from_scope: `String` name scope containing the ops to be copied.
    to_scope: `String` name scope under which the copied ops will reside.
    replace: A dictionary containing the mapping from input Tensors of these
      ops to their replacements.

  Returns:
    A dictionary containing the mapping from original ops to their copies and
    a dictionary of `Variables` that have been copied into `to_scope`.

  Raises:
    ValueError: If `from_scope` and `to_scope` are the same.
  """
    if from_scope == to_scope:
        raise ValueError("'from_scope' and 'to_scope' need to be different "
                         "when performing copy in the same graph.")
    op_list = set(op_list)
    op_names = set(op.name for op in op_list)
    op_outputs = set()
    for op in op_list:
        if not op.name.startswith(from_scope):
            raise ValueError("The Operation (%s) to copy is not under "
                             "'from_scope'." % op.name)
        op_outputs.update(set(op.outputs))

    input_map = {}
    as_unbound_inputs = []
    for op in op_list:
        for tensor in op.inputs:
            if not (tensor in op_outputs) or (tensor in replace):
                name = tensor.name[:-2] if tensor.name[
                    -2:] == ":0" else tensor.name
                as_unbound_inputs.append(name)
                if tensor in replace:
                    input_map[_unbound_name(name)] = replace[tensor]
                else:
                    input_map[_unbound_name(name)] = tensor
        for dep in op.control_inputs:
            if dep not in op_list:
                name = "^" + dep.name
                as_unbound_inputs.append(name)
                input_map[_unbound_name(name)] = dep
        for name in op.colocation_groups():
            if name[5:] not in op_names:
                as_unbound_inputs.append(name)
                input_map[_unbound_name(name)] = ops.get_default_graph(). \
                  as_graph_element(name[5:])

    orig_meta_graph = export_ops_meta_graph(
        op_list, export_scope=from_scope, as_unbound_inputs=as_unbound_inputs)
    _ = import_scoped_meta_graph(orig_meta_graph,
                                 import_scope=to_scope,
                                 input_map=input_map)
    copied_ops = {}
    for op in op_list:
        new_op_name = ops.prepend_name_scope(
            ops.strip_name_scope(op.name, from_scope), to_scope)
        new_op = ops.get_default_graph().as_graph_element(new_op_name,
                                                          allow_tensor=False)
        copied_ops[op] = new_op
    return copied_ops
Esempio n. 26
0
 def _get_tensor(name):
   return graph.get_tensor_by_name(
       ops.prepend_name_scope(name, import_scope=import_scope))
Esempio n. 27
0
  def __iadd__(self, other):
    logging.log_first_n(
        logging.WARN,
    return self + other

  def __isub__(self, other):
    logging.log_first_n(
        logging.WARN,
    return self - other

  def __imul__(self, other):
    logging.log_first_n(
        logging.WARN,
    return self * other

  def __idiv__(self, other):
    logging.log_first_n(
        logging.WARN,
    return self / other

  def __itruediv__(self, other):
    logging.log_first_n(
        logging.WARN,
    return self / other

  def __irealdiv__(self, other):
    logging.log_first_n(
        logging.WARN,
    return self / other

  def __ipow__(self, other):
    logging.log_first_n(
        logging.WARN,
    return self ** other

  class SaveSliceInfo(object):
    def __init__(self,
                 full_name=None,
                 full_shape=None,
                 var_offset=None,
                 var_shape=None,
                 save_slice_info_def=None,
                 import_scope=None):
      if save_slice_info_def:
        assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef)
        self.full_name = ops.prepend_name_scope(
            save_slice_info_def.full_name, import_scope=import_scope)
        self.full_shape = [i for i in save_slice_info_def.full_shape]
        self.var_offset = [i for i in save_slice_info_def.var_offset]
        self.var_shape = [i for i in save_slice_info_def.var_shape]
      else:
        self.full_name = full_name
        self.full_shape = full_shape
        self.var_offset = var_offset
        self.var_shape = var_shape

    @property
    def spec(self):
      full_shape_str = " ".join(["%d" % d for d in self.full_shape]) + " "
      sl_spec = ":".join([
          "%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape)
      ])
      return full_shape_str + sl_spec

    def to_proto(self, export_scope=None):
      if (export_scope is None or
          self.full_name.startswith(export_scope)):
        save_slice_info_def = variable_pb2.SaveSliceInfoDef()
        save_slice_info_def.full_name = ops.strip_name_scope(
            self.full_name, export_scope)
        for i in self.full_shape:
          save_slice_info_def.full_shape.append(i)
        for i in self.var_offset:
          save_slice_info_def.var_offset.append(i)
        for i in self.var_shape:
          save_slice_info_def.var_shape.append(i)
        return save_slice_info_def
      else:
        return None

  def _set_save_slice_info(self, save_slice_info):
    self._save_slice_info = save_slice_info

  def _get_save_slice_info(self):
    return self._save_slice_info
Esempio n. 28
0
 def prepend_name_scope(name_scope):
     return ops.prepend_name_scope(name_scope, import_scope)
Esempio n. 29
0
def import_scoped_meta_graph(meta_graph_or_file,
                             clear_devices=False,
                             graph=None,
                             import_scope=None,
                             input_map=None,
                             unbound_inputs_col_name="unbound_inputs"):
  """Recreates a`Graph` saved in a `MetaGraphDef` proto.

  This function takes a `MetaGraphDef` protocol buffer as input. If
  the argument is a file containing a `MetaGraphDef` protocol buffer ,
  it constructs a protocol buffer from the file content. The function
  then adds all the nodes from the `graph_def` field to the
  current graph, recreates all the collections, and returns a saver
  constructed from the `saver_def` field.

  In combination with `export_scoped_meta_graph()`, this function can be used to

  * Serialize a graph along with other Python objects such as `QueueRunner`,
    `Variable` into a `MetaGraphDef`.

  * Restart training from a saved graph and checkpoints.

  * Run inference from a saved graph and checkpoints.

  Args:
    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
      the path) containing a `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      from graph_def. Default false.
    graph: The `Graph` to import into. If `None`, use the default graph.
    import_scope: Optional `string`. Name scope into which to import the
      subgraph. If `None`, the graph is imported to the root name scope.
    input_map: A dictionary mapping input names (as strings) in `graph_def` to
      `Tensor` objects. The values of the named input tensors in the imported
      graph will be re-mapped to the respective `Tensor` values.
    unbound_inputs_col_name: Collection name for looking up unbound inputs.

  Returns:
    A dictionary of all the `Variables` imported into the name scope.

  Raises:
    ValueError: If the graph_def contains unbound inputs.
  """
  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
    meta_graph_def = meta_graph_or_file
  else:
    meta_graph_def = read_meta_graph_file(meta_graph_or_file)

  if unbound_inputs_col_name:
    for key, col_def in meta_graph_def.collection_def.items():
      if key == unbound_inputs_col_name:
        kind = col_def.WhichOneof("kind")
        field = getattr(col_def, kind)
        if field.value and (
            not input_map or
            sorted([compat.as_str(v) for v in field.value]) !=
            sorted(input_map)):
          raise ValueError("Graph contains unbound inputs: %s. Must "
                           "provide these inputs through input_map." %
                           ",".join([compat.as_str(v) for v in field.value]))
        break

  # Sets graph to default graph if it's not passed in.
  graph = graph or ops.get_default_graph()

  # Gathers the list of nodes we are interested in.
  with graph.as_default():
    producer_op_list = None
    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
    input_graph_def = meta_graph_def.graph_def
    # Remove all the explicit device specifications for this node. This helps to
    # make the graph more portable.
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""
    importer.import_graph_def(
        input_graph_def, name=(import_scope or ""), input_map=input_map,
        producer_op_list=producer_op_list)

    # Restores all the other collections.
    for key, col_def in meta_graph_def.collection_def.items():
      # Don't add unbound_inputs to the new graph.
      if key == unbound_inputs_col_name:
        continue

      kind = col_def.WhichOneof("kind")
      if kind is None:
        logging.error("Cannot identify data type for collection %s. Skipping.",
                      key)
        continue
      from_proto = ops.get_from_proto_function(key)
      if from_proto:
        assert kind == "bytes_list"
        proto_type = ops.get_collection_proto_type(key)
        for value in col_def.bytes_list.value:
          proto = proto_type()
          proto.ParseFromString(value)
          graph.add_to_collection(
              key, from_proto(proto, import_scope=import_scope))
      else:
        field = getattr(col_def, kind)
        if kind == "node_list":
          for value in field.value:
            col_op = graph.as_graph_element(
                ops.prepend_name_scope(value, import_scope))
            graph.add_to_collection(key, col_op)
        elif kind == "int64_list":
          # NOTE(opensource): This force conversion is to work around the fact
          # that Python2 distinguishes between int and long, while Python3 has
          # only int.
          for value in field.value:
            graph.add_to_collection(key, int(value))
        else:
          for value in field.value:
            graph.add_to_collection(
                key, ops.prepend_name_scope(value, import_scope))

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.VARIABLES,
                                     scope=import_scope)
    for v in variables:
      var_list[ops.strip_name_scope(v.name, import_scope)] = v

  return var_list
Esempio n. 30
0
def export_scoped_meta_graph(filename=None,
                             graph_def=None,
                             graph=None,
                             export_scope=None,
                             as_text=False,
                             unbound_inputs_col_name="unbound_inputs",
                             clear_devices=False,
                             saver_def=None,
                             clear_extraneous_savers=False,
                             strip_default_attrs=False,
                             save_debug_info=False,
                             **kwargs):
    """Returns `MetaGraphDef` proto. Optionally writes it to filename.

  This function exports the graph, saver, and collection objects into
  `MetaGraphDef` protocol buffer with the intention of it being imported
  at a later time or location to restart training, run inference, or be
  a subgraph.

  Args:
    filename: Optional filename including the path for writing the
      generated `MetaGraphDef` protocol buffer.
    graph_def: `GraphDef` protocol buffer.
    graph: The `Graph` to export. If `None`, use the default graph.
    export_scope: Optional `string`. Name scope under which to extract
      the subgraph. The scope name will be stripped from the node definitions
      for easy import later into new name scopes. If `None`, the whole graph
      is exported.
    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
    unbound_inputs_col_name: Optional `string`. If provided, a string collection
      with the given name will be added to the returned `MetaGraphDef`,
      containing the names of tensors that must be remapped when importing the
      `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      before exporting the graph.
    saver_def: `SaverDef` protocol buffer.
    clear_extraneous_savers: Remove any Saver-related information from the
        graph (both Save/Restore ops and SaverDefs) that are not associated
        with the provided SaverDef.
    strip_default_attrs: Set to true if default valued attributes must be
      removed while exporting the GraphDef.
    save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
      which in the same directory of filename and with `_debug` added before the
      file extension.
    **kwargs: Optional keyed arguments, including meta_info_def and
        collection_list.

  Returns:
    A `MetaGraphDef` proto and dictionary of `Variables` in the exported
    name scope.

  Raises:
    ValueError: When the `GraphDef` is larger than 2GB.
    ValueError: When executing in Eager mode and either `graph_def` or `graph`
      is undefined.
  """
    if context.executing_eagerly() and not (graph_def is not None
                                            and graph is not None):
        raise ValueError(
            "Exporting/importing meta graphs is not supported when "
            "Eager Execution is enabled.")
    graph = graph or ops.get_default_graph()

    exclude_nodes = None
    unbound_inputs = []
    if export_scope or clear_extraneous_savers or clear_devices:
        if graph_def:
            new_graph_def = graph_pb2.GraphDef()
            new_graph_def.versions.CopyFrom(graph_def.versions)
            new_graph_def.library.CopyFrom(graph_def.library)

            if clear_extraneous_savers:
                exclude_nodes = _find_extraneous_saver_nodes(
                    graph_def, saver_def)

            for node_def in graph_def.node:
                if _should_include_node(node_def.name, export_scope,
                                        exclude_nodes):
                    new_node_def = _node_def(node_def,
                                             export_scope,
                                             unbound_inputs,
                                             clear_devices=clear_devices)
                    new_graph_def.node.extend([new_node_def])
            graph_def = new_graph_def
        else:
            # Only do this complicated work if we want to remove a name scope.
            graph_def = graph_pb2.GraphDef()
            # pylint: disable=protected-access
            graph_def.versions.CopyFrom(graph.graph_def_versions)
            bytesize = 0

            if clear_extraneous_savers:
                exclude_nodes = _find_extraneous_saver_nodes(
                    graph.as_graph_def(), saver_def)

            for key in sorted(graph._nodes_by_id):
                if _should_include_node(graph._nodes_by_id[key].name,
                                        export_scope, exclude_nodes):
                    value = graph._nodes_by_id[key]
                    # pylint: enable=protected-access
                    node_def = _node_def(value.node_def,
                                         export_scope,
                                         unbound_inputs,
                                         clear_devices=clear_devices)
                    graph_def.node.extend([node_def])
                    if value.outputs:
                        assert "_output_shapes" not in graph_def.node[-1].attr
                        graph_def.node[-1].attr[
                            "_output_shapes"].list.shape.extend([
                                output.get_shape().as_proto()
                                for output in value.outputs
                            ])
                    bytesize += value.node_def.ByteSize()
                    if bytesize >= (1 << 31) or bytesize < 0:
                        raise ValueError("GraphDef cannot be larger than 2GB. "
                                         f"Received size: {bytesize}.")

            graph._copy_functions_to_graph_def(graph_def, bytesize)  # pylint: disable=protected-access

        # It's possible that not all the inputs are in the export_scope.
        # If we would like such information included in the exported meta_graph,
        # add them to a special unbound_inputs collection.
        if unbound_inputs_col_name:
            # Clears the unbound_inputs collections.
            graph.clear_collection(unbound_inputs_col_name)
            for k in unbound_inputs:
                graph.add_to_collection(unbound_inputs_col_name, k)

    var_list = {}
    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                     scope=export_scope)
    for v in variables:
        if _should_include_node(v, export_scope, exclude_nodes):
            var_list[ops.strip_name_scope(v.name, export_scope)] = v

    scoped_meta_graph_def = create_meta_graph_def(
        graph_def=graph_def,
        graph=graph,
        export_scope=export_scope,
        exclude_nodes=exclude_nodes,
        clear_extraneous_savers=clear_extraneous_savers,
        saver_def=saver_def,
        strip_default_attrs=strip_default_attrs,
        **kwargs)

    if filename:
        graph_io.write_graph(scoped_meta_graph_def,
                             os.path.dirname(filename),
                             os.path.basename(filename),
                             as_text=as_text)
        if save_debug_info:
            name, _ = os.path.splitext(filename)
            debug_filename = "{name}{ext}".format(name=name, ext=".debug")

            # Gets the operation from the graph by the name. Excludes variable nodes,
            # so only the nodes in the frozen models are included.
            # TODO(liufengdb): fix this for functions.
            ops_to_export = []
            for node in scoped_meta_graph_def.graph_def.node:
                scoped_op_name = ops.prepend_name_scope(
                    node.name, export_scope)
                ops_to_export.append(
                    ("", graph.get_operation_by_name(scoped_op_name)))

            graph_debug_info = error_interpolation.create_graph_debug_info_def(
                ops_to_export)

            graph_io.write_graph(graph_debug_info,
                                 os.path.dirname(debug_filename),
                                 os.path.basename(debug_filename),
                                 as_text=as_text)

    return scoped_meta_graph_def, var_list
Esempio n. 31
0
    def _remap_fetch(self, fetch):
        """
        Remap the user-provided fetches to the right list of fetches after graph transformations.

        Cases:
            * If original fetch exists (which is not affected by graph transformation), fetch the original.
            * Otherwise, for fetches that are train_ops, fetch them on all replicas;
            * for other fetches, only fetch it on master replica.
                * For example, for partitioned vars, it corresponds to the concat one as_tensor on the first replica.
        """
        _remap_element = self._remap_element
        fetch_type = type(fetch)
        fetch_name = fetch if isinstance(fetch, str) else fetch.name
        contract_fn = lambda fetched_vals: fetched_vals[0]  # noqa: E731
        try:
            transformed_fetch = [_remap_element(fetch_type, fetch_name)]
        except KeyError:
            master_replica_name = ops.prepend_name_scope(
                fetch_name, replica_prefix(0))
            master_replica_fetch = _remap_element(fetch_type,
                                                  master_replica_name)
            polymorphic_dim = self._polymorphic_dim(master_replica_fetch)

            def is_train_op(op):
                # In TF2: train_op as AssignAddVariableOp
                # In TF1 (being deprecated): no_op with a groups of stateful ops as control dependencies
                # TODO(unless deprecating): make the checking as strict as possible
                return isinstance(
                    op, ops.Operation) and (op.op_def.is_stateful
                                            or op.op_def.name == 'NoOp')

            if is_train_op(master_replica_fetch):
                transformed_fetch = [
                    _remap_element(
                        fetch_type,
                        ops.prepend_name_scope(fetch_name, replica_prefix(i)))
                    for i in range(self._graph_transformer.num_local_replicas)
                ]
                ####################################################################
                # # For Debugging Local Replicas
                ####################################################################
                # transformed_fetch = [
                #     self._graph_item.graph.as_graph_element('AutoDist-Replica-0/emb/part_0_take_grad')
                # ]
                # transformed_fetch = [
                #     _remap_element(ops.Tensor, ops.prepend_name_scope(
                #         'Mean:0',
                #         replica_prefix(i)))
                #     for i in range(self._graph_transformer.num_local_replicas)
                # ]
                # transformed_fetch = [_remap_element(ops.Tensor,
                #     ops.prepend_name_scope(
                #         'sampled_softmax_loss/embedding_lookup:0',
                #         replica_prefix(1)
                #     )
                # )]
                ####################################################################
                logging.debug('Fetch mapped from {} to {}'.format(
                    fetch, transformed_fetch))
            elif polymorphic_dim:
                transformed_fetch = [
                    _remap_element(
                        fetch_type,
                        ops.prepend_name_scope(fetch_name, replica_prefix(i)))
                    for i in range(self._graph_transformer.num_local_replicas)
                ]
                contract_fn = lambda fetch_vals: np.concatenate(
                    fetch_vals, axis=polymorphic_dim)  # noqa: E731
            else:
                transformed_fetch = [master_replica_fetch]
        return transformed_fetch, contract_fn
Esempio n. 32
0
def partially_apply_saved_transform(saved_model_dir, input_tensors):
    """Apply a transform graph, represented as a SavedModel, to existing Tensors.

  This adds nodes to a graph that already contains Tensors representing the
  inputs.  These input Tensors may be placeholders that will be fed when the
  graph is executed, or may be the outputs of some Ops.  Most typically, the
  input Tensors are reading and/or parsing Ops, but they could be anything--
  including the outputs of a prior application of this function using another
  transform graph.

  This function operates on the default Graph in the default Session, and so
  must be called within a context where these are provided.

  Args:
    saved_model_dir: A SavedModel directory providing a transform
      graph.  The MetaGraphDef and signature are selected from the SavedModel
      using keys defined in `../constants.py` ('transform' and
      'transform_signature', respectively).
    input_tensors: a dict of logical name to Tensor.  The logical names must
      be a subset of those in the input signature of the transform graph, and
      the corresponding Tensors must have the expected types and shapes.

  Returns:
    A pair of (unbound_inputs, outputs) where unbound_inputs is a dict of
    logical name to Tensors that are yet to be mapped or fed, and outputs is
    a dict of logical name to Tensor, as provided by the output signature
    of the transform graph

  Raises:
    ValueError: if the provided input_tensors dict has keys that are not part
      of the input signature, or any of the provided inputs have the wrong
      type or shape.
    RuntimeError: if there is no default graph available to which to apply the
      transform.
  """
    decomposed_input_tensors = _decompose_sparse_tensors(input_tensors)

    meta_graph_def, input_signature, output_signature = (
        _load_transform_saved_model(saved_model_dir))

    # Check for inputs that were not part of the input signature.
    unexpected_inputs = (set(decomposed_input_tensors.keys()) -
                         set(input_signature.keys()))
    if unexpected_inputs:
        raise ValueError('Unexpected inputs '
                         'to transform: {}'.format(unexpected_inputs))

    # Create a map from tensor names in the graph to be imported, to the tensors
    # specified in `input_tensors`.
    input_map = {
        input_signature[decomposed_logical_name]:
        decomposed_input_tensors[decomposed_logical_name]
        for decomposed_logical_name in decomposed_input_tensors
    }

    graph = tf.get_default_graph()
    if graph is None:
        raise RuntimeError('apply_saved_transform() requires a default graph.')

    # unique_name may produce e.g. transform_5.  The result has no trailing slash.
    scope = graph.unique_name('transform', mark_as_used=False)

    # Load the transform graph, applying it to existing Tensors via input_map.
    # Throws ValueError if the input_map gives mismatched types or shapes.
    saver = tf_saver.import_meta_graph(meta_graph_def,
                                       import_scope=scope,
                                       input_map=input_map)
    if saver:
        tf.logging.warn(
            'Transform graphs should not have saved Variables, but this '
            'one does.  Variable values will *not* be restored.')

    # Add computed output tensors to the output.  There are two cases.  When the
    # output is not in the input_map, then we look up the tensor in the imported
    # graph by preprending the import scope and looking up the tensor by name.
    # This will fail if the expected output tensor is not now in the graph
    # under the expected name scope.  When the output is in the input map, then
    # that tensor will have been re-mapped so we use the tensor given in the
    # input_map.
    def lookup_remapped_tensor(tensor_name):
        if tensor_name in input_map:
            return input_map[tensor_name]
        else:
            return graph.get_tensor_by_name(
                ops.prepend_name_scope(tensor_name, scope))

    decomposed_output_tensors = {
        decomposed_logical_name: lookup_remapped_tensor(tensor_name)
        for decomposed_logical_name, tensor_name in output_signature.items()
    }
    # Do the same for input tensors, where we assume such tensors are not in the
    # input_map since identical tensors in an input_map would be an error.
    decomposed_unbound_input_tensors = {
        decomposed_logical_name:
        graph.get_tensor_by_name(ops.prepend_name_scope(tensor_name, scope))
        for decomposed_logical_name, tensor_name in input_signature.items()
        if decomposed_logical_name not in decomposed_input_tensors
    }

    outputs = _recompose_sparse_tensors(decomposed_output_tensors)
    unbound_inputs = _recompose_sparse_tensors(
        decomposed_unbound_input_tensors)
    return unbound_inputs, outputs
Esempio n. 33
0
 def lookup_remapped_tensor(tensor_name):
     if tensor_name in input_map:
         return input_map[tensor_name]
     else:
         return graph.get_tensor_by_name(
             ops.prepend_name_scope(tensor_name, scope))
Esempio n. 34
0
def _partially_apply_saved_transform_impl(
    saved_model_dir, logical_input_map, tensor_replacement_map=None,
    fetch_tensor_names=None):
  """Shared code for partially_apply_saved_transform and fetch_tensor_values.

  This adds nodes to a graph that already contains Tensors representing the
  inputs.  These input Tensors may be placeholders that will be fed when the
  graph is executed, or may be the outputs of some Ops.  Most typically, the
  input Tensors are reading and/or parsing Ops, but they could be anything--
  including the outputs of a prior application of this function using another
  transform graph.

  This function operates on the default Graph in the default Session, and so
  must be called within a context where these are provided.

  Args:
    saved_model_dir: A SavedModel directory providing a transform
      graph.  The MetaGraphDef and signature are selected from the SavedModel
      using keys defined in `../constants.py` ('transform' and
      'transform_signature', respectively).
    logical_input_map: a dict of logical name to Tensor.  The logical names must
      be a subset of those in the input signature of the transform graph, and
      the corresponding Tensors must have the expected types and shapes.
    tensor_replacement_map: a dict of tensor names to `Tensors`.
    fetch_tensor_names: a list of tensor names.

  Returns:
    A tuple of (unbound_inputs, outputs, fetched_tensors) where unbound_inputs
    is a dict of logical name to Tensors that are yet to be mapped or fed,
    outputs is a dict of logical name to Tensor, as provided by the output
    signature of the transform graph, and fetched_tensors is a dict of tensor
    names to `Tensor`s where the tensor names are the names given by
    `fetched_tensor_names`.

  Raises:
    ValueError: if the provided input_tensors dict has keys that are not part
      of the input signature, or any of the provided inputs have the wrong
      type or shape.
    RuntimeError: if there is no default graph available to which to apply the
      transform.
  """
  graph = tf.get_default_graph()
  if graph is None:
    raise RuntimeError('apply_saved_transform() requires a default graph.')

  decomposed_input_tensors = _decompose_sparse_tensors(logical_input_map)

  meta_graph_def, input_signature, output_signature, asset_path_dict = (
      _load_transform_saved_model(saved_model_dir))
  asset_tensor_dict = {k: ops.convert_to_tensor(v)
                       for k, v in asset_path_dict.items()}

  # Check for inputs that were not part of the input signature.
  unexpected_inputs = (set(six.iterkeys(decomposed_input_tensors)) -
                       set(six.iterkeys(input_signature)))
  if unexpected_inputs:
    raise ValueError('Unexpected inputs '
                     'to transform: {}'.format(unexpected_inputs))

  # Create a map from tensor names in the graph to be imported, to the tensors
  # specified in `input_tensors`.
  input_map = {
      input_signature[decomposed_logical_name]:
      decomposed_input_tensors[decomposed_logical_name]
      for decomposed_logical_name in decomposed_input_tensors}
  input_map.update(asset_tensor_dict)
  if tensor_replacement_map:
    input_map.update(tensor_replacement_map)

  # unique_name may produce e.g. transform_5.  The result has no trailing slash.
  scope = graph.unique_name('transform', mark_as_used=False)

  # unique_name returns an "absolute" name while we want a name relative to the
  # current scope.  Therefore, we check if the current name stack is non-empty,
  # and if so, strip out the existing name scope.
  if graph.get_name_scope():
    current_name_scope = graph.get_name_scope() + '/'
    assert scope.startswith(current_name_scope)
    import_scope = scope[len(current_name_scope):]
  else:
    import_scope = scope


  # Save the ASSET_FILEPATHS before importing the MetaGraphDef
  current_assets = graph.get_collection(tf.GraphKeys.ASSET_FILEPATHS)

  # Warn user if meta_graph_def has saved variables
  if tf.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def:
    trainable_vars = meta_graph_def.collection_def[
        tf.GraphKeys.TRAINABLE_VARIABLES].bytes_list.value
    if trainable_vars:
      raise ValueError(
          'The SavedModel contained trainable variables {}.  Because this '
          'function is typically called in the input_fn, trainable variables '
          'are disallowed'.format(trainable_vars))

  # Load the transform graph, applying it to existing Tensors via input_map.
  # Throws ValueError if the input_map gives mismatched types or shapes.
  saver = tf_saver.import_meta_graph(meta_graph_def,
                                     import_scope=import_scope,
                                     input_map=input_map)

  # Wipe out AssetFileDef collection; it is obsolete after loading
  graph.clear_collection(tf.saved_model.constants.ASSETS_KEY)

  # The import may have added Tensors to the ASSET_FILEPATHS collection that
  # were substituted via input_map.  To account for this, wipe out the
  # collection, restore the preexisting collection values, and then write in
  # the new substituted Tensors.
  graph.clear_collection(tf.GraphKeys.ASSET_FILEPATHS)
  for asset_path_tensor in current_assets:
    graph.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, asset_path_tensor)
  for asset_path_tensor in asset_tensor_dict.values():
    graph.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, asset_path_tensor)

  if saver:
    checkpoint_path = os.path.join(
        tf.compat.as_bytes(saved_model_dir),
        tf.compat.as_bytes(tf.saved_model.constants.VARIABLES_DIRECTORY),
        tf.compat.as_bytes(tf.saved_model.constants.VARIABLES_FILENAME))

    # We can't use the scope rename from init_from_checkpoint because it relies
    # on var scopes not rebuilt by import_meta_graph. So we need to construct it
    # explicitly by iterating over the variables.
    var_map = {}
    for var in tf.global_variables():
      if var.op.name.startswith(scope):
        var_map[var.op.name[len(scope)+1:]] = var

    if var_map:
      tf.train.init_from_checkpoint(checkpoint_path, var_map)

  # Add computed output tensors to the output.  There are two cases.  When the
  # output is not in the input_map, then we look up the tensor in the imported
  # graph by prepending the import scope and looking up the tensor by name.
  # This will fail if the expected output tensor is not now in the graph
  # under the expected name scope.  When the output is in the input map, then
  # that tensor will have been re-mapped so we use the tensor given in the
  # input_map.
  def lookup_remapped_tensor(tensor_name):
    if tensor_name in input_map:
      return input_map[tensor_name]
    else:
      return graph.get_tensor_by_name(
          ops.prepend_name_scope(tensor_name, scope))
  decomposed_output_tensors = {
      decomposed_logical_name: lookup_remapped_tensor(tensor_name)
      for decomposed_logical_name, tensor_name
      in six.iteritems(output_signature)
  }
  # Do the same for input tensors, where we assume such tensors are not in the
  # input_map since identical tensors in an input_map would be an error.
  decomposed_unbound_input_tensors = {
      decomposed_logical_name: graph.get_tensor_by_name(
          ops.prepend_name_scope(tensor_name, scope))
      for decomposed_logical_name, tensor_name in six.iteritems(input_signature)
      if decomposed_logical_name not in decomposed_input_tensors
  }
  if fetch_tensor_names is None:
    fetch_tensor_names = []
  fetched_tensors = {
      name: lookup_remapped_tensor(name) for name in fetch_tensor_names}

  outputs = _recompose_sparse_tensors(decomposed_output_tensors)
  unbound_inputs = _recompose_sparse_tensors(decomposed_unbound_input_tensors)
  return unbound_inputs, outputs, fetched_tensors
Esempio n. 35
0
def _partially_apply_saved_transform_impl(saved_model_dir,
                                          logical_input_map,
                                          tensor_replacement_map=None,
                                          fetch_tensor_names=None):
    """Shared code for partially_apply_saved_transform and fetch_tensor_values.

  This adds nodes to a graph that already contains Tensors representing the
  inputs.  These input Tensors may be placeholders that will be fed when the
  graph is executed, or may be the outputs of some Ops.  Most typically, the
  input Tensors are reading and/or parsing Ops, but they could be anything--
  including the outputs of a prior application of this function using another
  transform graph.

  This function operates on the default Graph in the default Session, and so
  must be called within a context where these are provided.

  Args:
    saved_model_dir: A SavedModel directory providing a transform
      graph.  The MetaGraphDef and signature are selected from the SavedModel
      using keys defined in `../constants.py` ('transform' and
      'transform_signature', respectively).
    logical_input_map: a dict of logical name to Tensor.  The logical names must
      be a subset of those in the input signature of the transform graph, and
      the corresponding Tensors must have the expected types and shapes.
    tensor_replacement_map: a dict of tensor names to `Tensors`.
    fetch_tensor_names: a list of tensor names.

  Returns:
    A tuple of (unbound_inputs, outputs, fetched_tensors) where unbound_inputs
    is a dict of logical name to Tensors that are yet to be mapped or fed,
    outputs is a dict of logical name to Tensor, as provided by the output
    signature of the transform graph, and fetched_tensors is a dict of tensor
    names to `Tensor`s where the tensor names are the names given by
    `fetched_tensor_names`.

  Raises:
    ValueError: if the provided input_tensors dict has keys that are not part
      of the input signature, or any of the provided inputs have the wrong
      type or shape.
    RuntimeError: if there is no default graph available to which to apply the
      transform.
  """
    graph = tf.get_default_graph()
    if graph is None:
        raise RuntimeError('apply_saved_transform() requires a default graph.')

    decomposed_input_tensors = _decompose_sparse_tensors(logical_input_map)

    meta_graph_def, input_signature, output_signature, asset_path_dict = (
        _load_transform_saved_model(saved_model_dir))
    asset_tensor_dict = {
        k: ops.convert_to_tensor(v)
        for k, v in asset_path_dict.items()
    }

    # Check for inputs that were not part of the input signature.
    unexpected_inputs = (set(six.iterkeys(decomposed_input_tensors)) -
                         set(six.iterkeys(input_signature)))
    if unexpected_inputs:
        raise ValueError('Unexpected inputs '
                         'to transform: {}'.format(unexpected_inputs))

    # Create a map from tensor names in the graph to be imported, to the tensors
    # specified in `input_tensors`.
    input_map = {
        input_signature[decomposed_logical_name]:
        decomposed_input_tensors[decomposed_logical_name]
        for decomposed_logical_name in decomposed_input_tensors
    }
    input_map.update(asset_tensor_dict)
    if tensor_replacement_map:
        input_map.update(tensor_replacement_map)

    # unique_name may produce e.g. transform_5.  The result has no trailing slash.
    scope = graph.unique_name('transform', mark_as_used=False)

    # Save the ASSET_FILEPATHS before importing the MetaGraphDef
    current_assets = graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)

    # Load the transform graph, applying it to existing Tensors via input_map.
    # Throws ValueError if the input_map gives mismatched types or shapes.
    saver = tf_saver.import_meta_graph(meta_graph_def,
                                       import_scope=scope,
                                       input_map=input_map)

    for op in graph.get_operations():
        # pylint: disable=protected-access
        if op.type == b'Where' and 'T' in op._node_def.attr:
            del op._node_def.attr['T']

    # Wipe out AssetFileDef collection; it is obsolete after loading
    graph.clear_collection(tf.saved_model.constants.ASSETS_KEY)

    # The import may have added Tensors to the ASSET_FILEPATHS collection that
    # were substituted via input_map.  To account for this, wipe out the
    # collection, restore the preexisting collection values, and then write in
    # the new substituted Tensors.
    graph.clear_collection(ops.GraphKeys.ASSET_FILEPATHS)
    for asset_path_tensor in current_assets:
        graph.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS,
                                asset_path_tensor)
    for asset_path_tensor in asset_tensor_dict.values():
        graph.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS,
                                asset_path_tensor)

    if saver:
        tf.logging.warn(
            'Transform graphs should not have saved Variables, but this '
            'one does.  Variable values will *not* be restored.')

    # Add computed output tensors to the output.  There are two cases.  When the
    # output is not in the input_map, then we look up the tensor in the imported
    # graph by prepending the import scope and looking up the tensor by name.
    # This will fail if the expected output tensor is not now in the graph
    # under the expected name scope.  When the output is in the input map, then
    # that tensor will have been re-mapped so we use the tensor given in the
    # input_map.
    def lookup_remapped_tensor(tensor_name):
        if tensor_name in input_map:
            return input_map[tensor_name]
        else:
            return graph.get_tensor_by_name(
                ops.prepend_name_scope(tensor_name, scope))

    decomposed_output_tensors = {
        decomposed_logical_name: lookup_remapped_tensor(tensor_name)
        for decomposed_logical_name, tensor_name in six.iteritems(
            output_signature)
    }
    # Do the same for input tensors, where we assume such tensors are not in the
    # input_map since identical tensors in an input_map would be an error.
    decomposed_unbound_input_tensors = {
        decomposed_logical_name:
        graph.get_tensor_by_name(ops.prepend_name_scope(tensor_name, scope))
        for decomposed_logical_name, tensor_name in six.iteritems(
            input_signature)
        if decomposed_logical_name not in decomposed_input_tensors
    }
    if fetch_tensor_names is None:
        fetch_tensor_names = []
    fetched_tensors = {
        name: lookup_remapped_tensor(name)
        for name in fetch_tensor_names
    }

    outputs = _recompose_sparse_tensors(decomposed_output_tensors)
    unbound_inputs = _recompose_sparse_tensors(
        decomposed_unbound_input_tensors)
    return unbound_inputs, outputs, fetched_tensors
Esempio n. 36
0
def export_scoped_meta_graph(filename=None,
                             graph_def=None,
                             graph=None,
                             export_scope=None,
                             as_text=False,
                             unbound_inputs_col_name="unbound_inputs",
                             clear_devices=False,
                             saver_def=None,
                             clear_extraneous_savers=False,
                             strip_default_attrs=False,
                             save_debug_info=False,
                             **kwargs):
  """Returns `MetaGraphDef` proto. Optionally writes it to filename.

  This function exports the graph, saver, and collection objects into
  `MetaGraphDef` protocol buffer with the intention of it being imported
  at a later time or location to restart training, run inference, or be
  a subgraph.

  Args:
    filename: Optional filename including the path for writing the
      generated `MetaGraphDef` protocol buffer.
    graph_def: `GraphDef` protocol buffer.
    graph: The `Graph` to export. If `None`, use the default graph.
    export_scope: Optional `string`. Name scope under which to extract
      the subgraph. The scope name will be stripped from the node definitions
      for easy import later into new name scopes. If `None`, the whole graph
      is exported.
    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
    unbound_inputs_col_name: Optional `string`. If provided, a string collection
      with the given name will be added to the returned `MetaGraphDef`,
      containing the names of tensors that must be remapped when importing the
      `MetaGraphDef`.
    clear_devices: Boolean which controls whether to clear device information
      before exporting the graph.
    saver_def: `SaverDef` protocol buffer.
    clear_extraneous_savers: Remove any Saver-related information from the
        graph (both Save/Restore ops and SaverDefs) that are not associated
        with the provided SaverDef.
    strip_default_attrs: Set to true if default valued attributes must be
      removed while exporting the GraphDef.
    save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
      which in the same directory of filename and with `_debug` added before the
      file extension.
    **kwargs: Optional keyed arguments, including meta_info_def and
        collection_list.

  Returns:
    A `MetaGraphDef` proto and dictionary of `Variables` in the exported
    name scope.

  Raises:
    ValueError: When the `GraphDef` is larger than 2GB.
  """
  if context.executing_eagerly():
    raise ValueError("Exporting/importing meta graphs is not supported when "
                     "Eager Execution is enabled.")
  graph = graph or ops.get_default_graph()

  exclude_nodes = None
  unbound_inputs = []
  if export_scope or clear_extraneous_savers or clear_devices:
    if graph_def:
      new_graph_def = graph_pb2.GraphDef()
      new_graph_def.versions.CopyFrom(graph_def.versions)
      new_graph_def.library.CopyFrom(graph_def.library)

      if clear_extraneous_savers:
        exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def)

      for node_def in graph_def.node:
        if _should_include_node(node_def.name, export_scope, exclude_nodes):
          new_node_def = _node_def(node_def, export_scope, unbound_inputs,
                                   clear_devices=clear_devices)
          new_graph_def.node.extend([new_node_def])
      graph_def = new_graph_def
    else:
      # Only do this complicated work if we want to remove a name scope.
      graph_def = graph_pb2.GraphDef()
      # pylint: disable=protected-access
      graph_def.versions.CopyFrom(graph.graph_def_versions)
      bytesize = 0

      if clear_extraneous_savers:
        exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(),
                                                     saver_def)

      for key in sorted(graph._nodes_by_id):
        if _should_include_node(graph._nodes_by_id[key].name,
                                export_scope,
                                exclude_nodes):
          value = graph._nodes_by_id[key]
          # pylint: enable=protected-access
          node_def = _node_def(value.node_def, export_scope, unbound_inputs,
                               clear_devices=clear_devices)
          graph_def.node.extend([node_def])
          if value.outputs:
            assert "_output_shapes" not in graph_def.node[-1].attr
            graph_def.node[-1].attr["_output_shapes"].list.shape.extend([
                output.get_shape().as_proto() for output in value.outputs])
          bytesize += value.node_def.ByteSize()
          if bytesize >= (1 << 31) or bytesize < 0:
            raise ValueError("GraphDef cannot be larger than 2GB.")

      graph._copy_functions_to_graph_def(graph_def, bytesize)  # pylint: disable=protected-access

    # It's possible that not all the inputs are in the export_scope.
    # If we would like such information included in the exported meta_graph,
    # add them to a special unbound_inputs collection.
    if unbound_inputs_col_name:
      # Clears the unbound_inputs collections.
      graph.clear_collection(unbound_inputs_col_name)
      for k in unbound_inputs:
        graph.add_to_collection(unbound_inputs_col_name, k)

  var_list = {}
  variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                   scope=export_scope)
  for v in variables:
    if _should_include_node(v, export_scope, exclude_nodes):
      var_list[ops.strip_name_scope(v.name, export_scope)] = v

  scoped_meta_graph_def = create_meta_graph_def(
      graph_def=graph_def,
      graph=graph,
      export_scope=export_scope,
      exclude_nodes=exclude_nodes,
      clear_extraneous_savers=clear_extraneous_savers,
      saver_def=saver_def,
      strip_default_attrs=strip_default_attrs,
      **kwargs)

  if filename:
    graph_io.write_graph(
        scoped_meta_graph_def,
        os.path.dirname(filename),
        os.path.basename(filename),
        as_text=as_text)
    if save_debug_info:
      name, _ = os.path.splitext(filename)
      debug_filename = "{name}{ext}".format(name=name, ext=".debug")

      # Gets the operation from the graph by the name.
      ops_to_export = {}
      for node in scoped_meta_graph_def.graph_def.node:
        scoped_op_name = ops.prepend_name_scope(node.name, export_scope)
        ops_to_export.add(graph.get_operation_by_name(scoped_op_name))

      graph_debug_info = create_graph_debug_info_def(ops_to_export)

      graph_io.write_graph(
          graph_debug_info,
          os.path.dirname(debug_filename),
          os.path.basename(debug_filename),
          as_text=as_text)

  return scoped_meta_graph_def, var_list