def save_fn(): num_shards = len(self._single_device_savers) sharded_saves = [] sharded_prefixes = [] num_shards_tensor = constant_op.constant(num_shards, name="num_shards") last_device = None for shard, (device, saver) in enumerate( sorted(self._single_device_savers.items())): last_device = device with ops.device(saveable_object_util.set_cpu0(device)): shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor) sharded_prefixes.append(shard_prefix) with ops.device(device): # _SingleDeviceSaver will use the CPU device when necessary, but # initial read operations should be placed on the SaveableObject's # device. sharded_saves.append(saver.save(shard_prefix, options)) with ops.control_dependencies(sharded_saves): # Merge on the io_device if specified, otherwise co-locates the merge op # with the last device used. merge_device = ( options.experimental_io_device or saveable_object_util.set_cpu0(last_device)) with ops.device(merge_device): # V2 format write path consists of a metadata merge step. Once # merged, attempts to delete the temporary directory, # "<user-fed prefix>_temp". return gen_io_ops.merge_v2_checkpoints( sharded_prefixes, file_prefix, delete_old_dirs=True)
def save_fn(): saved_prefixes = [] # Save with the registered savers. These run before default savers due to # the API contract. for saver_name, (save_fn, _) in self._registered_savers.items(): maybe_saved_prefixes = save_fn(registered_paths[saver_name]) if maybe_saved_prefixes is not None: flattened_saved_prefixes = nest.flatten( maybe_saved_prefixes) if not all( tensor_util.is_tf_type(x) and x.dtype == dtypes.string for x in flattened_saved_prefixes): raise ValueError( "Registered saver must return a (maybe empty) list of " f"string type tensors. Got {maybe_saved_prefixes}." ) saved_prefixes.extend(flattened_saved_prefixes) # (Default saver) Save with single device savers. num_shards = len(self._single_device_savers) sharded_saves = [] num_shards_tensor = constant_op.constant(num_shards, name="num_shards") last_device = None for shard, (device, saver) in enumerate( sorted(self._single_device_savers.items())): last_device = device with ops.device(saveable_object_util.set_cpu0(device)): shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor) saved_prefixes.append(shard_prefix) with ops.device(device): # _SingleDeviceSaver will use the CPU device when necessary, but # initial read operations should be placed on the SaveableObject's # device. sharded_saves.append(saver.save(shard_prefix, options)) with ops.control_dependencies(sharded_saves): # Merge on the io_device if specified, otherwise co-locates the merge op # with the last device used. merge_device = (options.experimental_io_device or saveable_object_util.set_cpu0(last_device)) with ops.device(merge_device): # V2 format write path consists of a metadata merge step. Once # merged, attempts to delete the temporary directory, # "<user-fed prefix>_temp". return gen_io_ops.merge_v2_checkpoints( saved_prefixes, file_prefix, delete_old_dirs=True)
def __init__(self, saveable_objects): """Specify a list of `SaveableObject`s to save and restore. Args: saveable_objects: A list of `SaveableObject`s. Objects extending `SaveableObject` will be saved and restored, and objects extending `SaveableHook` will be called into at save and restore time. """ self._before_save_callbacks = [] self._after_restore_callbacks = [] saveable_objects = list(saveable_objects) saveables_by_device = {} for saveable in saveable_objects: is_saveable = isinstance(saveable, saveable_object.SaveableObject) is_hook = isinstance(saveable, saveable_hook.SaveableHook) if not is_saveable and not is_hook: raise ValueError( "Expected a dictionary of SaveableObjects, got {}." .format(saveable)) if is_hook: self._before_save_callbacks.append(saveable.before_save) self._after_restore_callbacks.append(saveable.after_restore) if is_saveable: host_device = saveable_object_util.set_cpu0(saveable.device) saveables_by_device.setdefault(host_device, []).append(saveable) self._single_device_savers = { device: _SingleDeviceSaver(saveables) for device, saveables in saveables_by_device.items()}
def restore(self, file_prefix): """Restore the saveable objects from a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix for files to read from. Returns: An operation which restores the `Saver`'s `SaveableObject`s when run, or None if executing eagerly. """ restore_ops = [] for saveable in self._saveable_objects: if saveable.device: device = saveable_object_util.set_cpu0(saveable.device) else: device = None with ops.device(device): tensors = [] for spec in saveable.specs: tensors.append( io_ops.restore_v2(file_prefix, [spec.name], [spec.slice_spec], [spec.dtype])[0]) restore_ops.append( saveable.restore(tensors, restored_shapes=None)) return control_flow_ops.group(restore_ops)
def __init__(self, saveable_objects, registered_savers=None, call_with_mapped_captures=None): """Specify a list of `SaveableObject`s to save and restore. Args: saveable_objects: A list of `SaveableObject`s. Objects extending `SaveableObject` will be saved and restored, and objects extending `SaveableHook` will be called into at save and restore time. registered_savers: A dictionary mapping `registration.RegisteredSaver` namedtuples to a dictionary of named Trackables. The keys of the Trackable dictionary are string names that uniquely identify the Trackable in the checkpoint. call_with_mapped_captures: TODO """ self._before_save_callbacks = [] self._after_restore_callbacks = [] saveable_objects = list(saveable_objects) saveables_by_device = {} for saveable in saveable_objects: is_saveable = isinstance(saveable, saveable_object.SaveableObject) is_hook = isinstance(saveable, saveable_hook.SaveableHook) if not is_saveable and not is_hook: raise ValueError( f"Expected a dictionary of SaveableObjects, got {saveable}." ) if is_hook: self._before_save_callbacks.append(saveable.before_save) self._after_restore_callbacks.append(saveable.after_restore) if is_saveable: host_device = saveable_object_util.set_cpu0(saveable.device) saveables_by_device.setdefault(host_device, []).append(saveable) self._single_device_savers = { device: _SingleDeviceSaver(saveables) for device, saveables in saveables_by_device.items() } self._registered_savers = {} if registered_savers: for registered_name, trackables in registered_savers.items(): save_fn = _get_mapped_registered_save_fn( registration.get_save_function(registered_name), trackables, call_with_mapped_captures) restore_fn = _get_mapped_registered_restore_fn( registration.get_restore_function(registered_name), trackables, call_with_mapped_captures) self._registered_savers[registered_name] = (save_fn, restore_fn)
def tf_function_restore(): restore_ops = restore_fn() restore_tensors = {} # tf.functions must return tensors, thus we use control dependencies so # that we can return a tensor which depends on the given op. with ops.device(saveable_object_util.set_cpu0(first_device)): for name, op in restore_ops.items(): with ops.control_dependencies([op]): restore_tensors[name] = array_ops.identity(file_prefix) return restore_tensors
def __init__(self, saveable_objects, registered_savers=None, call_with_mapped_captures=None): """Specify a list of `SaveableObject`s to save and restore. Args: saveable_objects: A list of `SaveableObject`s. Objects extending `SaveableObject` will be saved and restored. registered_savers: A dictionary mapping `registration.RegisteredSaver` namedtuples to a dictionary of named Trackables. The keys of the Trackable dictionary are string names that uniquely identify the Trackable in the checkpoint. call_with_mapped_captures: TODO """ saveable_objects = list(saveable_objects) # Keep these two data structures so that we can map restored tensors to # the Trackable restore functions. self._keys_to_restore_fn = {} self._restore_fn_to_keys = {} # Extract serialized tensors and separate by device. tensors_by_device = {} # device -> checkpoint key -> (slice_spec ->) tensor for saveable in saveable_objects: tensor_dict = saveable_object_util.saveable_object_to_tensor_dict( [saveable]) restore_fn = saveable_object_util.saveable_object_to_restore_fn( [saveable]) # Divide tensor_dict by device. for checkpoint_key, maybe_tensor in tensor_dict.items(): if not isinstance(maybe_tensor, dict): # Make sure that maybe_tensor is structured as {slice_spec -> tensor}. maybe_tensor = {"": maybe_tensor} for slice_spec, tensor in maybe_tensor.items(): if (checkpoint_key, slice_spec) in self._keys_to_restore_fn: raise ValueError( "Recieved multiple tensors with the same checkpoint key and " "slice spec. This is invalid because one will overwrite the " "other in the checkpoint. This indicates a bug in the " "Checkpoint key-generation.") self._keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn self._restore_fn_to_keys.setdefault(restore_fn, []).append( (checkpoint_key, slice_spec)) host_device = saveable_object_util.set_cpu0(tensor.device) (tensors_by_device .setdefault(host_device, {}) .setdefault(checkpoint_key, {})[slice_spec]) = tensor self._single_device_savers = { device: _SingleDeviceSaver(tensor_slice_dict) for device, tensor_slice_dict in tensors_by_device.items()} self._registered_savers = {} if registered_savers: for registered_name, trackables in registered_savers.items(): save_fn = _get_mapped_registered_save_fn( registration.get_save_function(registered_name), trackables, call_with_mapped_captures) restore_fn = _get_mapped_registered_restore_fn( registration.get_restore_function(registered_name), trackables, call_with_mapped_captures) self._registered_savers[registered_name] = (save_fn, restore_fn)
def save(self, file_prefix): """Save the saveable objects to a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix to save under. Returns: An `Operation`, or None when executing eagerly. """ # IMPLEMENTATION DETAILS: most clients should skip. # # Suffix for any well-formed "checkpoint_prefix", when sharded. # Transformations: # * Users pass in "save_path" in save() and restore(). Say "myckpt". # * checkpoint_prefix gets fed <save_path><sharded_suffix>. # # Example: # During runtime, a temporary directory is first created, which contains # files # # <train dir>/myckpt_temp/ # part-?????-of-?????{.index, .data-00000-of-00001} # # Before .save() finishes, they will be (hopefully, atomically) renamed to # # <train dir>/ # myckpt{.index, .data-?????-of-?????} # # Users only need to interact with the user-specified prefix, which is # "<train dir>/myckpt" in this case. Save() and Restore() work with the # prefix directly, instead of any physical pathname. (On failure and # subsequent restore, an outdated and orphaned temporary directory can be # safely removed.) sharded_suffix = "_temp_%s/part" % uuid.uuid4().hex with ops.device("cpu:0"): tmp_checkpoint_prefix = string_ops.string_join( [file_prefix, sharded_suffix]) num_shards = len(self._single_device_savers) sharded_saves = [] sharded_prefixes = [] num_shards_tensor = constant_op.constant(num_shards, name="num_shards") last_device = None for shard, (device, saver) in enumerate( sorted(self._single_device_savers.items())): last_device = device with ops.device(saveable_object_util.set_cpu0(device)): shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor) sharded_prefixes.append(shard_prefix) with ops.device(device): # _SingleDeviceSaver will use the CPU device when necessary, but initial # read operations should be placed on the SaveableObject's device. sharded_saves.append(saver.save(shard_prefix)) with ops.control_dependencies(sharded_saves): # Co-locates the merge step with the last device. with ops.device(saveable_object_util.set_cpu0(last_device)): # V2 format write path consists of a metadata merge step. Once merged, # attempts to delete the temporary directory, "<user-fed prefix>_temp". return gen_io_ops.merge_v2_checkpoints( sharded_prefixes, file_prefix, delete_old_dirs=True)
def save(self, file_prefix, options=None): """Save the saveable objects to a checkpoint with `file_prefix`. Args: file_prefix: A string or scalar string Tensor containing the prefix to save under. options: Optional `CheckpointOptions` object. Returns: An `Operation`, or None when executing eagerly. """ options = options or checkpoint_options.CheckpointOptions() for callback in self._before_save_callbacks: callback() # IMPLEMENTATION DETAILS: most clients should skip. # # Suffix for any well-formed "checkpoint_prefix", when sharded. # Transformations: # * Users pass in "save_path" in save() and restore(). Say "myckpt". # * checkpoint_prefix gets fed <save_path><sharded_suffix>. # # Example: # During runtime, a temporary directory is first created, which contains # files # # <train dir>/myckpt_temp/ # part-?????-of-?????{.index, .data-00000-of-00001} # # Before .save() finishes, they will be (hopefully, atomically) renamed to # # <train dir>/ # myckpt{.index, .data-?????-of-?????} # # Filesystems with eventual consistency (such as S3), don't need a # temporary location. Using a temporary directory in those cases might # cause situations where files are not available during copy. # # Users only need to interact with the user-specified prefix, which is # "<train dir>/myckpt" in this case. Save() and Restore() work with the # prefix directly, instead of any physical pathname. (On failure and # subsequent restore, an outdated and orphaned temporary directory can be # safely removed.) with ops.device("CPU"): sharded_suffix = array_ops.where( string_ops.regex_full_match(file_prefix, "^s3://.*"), constant_op.constant(".part"), constant_op.constant("_temp_%s/part" % uuid.uuid4().hex)) tmp_checkpoint_prefix = string_ops.string_join( [file_prefix, sharded_suffix]) num_shards = len(self._single_device_savers) sharded_saves = [] sharded_prefixes = [] num_shards_tensor = constant_op.constant(num_shards, name="num_shards") last_device = None for shard, (device, saver) in enumerate( sorted(self._single_device_savers.items())): last_device = device with ops.device(saveable_object_util.set_cpu0(device)): shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor) sharded_prefixes.append(shard_prefix) with ops.device(device): # _SingleDeviceSaver will use the CPU device when necessary, but initial # read operations should be placed on the SaveableObject's device. sharded_saves.append(saver.save(shard_prefix, options)) with ops.control_dependencies(sharded_saves): # Merge on the io_device if specified, otherwise co-locates the merge op # with the last device used. merge_device = (options.experimental_io_device or saveable_object_util.set_cpu0(last_device)) with ops.device(merge_device): # V2 format write path consists of a metadata merge step. Once merged, # attempts to delete the temporary directory, "<user-fed prefix>_temp". return gen_io_ops.merge_v2_checkpoints(sharded_prefixes, file_prefix, delete_old_dirs=True)