Esempio n. 1
0
def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type, name="restore_slice", preferred_shard=-1):
    """Restore a tensor slice from a set of files with a given pattern.

  Example usage:
    RestoreSlice("/foo/bar-?????-of-?????", "w", "10 10 0,2:-", DT_FLOAT)

  Args:
    file_pattern: the file pattern used to match a set of checkpoint files.
    tensor_name: the name of the tensor to restore.
    shape_and_slice: the shape-and-slice spec of the slice.
    tensor_type: the type of the tensor to restore.
    name: string.  Optional name for the op.
    preferred_shard: Int. Optional shard to open first in the checkpoint file.

  Returns:
    A tensor of type "tensor_type".
  """
    base_type = dtypes.as_dtype(tensor_type).base_dtype
    return gen_io_ops._restore_slice(file_pattern, tensor_name, shape_and_slice, base_type, preferred_shard, name=name)
Esempio n. 2
0
def _set_checkpoint_initializer(variable, file_pattern, tensor_name, slice_spec,
                                name="checkpoint_initializer"):
  """Sets variable initializer to assign op form value in checkpoint's tensor.

  Args:
    variable: `Variable` object.
    file_pattern: string, where to load checkpoints from.
    tensor_name: Name of the `Tensor` to load from checkpoint reader.
    slice_spec: Slice specification for loading partitioned variables.
    name: Name of the operation.
  """
  base_type = variable.dtype.base_dtype
  restore_op = gen_io_ops._restore_slice(
      file_pattern,
      tensor_name,
      slice_spec,
      base_type,
      preferred_shard=-1,
      name=name)
  variable._initializer_op = state_ops.assign(variable, restore_op)
Esempio n. 3
0
def _set_checkpoint_initializer(variable,
                                file_pattern,
                                tensor_name,
                                slice_spec,
                                name="checkpoint_initializer"):
    """Sets variable initializer to assign op form value in checkpoint's tensor.

  Args:
    variable: `Variable` object.
    file_pattern: string, where to load checkpoints from.
    tensor_name: Name of the `Tensor` to load from checkpoint reader.
    slice_spec: Slice specification for loading partitioned variables.
    name: Name of the operation.
  """
    base_type = variable.dtype.base_dtype
    restore_op = gen_io_ops._restore_slice(file_pattern,
                                           tensor_name,
                                           slice_spec,
                                           base_type,
                                           preferred_shard=-1,
                                           name=name)
    variable._initializer_op = state_ops.assign(variable, restore_op)
Esempio n. 4
0
def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type,
                   name="restore_slice", preferred_shard=-1):
  """Restore a tensor slice from a set of files with a given pattern.

  Example usage:
    RestoreSlice("/foo/bar-?????-of-?????", "w", "10 10 0,2:-", DT_FLOAT)

  Args:
    file_pattern: the file pattern used to match a set of checkpoint files.
    tensor_name: the name of the tensor to restore.
    shape_and_slice: the shape-and-slice spec of the slice.
    tensor_type: the type of the tensor to restore.
    name: string.  Optional name for the op.
    preferred_shard: Int. Optional shard to open first in the checkpoint file.

  Returns:
    A tensor of type "tensor_type".
  """
  base_type = dtypes.as_dtype(tensor_type).base_dtype
  return gen_io_ops._restore_slice(
      file_pattern, tensor_name, shape_and_slice, base_type,
      preferred_shard, name=name)
Esempio n. 5
0
 def testRestoreSlice(self):
     op = gen_io_ops._restore_slice("model", "var", "3 4 0,1:-",
                                    dtypes.float32)
     self.assertEqual([1, 4], op.get_shape())