def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type, name="restore_slice", preferred_shard=-1): """Restore a tensor slice from a set of files with a given pattern. Example usage: RestoreSlice("/foo/bar-?????-of-?????", "w", "10 10 0,2:-", DT_FLOAT) Args: file_pattern: the file pattern used to match a set of checkpoint files. tensor_name: the name of the tensor to restore. shape_and_slice: the shape-and-slice spec of the slice. tensor_type: the type of the tensor to restore. name: string. Optional name for the op. preferred_shard: Int. Optional shard to open first in the checkpoint file. Returns: A tensor of type "tensor_type". """ base_type = dtypes.as_dtype(tensor_type).base_dtype return gen_io_ops._restore_slice(file_pattern, tensor_name, shape_and_slice, base_type, preferred_shard, name=name)
def _set_checkpoint_initializer(variable, file_pattern, tensor_name, slice_spec, name="checkpoint_initializer"): """Sets variable initializer to assign op form value in checkpoint's tensor. Args: variable: `Variable` object. file_pattern: string, where to load checkpoints from. tensor_name: Name of the `Tensor` to load from checkpoint reader. slice_spec: Slice specification for loading partitioned variables. name: Name of the operation. """ base_type = variable.dtype.base_dtype restore_op = gen_io_ops._restore_slice( file_pattern, tensor_name, slice_spec, base_type, preferred_shard=-1, name=name) variable._initializer_op = state_ops.assign(variable, restore_op)
def _set_checkpoint_initializer(variable, file_pattern, tensor_name, slice_spec, name="checkpoint_initializer"): """Sets variable initializer to assign op form value in checkpoint's tensor. Args: variable: `Variable` object. file_pattern: string, where to load checkpoints from. tensor_name: Name of the `Tensor` to load from checkpoint reader. slice_spec: Slice specification for loading partitioned variables. name: Name of the operation. """ base_type = variable.dtype.base_dtype restore_op = gen_io_ops._restore_slice(file_pattern, tensor_name, slice_spec, base_type, preferred_shard=-1, name=name) variable._initializer_op = state_ops.assign(variable, restore_op)
def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type, name="restore_slice", preferred_shard=-1): """Restore a tensor slice from a set of files with a given pattern. Example usage: RestoreSlice("/foo/bar-?????-of-?????", "w", "10 10 0,2:-", DT_FLOAT) Args: file_pattern: the file pattern used to match a set of checkpoint files. tensor_name: the name of the tensor to restore. shape_and_slice: the shape-and-slice spec of the slice. tensor_type: the type of the tensor to restore. name: string. Optional name for the op. preferred_shard: Int. Optional shard to open first in the checkpoint file. Returns: A tensor of type "tensor_type". """ base_type = dtypes.as_dtype(tensor_type).base_dtype return gen_io_ops._restore_slice( file_pattern, tensor_name, shape_and_slice, base_type, preferred_shard, name=name)
def testRestoreSlice(self): op = gen_io_ops._restore_slice("model", "var", "3 4 0,1:-", dtypes.float32) self.assertEqual([1, 4], op.get_shape())