Ejemplo n.º 1
0
    def value_tensors(self):
        """Create value `Tensor`s for this object's attributes.

    Does not require that the Python object has been created. Used for
    restore-on-create when executing eagerly.

    Returns:
      A dictionary mapping from object attribute names to `Tensor`s.
    """
        value_tensors = {}
        for serialized_tensor in self.object_proto.attributes:
            checkpoint_key = serialized_tensor.checkpoint_key
            dtype = self._checkpoint.dtype_map[checkpoint_key]
            base_type = dtype.base_dtype
            with ops.init_scope():
                with ops.device("/cpu:0"):
                    # Run the restore itself on the CPU.
                    value, = io_ops.restore_v2(
                        prefix=self._checkpoint.save_path_tensor,
                        tensor_names=[checkpoint_key],
                        shape_and_slices=[""],
                        dtypes=[base_type],
                        name="%s_checkpoint_read" % (serialized_tensor.name, ))
                # Copy the value to the current device if necessary.
                value_tensors[serialized_tensor.name] = array_ops.identity(
                    value)
            return value_tensors
Ejemplo n.º 2
0
  def value_tensors(self):
    """Create value `Tensor`s for this object's attributes.

    Does not require that the Python object has been created. Used for
    restore-on-create when executing eagerly.

    Returns:
      A dictionary mapping from object attribute names to `Tensor`s.
    """
    value_tensors = {}
    for serialized_tensor in self.object_proto.attributes:
      checkpoint_key = serialized_tensor.checkpoint_key
      dtype = self._checkpoint.dtype_map[checkpoint_key]
      base_type = dtype.base_dtype
      with ops.init_scope():
        with ops.device("/cpu:0"):
          # Run the restore itself on the CPU.
          value, = io_ops.restore_v2(
              prefix=self._checkpoint.save_path,
              tensor_names=[checkpoint_key],
              shape_and_slices=[""],
              dtypes=[base_type],
              name="%s_checkpoint_read" % (serialized_tensor.name,))
        # Copy the value to the current device if necessary.
        value_tensors[serialized_tensor.name] = array_ops.identity(value)
      return value_tensors
Ejemplo n.º 3
0
 def restore_ops(self):
     """Create restore ops for this object's attributes."""
     restore_tensors = {}
     for serialized_tensor in self.object_proto.attributes:
         checkpoint_key = serialized_tensor.checkpoint_key
         dtype = self._checkpoint.dtype_map[checkpoint_key]
         base_type = dtype.base_dtype
         with ops.init_scope():
             restore, = io_ops.restore_v2(prefix=self._checkpoint.save_path,
                                          tensor_names=[checkpoint_key],
                                          shape_and_slices=[""],
                                          dtypes=[base_type],
                                          name="%s_checkpoint_read" %
                                          (serialized_tensor.name, ))
             restore_tensors[serialized_tensor.name] = restore
         return restore_tensors
Ejemplo n.º 4
0
    def value_tensors(self, shape_and_slices=None):
        """Create value `Tensor`s for this object's attributes.

    Does not require that the Python object has been created. Used for
    restore-on-create when executing eagerly.

    Args:
      shape_and_slices: A dict mapping from object attribute names to a shape
        and slice string that will be passed to a RestoreV2 op. If the dict is
        None or if an object attribute is not in the dict, the full tensor will
        be restored.

    Returns:
      A dictionary mapping from object attribute names to `Tensor`s.
    """
        value_tensors = {}
        for serialized_tensor in self.object_proto.attributes:
            checkpoint_key = serialized_tensor.checkpoint_key
            dtype = self._checkpoint.dtype_map[checkpoint_key]
            base_type = dtype.base_dtype
            io_device = self._checkpoint.options.experimental_io_device or "cpu:0"
            with ops.init_scope():
                with ops.device(io_device):
                    # Run the restore itself on the io_device(CPU or specified).
                    if (shape_and_slices is not None
                            and serialized_tensor.name in shape_and_slices):
                        shape_and_slice = shape_and_slices[
                            serialized_tensor.name]
                    else:
                        shape_and_slice = ""
                    value, = io_ops.restore_v2(
                        prefix=self._checkpoint.save_path_tensor,
                        tensor_names=[checkpoint_key],
                        shape_and_slices=[shape_and_slice],
                        dtypes=[base_type],
                        name="%s_checkpoint_read" % (serialized_tensor.name, ))
                # Copy the value to the current device if necessary.
                value_tensors[serialized_tensor.name] = array_ops.identity(
                    value)
        return value_tensors
Ejemplo n.º 5
0
  def value_tensors(self):
    """Create value `Tensor`s for this object's attributes.

    Does not require that the Python object has been created. Used for
    restore-on-create when executing eagerly.

    Returns:
      A dictionary mapping from object attribute names to `Tensor`s.
    """
    value_tensors = {}
    for serialized_tensor in self.object_proto.attributes:
      checkpoint_key = serialized_tensor.checkpoint_key
      dtype = self._checkpoint.dtype_map[checkpoint_key]
      base_type = dtype.base_dtype
      with ops.init_scope():
        value, = io_ops.restore_v2(
            prefix=self._checkpoint.save_path,
            tensor_names=[checkpoint_key],
            shape_and_slices=[""],
            dtypes=[base_type],
            name="%s_checkpoint_read" % (serialized_tensor.name,))
        value_tensors[serialized_tensor.name] = value
      return value_tensors
Ejemplo n.º 6
0
 def _call_restore_v2():
     gen_io_ops.restore_v2(checkpoint_path, all_names,
                           [""] * len(all_names), all_dtypes)