def _set_mht_saveable_checkpoint_initializer(saveable, ckpt_file, saveable_name_in_ckpt, name="mht_checkpoint_initializer"): canonical_device = set( pydev.canonical_name(spec.tensor.device) for spec in saveable.specs) if len(canonical_device) != 1: raise ValueError("All tensors of a saveable object must be " "on the same device: %s" % saveable.name) device = canonical_device.pop() keys_spec, values_spec = saveable.specs with ops.device(device), ops.device("/cpu:0"): restore_keys_op = io_ops.restore_v2(ckpt_file, [saveable_name_in_ckpt + "-keys"], [keys_spec.slice_spec], [keys_spec.dtype], name=name)[0] restore_values_op = io_ops.restore_v2(ckpt_file, [saveable_name_in_ckpt + "-values"], [values_spec.slice_spec], [values_spec.dtype], name=name)[0] init_op = saveable.restore([restore_keys_op, restore_values_op], restored_shapes=None) ops.add_to_collection(ops.GraphKeys.DYNAMIC_EMBEDDING_VARIABLE_INITIALIZERS, init_op)
def _process_slot_restoration(self, slot_restoration, variable): """Restore a slot variable's value (creating it if necessary).""" # TODO(allenl): Move this to Optimizer assert isinstance(self, optimizer_lib.Optimizer) named_slots = self._slot_dict(slot_restoration.slot_name) variable_key = optimizer_lib._var_key(variable) # pylint: disable=protected-access existing_slot_variable = named_slots.get(variable_key, None) if existing_slot_variable is None: base_dtype = slot_restoration.value_pointer.dtype.base_dtype initializer, = io_ops.restore_v2( prefix=slot_restoration.value_pointer.save_path, tensor_names=[slot_restoration.value_pointer.checkpoint_key], shape_and_slices=[""], dtypes=[base_dtype], name="checkpoint_initializer") new_slot_variable = slot_creator.create_slot( variable, initializer, slot_restoration.slot_name) if slot_restoration.value_pointer.session is not None: slot_restoration.value_pointer.session.run( new_slot_variable.initializer) named_slots[variable_key] = new_slot_variable else: _assign_existing_variable( existing_slot_variable, value_pointer=slot_restoration.value_pointer)
def ReadNpArrays(file_prefix, nmap): """Reads from a tf checkpoint to fill in values of a NesteMap. Args: file_prefix: A TF checkpoint filename prefix. nmap: A NestedMap of numpy dtypes. Returns: A NestedMap with numpy arrays compatible w/ nmap. """ g = tf.Graph() with g.as_default(): reads = [] for name, dtype in nmap.FlattenItems(): reads.append( io_ops.restore_v2( prefix=file_prefix, tensor_names=[name], shape_and_slices=[""], dtypes=[dtype])[0]) with tf.Session(graph=g) as sess: vals = sess.run(reads) return nmap.Pack(vals)
def restore(self, file_prefix): """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. Returns: A scalar string Tensor containing `file_prefix` with control dependencies on the restore ops. """ restore_specs = [] tensor_structure = [] for saveable in self._saveable_objects: saveable_tensor_structure = [] tensor_structure.append(saveable_tensor_structure) for spec in saveable.specs: saveable_tensor_structure.append(spec.name) restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs) with ops.device("cpu:0"): restored_tensors = io_ops.restore_v2(file_prefix, tensor_names, tensor_slices, tensor_dtypes) structured_restored_tensors = nest.pack_sequence_as( tensor_structure, restored_tensors) for saveable, restored_tensors in zip(self._saveable_objects, structured_restored_tensors): saveable.restore(restored_tensors, restored_shapes=None) return file_prefix
def _set_checkpoint_initializer(variable, ckpt_file, tensor_name, slice_spec, name="checkpoint_initializer"): """Overrides given variable's initialization op. Sets variable initializer to assign op that initializes variable from tensor's value in the checkpoint. Args: variable: `tf.Variable` object. ckpt_file: string, full path of the checkpoint. tensor_name: Name of the tensor to load from the checkpoint. slice_spec: Slice specification for loading partitioned tensors. name: Name of the operation. """ base_type = variable.dtype.base_dtype # Do not colocate with variable since RestoreV2 op only runs on CPU and # colocation will force variable (and other ops that colocate with variable) # to be on CPU as well. It is okay to place the variable's initializer op on # CPU since it will only be run once at the start. with ops.device(variable.device), ops.device("/cpu:0"): restore_op = io_ops.restore_v2(ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] if isinstance(variable, resource_variable_ops.ResourceVariable): init_op = variable.assign(restore_op, read_value=False) else: init_op = state_ops.assign(variable, restore_op) variable._initializer_op = init_op # pylint:disable=protected-access restore_op.set_shape(variable.shape) variable._initial_value = restore_op # pylint:disable=protected-access
def testRelativePath(self): os.chdir(self.get_temp_dir()) self.evaluate(io_ops.save_v2( "ckpt", ["x"], [""], [constant_op.constant(100.)])) self.assertAllEqual([100.], self.evaluate(io_ops.restore_v2( "ckpt", ["x"], [""], [dtypes.float32])))
def restore(self, file_prefix): """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. Returns: An operation which restores the `Saver`'s `SaveableObject`s when run, or None if executing eagerly. """ restore_ops = [] for saveable in self._saveable_objects: if saveable.device: device = saveable_object_util.set_cpu0(saveable.device) else: device = None with ops.device(device): tensors = [] for spec in saveable.specs: tensors.append( io_ops.restore_v2(file_prefix, [spec.name], [spec.slice_spec], [spec.dtype])[0]) restore_ops.append( saveable.restore(tensors, restored_shapes=None)) return control_flow_ops.group(restore_ops)
def restore_tfra_variable(ckpt_path, variable): mhts = variable._tables key_tensor_names = [] value_tensor_names = [] for mht in mhts: assert len(mht.saveable.specs) == 2 key_tensor_names.append(mht.saveable.specs[0].name) value_tensor_names.append(mht.saveable.specs[1].name) latest_ckpt = _get_checkpoint_filename(ckpt_path) restore_op = io_ops.restore_v2(latest_ckpt, key_tensor_names + value_tensor_names, [""] * len(key_tensor_names + value_tensor_names), ([tf.int64] * len(key_tensor_names)) + ([tf.float32] * len(value_tensor_names))) key_tensor = restore_op[:len(key_tensor_names)] value_tensor = restore_op[len(key_tensor_names):] mht_restore_ops = [] index = 0 for mht in mhts: mht_restore_ops.append( mht.saveable.restore((key_tensor[index], value_tensor[index]), None)) index += 1 return mht_restore_ops
def _set_checkpoint_initializer(variable, ckpt_file, tensor_name, slice_spec, name="checkpoint_initializer"): """Overrides given variable's initialization op. Sets variable initializer to assign op that initializes variable from tensor's value in the checkpoint. Args: variable: `tf.Variable` object. ckpt_file: string, full path of the checkpoint. tensor_name: Name of the tensor to load from the checkpoint. slice_spec: Slice specification for loading partitioned tensors. name: Name of the operation. """ base_type = variable.dtype.base_dtype with ops.colocate_with(variable): restore_op = io_ops.restore_v2(ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] variable._initializer_op = state_ops.assign(variable, restore_op) # pylint:disable=protected-access restore_op.set_shape(variable.shape) variable._initial_value = restore_op # pylint:disable=protected-access
def _set_checkpoint_initializer(variable, ckpt_file, tensor_name, slice_spec, name="checkpoint_initializer"): """Overrides given variable's initialization op. Sets variable initializer to assign op that initializes variable from tensor's value in the checkpoint. Args: variable: `tf.Variable` object. ckpt_file: string, full path of the checkpoint. tensor_name: Name of the tensor to load from the checkpoint. slice_spec: Slice specification for loading partitioned tensors. name: Name of the operation. """ base_type = variable.dtype.base_dtype # Do not colocate with variable since RestoreV2 op only runs on CPU and # colocation will force variable (and other ops that colocate with variable) # to be on CPU as well. It is okay to place the variable's initializer op on # CPU since it will only be run once at the start. with ops.device(variable.device), ops.device("/cpu:0"): restore_op = io_ops.restore_v2( ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] if isinstance(variable, resource_variable_ops.ResourceVariable): init_op = variable.assign(restore_op, read_value=False) else: init_op = state_ops.assign(variable, restore_op) variable._initializer_op = init_op # pylint:disable=protected-access restore_op.set_shape(variable.shape) variable._initial_value = restore_op # pylint:disable=protected-access
def restore(self, file_prefix, options=None): """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. options: Optional `CheckpointOptions` object. Returns: A dictionary mapping from SaveableObject names to restore operations. """ options = options or checkpoint_options.CheckpointOptions() restore_specs = [] tensor_structure = [] for saveable in self._saveable_objects: saveable_tensor_structure = [] tensor_structure.append(saveable_tensor_structure) for spec in saveable.specs: saveable_tensor_structure.append(spec.name) restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs) restore_device = options.experimental_io_device or "cpu:0" with ops.device(restore_device): restored_tensors = io_ops.restore_v2( file_prefix, tensor_names, tensor_slices, tensor_dtypes) structured_restored_tensors = nest.pack_sequence_as( tensor_structure, restored_tensors) restore_ops = {} for saveable, restored_tensors in zip(self._saveable_objects, structured_restored_tensors): restore_ops[saveable.name] = saveable.restore( restored_tensors, restored_shapes=None) return restore_ops
def restore(self, file_prefix): """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. Returns: A dictionary mapping from SaveableObject names to restore operations. """ restore_specs = [] tensor_structure = [] for saveable in self._saveable_objects: saveable_tensor_structure = [] tensor_structure.append(saveable_tensor_structure) for spec in saveable.specs: saveable_tensor_structure.append(spec.name) restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs) with ops.device("cpu:0"): restored_tensors = io_ops.restore_v2( file_prefix, tensor_names, tensor_slices, tensor_dtypes) structured_restored_tensors = nest.pack_sequence_as( tensor_structure, restored_tensors) restore_ops = {} for saveable, restored_tensors in zip(self._saveable_objects, structured_restored_tensors): restore_ops[saveable.name] = saveable.restore( restored_tensors, restored_shapes=None) return restore_ops
def testRestoreV2WithSliceInput(self): with ops.Graph().as_default(): op = io_ops.restore_v2("model", ["var1", "var2"], ["", "3 4 0,1:-"], [dtypes.float32, dtypes.float32]) self.assertEqual(2, len(op)) self.assertFalse(op[0].get_shape().is_fully_defined()) self.assertEqual([1, 4], op[1].get_shape())
def restore_op(self, filename_tensor, saveable, preferred_shard): tensors = [] for spec in saveable.specs: print(spec.name) tensors.append( io_ops.restore_v2(filename_tensor, [spec.name], [spec.slice_spec], [spec.tensor.dtype])[0]) return tensors
def restore_fn(trackables, merged_prefix): tensor_names, shapes_and_slices, tensors, restored_trackables = ( _get_tensors(trackables)) dtypes = [t.dtype for t in tensors] try: restored_tensors = io_ops.restore_v2(merged_prefix, tensor_names, shapes_and_slices, dtypes) except errors_impl.NotFoundError: # If a NotFoundError is caught, then it means that the checkpoint # was written prior to the saver registration migration. tensor_names, shapes_and_slices, tensors, restored_trackables = ( _get_tensors(trackables, append_name=False)) restored_tensors = io_ops.restore_v2(merged_prefix, tensor_names, shapes_and_slices, dtypes) for trackable, name_tensor in zip(restored_trackables, restored_tensors): trackable.name = name_tensor
def restore_op(self, filename_tensor, saveable, preferred_shard): tensors = [] for spec in saveable.specs: # Ignore the moving_mean and moving_variance in other towers. if spec.name.startswith('replicated_'): if not spec.name.startswith( 'replicated_0') and 'BatchNorm/moving_' in spec.name: continue tensors.append( io_ops.restore_v2(filename_tensor, ['/'.join(spec.name.split('/')[1:])], [spec.slice_spec], [spec.tensor.dtype])[0]) else: tensors.append( io_ops.restore_v2(filename_tensor, [spec.name], [spec.slice_spec], [spec.tensor.dtype])[0]) return tensors
def _BuildRestore(self): """Builds restore ops.""" assign_ops = [] for var in self._vars: val, = io_ops.restore_v2( prefix=self._restore_prefix_ph, tensor_names=[_VarKey(var)], shape_and_slices=[""], dtypes=[var.dtype]) assign_ops.append(var.assign(val)) self._restore_op = tf.group(*assign_ops)
def bulk_restore(self, filename_tensor, saveables, preferred_shard, restore_sequentially): from tensorflow.python.ops import io_ops restore_specs = [] for saveable in saveables: for spec in saveable.specs: restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) names, slices, dtypes = zip(*restore_specs) restore_dtypes = [tf.float32 for _ in dtypes] with tf.device("cpu:0"): restored = io_ops.restore_v2(filename_tensor, names, slices, restore_dtypes) return [tf.cast(r, dt) for r, dt in zip(restored, dtypes)]
def restore_stacks_and_parts(trackables, merged_prefix): tensor_names, shapes_and_slices, tensors, restored_trackables = ( get_tensor_slices(trackables)) dtypes = [t.dtype for t in tensors] restored_tensors = io_ops.restore_v2(merged_prefix, tensor_names, shapes_and_slices, dtypes) for trackable, restored_tensor in zip(restored_trackables, restored_tensors): expected_shape = trackable.value().get_shape() restored_tensor = array_ops.reshape(restored_tensor, expected_shape) parts = array_ops.unstack(restored_tensor) for part, restored_part in zip(trackable.parts, parts): part.assign(restored_part)
def _set_checkpoint_initializer(variable, ckpt_file, tensor_name, slice_spec, name="checkpoint_initializer"): """Overrides given variable's initialization op. Sets variable initializer to assign op that initializes variable from tensor's value in the checkpoint. Args: variable: `tf.Variable` object. ckpt_file: string, full path of the checkpoint. tensor_name: Name of the tensor to load from the checkpoint. slice_spec: Slice specification for loading partitioned tensors. name: Name of the operation. """ base_type = variable.dtype.base_dtype # Do not colocate with variable since RestoreV2 op only runs on CPU and # colocation will force variable (and other ops that colocate with variable) # to be on CPU as well. It is okay to place the variable's initializer op on # CPU since it will only be run once at the start. with ops.device(variable.device), ops.device("/cpu:0"): restore_op = io_ops.restore_v2(ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] # TODO(priyag, allenl): Use `SaveableObject.restore` instead here. if resource_variable_ops.is_resource_variable(variable): init_op = variable.assign(restore_op, read_value=False) else: init_op = state_ops.assign(variable, restore_op) # pylint:disable=protected-access # We need special handling for `DistributedVariable`s as they contain # mutliple actual variables. `assign` on a `DistributedVariable` returns a # combined `init_op` which contains initializers for all the contained # variables. We then set each underlying variable's `_initializer_op` using # the corresponding `init_op`. # TODO(priyag): Use `isinstance` checks when `DistributedVariable` class # moves out of contrib. if any(base.__name__ == "DistributedVariable" for base in variable.__class__.__bases__): assert distribute_lib.get_cross_tower_context() assert hasattr(variable, "_index") for (d, v) in six.iteritems(variable._index): v._initializer_op = init_op._index[d] restore_op.set_shape(v.shape) v._initial_value = restore_op else: variable._initializer_op = init_op restore_op.set_shape(variable.shape) variable._initial_value = restore_op
def bulk_restore(self, filename_tensor, saveables, preferred_shard, restore_sequentially): # Ignored: bulk restore is internally sequential. del restore_sequentially restore_specs = [] for saveable in saveables: for spec in saveable.specs: restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) names, slices, dtypes = zip(*restore_specs) # Load all tensors onto CPU 0 for compatibility with existing code. with ops.device("cpu:0"): return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
def _assign_existing_variable(variable_to_restore, value_pointer): """Set a variable from a _ValuePointer object.""" base_type = variable_to_restore.dtype.base_dtype with ops.colocate_with(variable_to_restore): # TODO(allenl): Handle partitioned variables value_to_restore, = io_ops.restore_v2( prefix=value_pointer.save_path, tensor_names=[value_pointer.checkpoint_key], shape_and_slices=[""], dtypes=[base_type], name="checkpoint_initializer") initializer_op = state_ops.assign(variable_to_restore, value_to_restore) variable_to_restore._initializer_op = initializer_op # pylint:disable=protected-access if value_pointer.session is not None: value_pointer.session.run(initializer_op)
def _set_checkpoint_initializer(variable, file_pattern, tensor_name, slice_spec, name="checkpoint_initializer"): """Sets variable initializer to assign op form value in checkpoint's tensor. Args: variable: `Variable` object. file_pattern: string, where to load checkpoints from. tensor_name: Name of the `Tensor` to load from checkpoint reader. slice_spec: Slice specification for loading partitioned variables. name: Name of the operation. """ base_type = variable.dtype.base_dtype restore_op = io_ops.restore_v2( file_pattern, [tensor_name], [slice_spec], [base_type], name=name)[0] variable._initializer_op = state_ops.assign(variable, restore_op)
def bulk_restore(self, filename_tensor, saveables, preferred_shard, restore_sequentially): restore_specs = [] for saveable in saveables: for spec in saveable.specs: restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) names, slices, dtypes = zip(*restore_specs) restore_dtypes = [tf.float32 if dtype.base_dtype==tf.float16 else dtype for dtype in dtypes] # print info for i in range(len(restore_specs)): print(names[i], 'from', restore_dtypes[i], 'to', dtypes[i].base_dtype) with tf.device("cpu:0"): restored = io_ops.restore_v2( filename_tensor, names, slices, restore_dtypes) return [tf.cast(r, dt.base_dtype) for r, dt in zip(restored, dtypes)]
def _set_checkpoint_initializer(variable, file_pattern, tensor_name, slice_spec, name="checkpoint_initializer"): """Sets variable initializer to assign op form value in checkpoint's tensor. Args: variable: `Variable` object. file_pattern: string, where to load checkpoints from. tensor_name: Name of the `Tensor` to load from checkpoint reader. slice_spec: Slice specification for loading partitioned variables. name: Name of the operation. """ base_type = variable.dtype.base_dtype restore_op = io_ops.restore_v2( file_pattern, [tensor_name], [slice_spec], [base_type], name=name)[0] variable._initializer_op = state_ops.assign(variable, restore_op)
def _assign_existing_variable(variable_to_restore, value_pointer): """Set a variable from a _ValuePointer object.""" base_type = variable_to_restore.dtype.base_dtype with ops.colocate_with(variable_to_restore): # TODO(allenl): Handle partitioned variables value_to_restore, = io_ops.restore_v2( prefix=value_pointer.save_path, tensor_names=[value_pointer.checkpoint_key], shape_and_slices=[""], dtypes=[base_type], name="checkpoint_initializer") initializer_op = state_ops.assign(variable_to_restore, value_to_restore) variable_to_restore._initializer_op = initializer_op # pylint:disable=protected-access if value_pointer.session is not None: value_pointer.session.run(initializer_op)
def __call__(self, shape, dtype=None, partition_info=None): # Creating different RestoreV2 ops when a single one could # output several tensors seems inefficient, but that's actually # what tf.Saver.restore_op (via tf.BaseSaverBuilder) does too. if self._scope is None: scope_name = tf.get_variable_scope().name elif callable(self._scope): scope_name = self._scope(tf.get_variable_scope().name) else: scope_name = self._scope tensor_name = self._var_name if scope_name: tensor_name = '{}/{}'.format(scope_name, tensor_name) tensor = io_ops.restore_v2( self._filename, [tensor_name], [self._partition_spec(shape, partition_info)], [dtype])[0] tensor.set_shape(shape) return tensor
def _BuildRestore(self): """Builds restore ops.""" assign_ops = [] with self._var_graph.as_default(): per_device = collections.defaultdict(lambda: []) for var in self._vars: per_device[var.device].append(var) for device, var_list in per_device.items(): with self._var_graph.device(device): for var in var_list: val, = io_ops.restore_v2( prefix=self._restore_prefix_ph, tensor_names=[_VarKey(var)], shape_and_slices=[""], dtypes=[var.dtype]) assign_ops.append(var.assign(val)) self._restore_op = tf.group(*assign_ops)
def __call__(self, shape, dtype=None, partition_info=None): # Creating different RestoreV2 ops when a single one could # output several tensors seems inefficient, but that's actually # what tf.Saver.restore_op (via tf.BaseSaverBuilder) does too. if self._scope is None: scope_name = tf.get_variable_scope().name elif callable(self._scope): scope_name = self._scope(tf.get_variable_scope().name) else: scope_name = self._scope tensor_name = self._var_name if scope_name: tensor_name = '{}/{}'.format(scope_name, tensor_name) tensor = io_ops.restore_v2( self._filename, [tensor_name], [self._partition_spec(shape, partition_info)], [dtype])[0] tensor.set_shape(shape) return tensor
def _set_checkpoint_initializer(variable, ckpt_file, tensor_name, slice_spec, name="checkpoint_initializer"): """Overrides given variable's initialization op. Sets variable initializer to assign op that initializes variable from tensor's value in the checkpoint. Args: variable: `tf.Variable` object. ckpt_file: string, full path of the checkpoint. tensor_name: Name of the tensor to load from the checkpoint. slice_spec: Slice specification for loading partitioned tensors. name: Name of the operation. """ base_type = variable.dtype.base_dtype # Do not colocate with variable since RestoreV2 op only runs on CPU and # colocation will force variable (and other ops that colocate with variable) # to be on CPU as well. It is okay to place the variable's initializer op on # CPU since it will only be run once at the start. with ops.device(variable.device), ops.device("/cpu:0"): restore_op = io_ops.restore_v2(ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] names_to_saveables = saver.BaseSaverBuilder.OpListToDict([variable]) saveable_objects = [] for name, op in names_to_saveables.items(): for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name): saveable_objects.append(s) assert len(saveable_objects) == 1 # Should be only one variable. init_op = saveable_objects[0].restore([restore_op], restored_shapes=None) # pylint:disable=protected-access variable._initializer_op = init_op restore_op.set_shape(variable.shape) variable._initial_value = restore_op
def _set_checkpoint_initializer(variable, ckpt_file, tensor_name, slice_spec, name="checkpoint_initializer"): """Overrides given variable's initialization op. Sets variable initializer to assign op that initializes variable from tensor's value in the checkpoint. Args: variable: `tf.Variable` object. ckpt_file: string, full path of the checkpoint. tensor_name: Name of the tensor to load from the checkpoint. slice_spec: Slice specification for loading partitioned tensors. name: Name of the operation. """ base_type = variable.dtype.base_dtype restore_op = io_ops.restore_v2( ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] variable._initializer_op = state_ops.assign(variable, restore_op) # pylint:disable=protected-access
def _set_checkpoint_initializer(variable, ckpt_file, tensor_name, slice_spec, name="checkpoint_initializer"): """Overrides given variable's initialization op. Sets variable initializer to assign op that initializes variable from tensor's value in the checkpoint. Args: variable: `tf.Variable` object. ckpt_file: string, full path of the checkpoint. tensor_name: Name of the tensor to load from the checkpoint. slice_spec: Slice specification for loading partitioned tensors. name: Name of the operation. """ base_type = variable.dtype.base_dtype # Do not colocate with variable since RestoreV2 op only runs on CPU and # colocation will force variable (and other ops that colocate with variable) # to be on CPU as well. It is okay to place the variable's initializer op on # CPU since it will only be run once at the start. with ops.device(variable.device), ops.device("/cpu:0"): restore_op = io_ops.restore_v2( ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] # TODO(priyag, allenl): Use `SaveableObject.restore` instead here. if resource_variable_ops.is_resource_variable(variable): init_op = variable.assign(restore_op, read_value=False) # TODO(priyag): Remove this when using `SaveableObject.restore` instead. if hasattr(init_op, "_index"): init_op = distribute_lib.get_distribution_strategy().group(init_op) else: init_op = state_ops.assign(variable, restore_op) # pylint:disable=protected-access variable._initializer_op = init_op restore_op.set_shape(variable.shape) variable._initial_value = restore_op
def _set_checkpoint_initializer(variable, ckpt_file, tensor_name, slice_spec, name="checkpoint_initializer"): """Overrides given variable's initialization op. Sets variable initializer to assign op that initializes variable from tensor's value in the checkpoint. Args: variable: `tf.Variable` object. ckpt_file: string, full path of the checkpoint. tensor_name: Name of the tensor to load from the checkpoint. slice_spec: Slice specification for loading partitioned tensors. name: Name of the operation. """ base_type = variable.dtype.base_dtype # Do not colocate with variable since RestoreV2 op only runs on CPU and # colocation will force variable (and other ops that colocate with variable) # to be on CPU as well. It is okay to place the variable's initializer op on # CPU since it will only be run once at the start. with ops.device(variable.device), ops.device("/cpu:0"): restore_op = io_ops.restore_v2( ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] names_to_saveables = saver.BaseSaverBuilder.OpListToDict([variable]) saveable_objects = [] for name, op in names_to_saveables.items(): for s in saver.BaseSaverBuilder.SaveableObjectsForOp(op, name): saveable_objects.append(s) assert len(saveable_objects) == 1 # Should be only one variable. init_op = saveable_objects[0].restore([restore_op], restored_shapes=None) # pylint:disable=protected-access variable._initializer_op = init_op restore_op.set_shape(variable.shape) variable._initial_value = restore_op
def restore_from_saveable_objects(file_prefix, saveable_objects): """Reads from a checkpoint and returns restore ops for `saveable_objects`s.""" restore_specs = [] tensor_structure = [] for saveable in saveable_objects: saveable_tensor_structure = [] tensor_structure.append(saveable_tensor_structure) for spec in saveable.specs: saveable_tensor_structure.append(spec.name) restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs) with ops.device("cpu:0"): restored_tensors = io_ops.restore_v2(file_prefix, tensor_names, tensor_slices, tensor_dtypes) structured_restored_tensors = nest.pack_sequence_as( tensor_structure, restored_tensors) restore_ops = [] for saveable, restored_tensors in zip(saveable_objects, structured_restored_tensors): restore_ops.append( saveable.restore(restored_tensors, restored_shapes=None)) return restore_ops
def restore_from_saveable_objects(file_prefix, saveable_objects): """Reads from a checkpoint and returns restore ops for `saveable_objects`s.""" restore_specs = [] tensor_structure = [] for saveable in saveable_objects: saveable_tensor_structure = [] tensor_structure.append(saveable_tensor_structure) for spec in saveable.specs: saveable_tensor_structure.append(spec.name) restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs) with ops.device("cpu:0"): restored_tensors = io_ops.restore_v2( file_prefix, tensor_names, tensor_slices, tensor_dtypes) structured_restored_tensors = nest.pack_sequence_as( tensor_structure, restored_tensors) restore_ops = [] for saveable, restored_tensors in zip(saveable_objects, structured_restored_tensors): restore_ops.append(saveable.restore(restored_tensors, restored_shapes=None)) return restore_ops
def restore(self, file_prefix, options=None): """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. options: Optional `CheckpointOptions` object. Returns: A restored tensor dict (maps checkpoint_key -> slice_spec -> tensor). """ options = options or checkpoint_options.CheckpointOptions() tensor_names = [] tensor_dtypes = [] slice_specs = [] for checkpoint_key, tensor_slices in self._tensor_slice_dict.items(): for slice_spec, tensor in tensor_slices.items(): tensor_dtypes.append(tensor.dtype) if isinstance(tensor, saveable_object.SaveSpec): slice_specs.append(tensor.slice_spec) tensor_names.append(tensor.name) else: slice_specs.append(slice_spec) tensor_names.append(checkpoint_key) restore_device = options.experimental_io_device or "cpu:0" with ops.device(restore_device): restored_tensors = io_ops.restore_v2( file_prefix, tensor_names, slice_specs, tensor_dtypes) restored_tensor_dict = {} for checkpoint_key, tensor_slices in self._tensor_slice_dict.items(): for slice_spec in tensor_slices: restored_tensor = restored_tensors.pop(0) restored_tensor_dict.setdefault(checkpoint_key, {})[slice_spec] = ( restored_tensor) return restored_tensor_dict
def _process_slot_restoration(self, slot_restoration, variable): """Restore a slot variable's value (creating it if necessary).""" # TODO(allenl): Move this to Optimizer assert isinstance(self, optimizer_lib.Optimizer) named_slots = self._slot_dict(slot_restoration.slot_name) variable_key = optimizer_lib._var_key(variable) # pylint: disable=protected-access existing_slot_variable = named_slots.get(variable_key, None) if existing_slot_variable is None: base_dtype = slot_restoration.value_pointer.dtype.base_dtype initializer, = io_ops.restore_v2( prefix=slot_restoration.value_pointer.save_path, tensor_names=[slot_restoration.value_pointer.checkpoint_key], shape_and_slices=[""], dtypes=[base_dtype], name="checkpoint_initializer") new_slot_variable = slot_creator.create_slot(variable, initializer, slot_restoration.slot_name) if slot_restoration.value_pointer.session is not None: slot_restoration.value_pointer.session.run( new_slot_variable.initializer) named_slots[variable_key] = new_slot_variable else: _assign_existing_variable( existing_slot_variable, value_pointer=slot_restoration.value_pointer)
def restore_op(self, filename_tensor, saveable, preferred_shard): return [io_ops.restore_v2(filename_tensor, [unscope_name(spec.name)], [spec.slice_spec], [spec.tensor.dtype])[0] for spec in saveable.specs]
def restore(save_path, root_checkpointable, session=None): """Restore a training checkpoint. Restores the values of variables created with `Checkpointable._add_variable` in `root_checkpointable` and any objects that it tracks (transitive). Either assigns values immediately if variables to restore have been created already, or defers restoration until the variables are created. Dependencies added to `root_checkpointable` after this call will be matched if they have a corresponding object in the checkpoint. When building a graph, restorations are executed in the default session if `session` is `None`. Variable initializers read checkpointed values. To disallow deferred loading, assert immediately that all checkpointed variables have been matched to variable objects: ```python restore(path, root).assert_consumed() ``` An exception will be raised unless every object was matched and its variables already exist. Args: save_path: The path to the checkpoint, as returned by `save` or `tf.train.latest_checkpoint`. If None (as when there is no latest checkpoint for `tf.train.latest_checkpoint` to return), does nothing. root_checkpointable: The root of the object graph to restore. Variables to restore need not have been created yet, but all dependencies on other Checkpointable objects should already be declared. Objects in the dependency graph are matched to objects in the checkpointed graph, and matching objects have their variables restored (or the checkpointed values saved for eventual restoration when the variable is created). session: The session to evaluate assignment ops in. Ignored when executing eagerly. If not provided when graph building, the default session is used. Returns: A CheckpointLoadStatus object, which can be used to make assertions about the status of checkpoint restoration. """ if save_path is None: return if context.in_graph_mode(): if session is None: session = ops.get_default_session() else: session = None object_graph_string, = io_ops.restore_v2( prefix=save_path, tensor_names=[_OBJECT_GRAPH_PROTO_KEY], shape_and_slices=[""], dtypes=[dtypes.string], name="object_graph_proto_read") if session is not None: object_graph_string = session.run(object_graph_string) else: object_graph_string = object_graph_string.numpy() object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) checkpoint = core_checkpointable._Checkpoint( # pylint: disable=protected-access object_graph_proto=object_graph_proto, save_path=save_path, session=session) core_checkpointable._CheckpointPosition( # pylint: disable=protected-access checkpoint=checkpoint, proto_id=0).restore(root_checkpointable) return CheckpointLoadStatus(checkpoint)
def restore(self, save_path, session=None): """Restore a training checkpoint. Restores `root_checkpointable` and any objects that it tracks (transitive). Either assigns values immediately if variables to restore have been created already, or defers restoration until the variables are created. Dependencies added to the `root_checkpointable` passed to the constructor after this call will be matched if they have a corresponding object in the checkpoint. When building a graph, restorations are added to the graph but not run. A session is required to retrieve checkpoint metadata. To disallow deferred loading, assert immediately that all checkpointed variables have been matched to variable objects: ```python saver = Saver(root) saver.restore(path).assert_consumed() ``` An exception will be raised unless every object was matched and its variables already exist. When graph building, `assert_consumed()` indicates that all of the restore ops which will be created for this checkpoint have been created. They can be run via the `run_restore_ops()` function of the status object: ```python saver.restore(path).assert_consumed().run_restore_ops() ``` If the checkpoint has not been consumed completely, then the list of restore ops will grow as more objects are added to the dependency graph. Name-based `tf.train.Saver` checkpoints can be loaded using this method. There is no deferred loading, and names are used to match variables. No restore ops are created/run until `run_restore_ops()` or `initialize_or_restore()` are called on the returned status object, even when executing eagerly. Re-encode name-based checkpoints using this object-based `Saver.save` as soon as possible. Args: save_path: The path to the checkpoint, as returned by `save` or `tf.train.latest_checkpoint`. If None (as when there is no latest checkpoint for `tf.train.latest_checkpoint` to return), returns an object which may run initializers for objects in the dependency graph. If the checkpoint was written by the name-based `tf.train.Saver`, names are used to match variables. session: The session to retrieve metadata with. Ignored when executing eagerly. If not provided when graph building, the default session is used. Returns: A load status object, which can be used to make assertions about the status of checkpoint restoration and run initialization/restore ops (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if `save_path` is `None`). If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus` object is returned which runs restore ops from a name-based saver. """ if save_path is None: return InitializationOnlyStatus(self._root_checkpointable) in_graph_mode = context.in_graph_mode() if in_graph_mode: if session is None: session = ops.get_default_session() file_prefix_tensor = self._file_prefix_placeholder file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} else: session = None file_prefix_tensor = constant_op.constant(save_path) file_prefix_feed_dict = None try: if not in_graph_mode or self._object_graph_restore_tensor is None: object_graph_string, = io_ops.restore_v2( prefix=file_prefix_tensor, tensor_names=[_OBJECT_GRAPH_PROTO_KEY], shape_and_slices=[""], dtypes=[dtypes.string], name="object_graph_proto_read") if in_graph_mode: self._object_graph_restore_tensor = object_graph_string if in_graph_mode: object_graph_string = session.run( self._object_graph_restore_tensor, feed_dict=file_prefix_feed_dict) else: object_graph_string = object_graph_string.numpy() except errors_impl.NotFoundError: # The object graph proto does not exist in this checkpoint. Try again with # name-based saving. return NameBasedSaverStatus(self, save_path) object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) if in_graph_mode and object_graph_proto == self._last_restore_object_graph: checkpoint = self._last_restore_checkpoint else: if in_graph_mode: dtype_map = None else: reader = pywrap_tensorflow.NewCheckpointReader(save_path) dtype_map = reader.get_variable_to_dtype_map() checkpoint = core_checkpointable_utils._Checkpoint( # pylint: disable=protected-access object_graph_proto=object_graph_proto, save_path=file_prefix_tensor, dtype_map=dtype_map) if in_graph_mode: if self._last_restore_object_graph is not None: raise NotImplementedError( "Using a single Saver to restore different object graphs is not " "currently supported when graph building. Use a different Saver " "for each object graph (restore ops will be duplicated), or " "file a feature request if this limitation bothers you.") self._last_restore_checkpoint = checkpoint self._last_restore_object_graph = object_graph_proto core_checkpointable._CheckpointPosition( # pylint: disable=protected-access checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable) load_status = CheckpointLoadStatus( checkpoint, feed_dict=file_prefix_feed_dict) return load_status
def restore(self, save_path, session=None): """Restore a training checkpoint. Restores `root_checkpointable` and any objects that it tracks (transitive). Either assigns values immediately if variables to restore have been created already, or defers restoration until the variables are created. Dependencies added to the `root_checkpointable` passed to the constructor after this call will be matched if they have a corresponding object in the checkpoint. When building a graph, restorations are added to the graph but not run. A session is required to retrieve checkpoint metadata. To disallow deferred loading, assert immediately that all checkpointed variables have been matched to variable objects: ```python saver = Saver(root) saver.restore(path).assert_consumed() ``` An exception will be raised unless every object was matched and its variables already exist. When graph building, `assert_consumed()` indicates that all of the restore ops which will be created for this checkpoint have been created. They can be run via the `run_restore_ops()` function of the status object: ```python saver.restore(path).assert_consumed().run_restore_ops() ``` If the checkpoint has not been consumed completely, then the list of restore ops will grow as more objects are added to the dependency graph. Name-based `tf.train.Saver` checkpoints can be loaded using this method. There is no deferred loading, and names are used to match variables. No restore ops are created/run until `run_restore_ops()` or `initialize_or_restore()` are called on the returned status object, even when executing eagerly. Re-encode name-based checkpoints using this object-based `Saver.save` as soon as possible. Args: save_path: The path to the checkpoint, as returned by `save` or `tf.train.latest_checkpoint`. If None (as when there is no latest checkpoint for `tf.train.latest_checkpoint` to return), returns an object which may run initializers for objects in the dependency graph. If the checkpoint was written by the name-based `tf.train.Saver`, names are used to match variables. session: The session to retrieve metadata with. Ignored when executing eagerly. If not provided when graph building, the default session is used. Returns: A load status object, which can be used to make assertions about the status of checkpoint restoration and run initialization/restore ops (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if `save_path` is `None`). If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus` object is returned which runs restore ops from a name-based saver. """ if save_path is None: return InitializationOnlyStatus(self._root_checkpointable) in_graph_mode = context.in_graph_mode() if in_graph_mode: if session is None: session = ops.get_default_session() file_prefix_tensor = self._file_prefix_placeholder file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} else: session = None file_prefix_tensor = constant_op.constant(save_path) file_prefix_feed_dict = None try: if not in_graph_mode or self._object_graph_restore_tensor is None: object_graph_string, = io_ops.restore_v2( prefix=file_prefix_tensor, tensor_names=[_OBJECT_GRAPH_PROTO_KEY], shape_and_slices=[""], dtypes=[dtypes.string], name="object_graph_proto_read") if in_graph_mode: self._object_graph_restore_tensor = object_graph_string if in_graph_mode: object_graph_string = session.run( self._object_graph_restore_tensor, feed_dict=file_prefix_feed_dict) else: object_graph_string = object_graph_string.numpy() except errors_impl.NotFoundError: # The object graph proto does not exist in this checkpoint. Try again with # name-based saving. return NameBasedSaverStatus(self, save_path) object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) object_graph_proto.ParseFromString(object_graph_string) if in_graph_mode and object_graph_proto == self._last_restore_object_graph: checkpoint = self._last_restore_checkpoint else: if in_graph_mode: dtype_map = None else: reader = pywrap_tensorflow.NewCheckpointReader(save_path) dtype_map = reader.get_variable_to_dtype_map() checkpoint = core_checkpointable_utils._Checkpoint( # pylint: disable=protected-access object_graph_proto=object_graph_proto, save_path=file_prefix_tensor, dtype_map=dtype_map) if in_graph_mode: if self._last_restore_object_graph is not None: raise NotImplementedError( "Using a single Saver to restore different object graphs is not " "currently supported when graph building. Use a different Saver " "for each object graph (restore ops will be duplicated), or " "file a feature request if this limitation bothers you." ) self._last_restore_checkpoint = checkpoint self._last_restore_object_graph = object_graph_proto core_checkpointable._CheckpointPosition( # pylint: disable=protected-access checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable) load_status = CheckpointLoadStatus(checkpoint, feed_dict=file_prefix_feed_dict) return load_status
def add_variable(self, name, shape=None, dtype=dtypes.float32, initializer=None, **kwargs): """Create a new variable object to be saved with this `Checkpointable`. If the user has requested that this object or another `Checkpointable` which depends on this object be restored from a checkpoint (deferred loading before variable object creation), `initializer` may be ignored and the value from the checkpoint used instead. Args: name: A name for the variable. Must be unique within this object. shape: The shape of the variable. dtype: The data type of the variable. initializer: The initializer to use. Ignored if deferred loading has been requested. **kwargs: Passed to the ResourceVariable constructor. Returns: The new variable object. Raises: ValueError: If the variable name is not unique. RuntimeError: If __init__ has not been called. """ if not hasattr(self, "_owned_variables"): raise RuntimeError("Need to call Checkpointable.__init__ before adding " "variables.") if name in self._owned_variables: raise ValueError( ("A variable named '%s' already exists in this Checkpointable, but " "Checkpointable.add_variable called to create another with " "that name. Variable names must be unique within a Checkpointable " "object.") % (name,)) if "getter" in kwargs: # Allow the getter to be overridden, typically because there is a need for # compatibility with some other variable creation mechanism. This should # be relatively uncommon in user code. getter = kwargs.pop("getter") else: getter = _default_getter deferred_restoration = self._deferred_restorations.pop(name, None) if deferred_restoration is not None: dtype = deferred_restoration.value_pointer.dtype base_type = dtype.base_dtype # TODO(allenl): Handle partitioned variables here too with ops.init_scope(): initializer, = io_ops.restore_v2( prefix=deferred_restoration.value_pointer.save_path, tensor_names=[deferred_restoration.value_pointer.checkpoint_key], shape_and_slices=[""], dtypes=[base_type], name="checkpoint_initializer") # We need to un-set the shape so get_variable doesn't complain, but we # also need to set the static shape information on the initializer if # possible so we don't get a variable with an unknown shape. initializer.set_shape(shape) # Un-set shape since we're using a constant initializer shape = None new_variable = getter( name=name, shape=shape, dtype=dtype, initializer=initializer, **kwargs) if deferred_restoration is not None: if deferred_restoration.value_pointer.session is not None: deferred_restoration.value_pointer.session.run(new_variable.initializer) for slot_restoration in deferred_restoration.slot_restorations: strong_ref = slot_restoration.optimizer_ref() if strong_ref is None: # If the optimizer object has been garbage collected, there's no need # to create the slot variable. continue strong_ref._process_slot_restoration( # pylint: disable=protected-access slot_restoration, new_variable) self._owned_variables[name] = new_variable return new_variable