def restore_saveables( self, tensor_saveables: Dict[str, saveable_object.SaveableObject], python_saveables: List[base.PythonStateSaveable], registered_savers: Optional[Dict[str, Dict[str, base.Trackable]]] = None ) -> Optional[List[ops.Operation]]: """Run or build restore operations for SaveableObjects. Args: tensor_saveables: `SaveableObject`s which correspond to Tensors. python_saveables: `PythonStateSaveable`s which correspond to Python values. registered_savers: a dict mapping saver names-> object name -> Trackable. This argument is not implemented for DTensorCheckpoint. Returns: When graph building, a list of restore operations, either cached or newly created, to restore `tensor_saveables`. """ del registered_savers restore_ops = [] # Eagerly run restorations for Python state. reader = None for saveable in python_saveables: if reader is None: # Lazily create the NewCheckpointReader, since this requires file access # and we may not have any Python saveables. reader = py_checkpoint_reader.NewCheckpointReader( self.save_path_string) spec_names = [spec.name for spec in saveable.specs] saveable.python_restore( [reader.get_tensor(name) for name in spec_names]) # If we have new SaveableObjects, extract and cache restore ops. if tensor_saveables: validated_saveables = saveable_object_util.validate_and_slice_inputs( tensor_saveables) validated_names = set(saveable.name for saveable in validated_saveables) if set(tensor_saveables.keys()) != validated_names: raise AssertionError( ("Saveable keys changed when validating. Got back %s, was " "expecting %s") % (tensor_saveables.keys(), validated_names)) # DTensor change: Use _DSaver that does restore on DTensor with # customized DTensorRestoreV2 op. new_restore_ops = _DSaver(self._mesh, validated_saveables).restore( self.save_path_tensor, self.options) if not context.executing_eagerly(): for name, restore_op in sorted(new_restore_ops.items()): restore_ops.append(restore_op) assert name not in self.restore_ops_by_name self.restore_ops_by_name[name] = restore_op return restore_ops
def warm_start(ckpt_to_initialize_from, vars_to_warm_start=".*", var_name_to_prev_var_name=None): """Warm-starts de.Variable using the given settings. Args: ckpt_to_initialize_from: [Required] A string specifying the directory with checkpoint file(s) or path to checkpoint from which to warm-start the model parameters. vars_to_warm_start: [Optional] One of the following: - A regular expression (string) that captures which variables to warm-start (see tf.compat.v1.get_collection). This expression will only consider variables in the TRAINABLE_VARIABLES collection -- if you need to warm-start non_TRAINABLE vars (such as optimizer accumulators or batch norm statistics), please use the below option. - A list of strings, each a regex scope provided to tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see tf.compat.v1.get_collection). For backwards compatibility reasons, this is separate from the single-string argument type. - A list of Variables to warm-start. If you do not have access to the `Variable` objects at the call site, please use the above option. - `None`, in which case only TRAINABLE variables specified in `var_name_to_vocab_info` will be warm-started. Defaults to `'.*'`, which warm-starts all variables in the TRAINABLE_VARIABLES collection. Note that this excludes variables such as accumulators and moving statistics from batch norm. Raises: ValueError: If saveable's spec.name not match pattern defined by de.Variable._make_name. """ def _replace_var_in_spec_name(spec_name, var_name): def _replace(m): return '{}_mht_{}of{}'.format(var_name, m.groups()[1], m.groups()[2]) out = re.sub(r'(\w+)_mht_(\d+)of(\d+)', _replace, spec_name) if out is None: raise ValueError( "Invalid sepc name, should match `{}_mht_{}of{}`, given %s" % spec_name) return out logging.info("Warm-starting from: {}".format(ckpt_to_initialize_from)) de_variables = _get_de_variables(vars_to_warm_start) if not var_name_to_prev_var_name: var_name_to_prev_var_name = {} ckpt_file = checkpoint_utils._get_checkpoint_filename(ckpt_to_initialize_from) assign_ops = [] for variable in de_variables: var_name = variable.name prev_var_name = var_name_to_prev_var_name.get(var_name) if prev_var_name: logging.debug("Warm-start variable: {}: prev_var_name: {}".format( var_name, prev_var_name or "Unchanged")) else: prev_var_name = var_name saveables = saveable_object_util.validate_and_slice_inputs([variable]) for saveable in saveables: restore_specs = [] for spec in saveable.specs: restore_specs.append((_replace_var_in_spec_name(spec.name, prev_var_name), spec.slice_spec, spec.dtype)) names, slices, dtypes = zip(*restore_specs) # Load tensors in cuckoo_hashtable op's device with ops.colocate_with(saveable.op._resource_handle.op): saveable_tensors = io_ops.restore_v2(ckpt_file, names, slices, dtypes) assign_ops.append(saveable.restore(saveable_tensors, None)) return control_flow_ops.group(assign_ops)
def _build_internal(self, names_to_saveables, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, filename="model", build_save=True, build_restore=True): """build() with option to only perform save and restore.""" if not context.executing_eagerly() and (not build_save or not build_restore): raise ValueError("save and restore operations need to be built together " " when eager execution is not enabled.") saveables = saveable_object_util.validate_and_slice_inputs( names_to_saveables) if max_to_keep is None: max_to_keep = 0 with ops.name_scope(name, "save", [saveable.op for saveable in saveables]) as name: # Add a placeholder string tensor for the filename. filename_tensor = array_ops.placeholder_with_default( filename or "model", shape=(), name="filename") # Keep the name "Const" for backwards compatibility. filename_tensor = array_ops.placeholder_with_default( filename_tensor, shape=(), name="Const") # Add the save ops. if sharded: per_device = self._GroupByDevices(saveables) if build_save: op_list = [] with tf.name_scope("Save_Weight_Update_Sharding"): grad_and_var_items = util.get_all_grad_item() for item in grad_and_var_items: if item.var in names_to_saveables: rank_id = item.root_rank_id if rank_id >= 0: with tf.get_default_graph().control_dependencies(op_list): out_var = hccl_ops.broadcast([item.var], rank_id, 2, rank_id) op_list.append(out_var[0].op) if len(op_list) > 0: with tf.get_default_graph().control_dependencies(op_list): save_tensor = self._AddShardedSaveOps(filename_tensor, per_device) else: save_tensor = self._AddShardedSaveOps(filename_tensor, per_device) if build_restore: restore_op = self._AddShardedRestoreOps(filename_tensor, per_device, restore_sequentially, reshape) else: if build_save: op_list = [] with tf.name_scope("Save_Weight_Update_Sharding"): grad_and_var_items = util.get_all_grad_item() for item in grad_and_var_items: if item.var in names_to_saveables: rank_id = item.root_rank_id if rank_id >= 0: with tf.get_default_graph().control_dependencies(op_list): out_var = hccl_ops.broadcast([item.var], rank_id, 2, rank_id) op_list.append(out_var[0].op) if len(op_list) > 0: with tf.get_default_graph().control_dependencies(op_list): save_tensor = self._AddSaveOps(filename_tensor, saveables) else: save_tensor = self._AddSaveOps(filename_tensor, saveables) if build_restore: restore_op = self._AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape) # In the following use case, it's possible to have restore_ops be called # something else: # - Build inference graph and export a meta_graph. # - Import the inference meta_graph # - Extend the inference graph to a train graph. # - Export a new meta_graph. # Now the second restore_op will be called "restore_all_1". # As such, comment out the assert for now until we know whether supporting # such usage model makes sense. # # assert restore_op.name.endswith("restore_all"), restore_op.name if context.executing_eagerly(): # Store the tensor values to the tensor_names. save_tensor_name = save_tensor.numpy() if build_save else "" return saver_pb2.SaverDef( filename_tensor_name=filename_tensor.numpy(), save_tensor_name=save_tensor_name, restore_op_name="", max_to_keep=max_to_keep, sharded=sharded, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, version=self._write_version) else: graph = ops.get_default_graph() # Do some sanity checking on collections containing # PartitionedVariables. If a saved collection has a PartitionedVariable, # the GraphDef needs to include concat ops to get the value (or there'll # be a lookup error on load). check_collection_list = graph.get_all_collection_keys() for collection_type in check_collection_list: for element in graph.get_collection(collection_type): if isinstance(element, variables.PartitionedVariable): try: graph.get_operation_by_name(element.name) except KeyError: # Create a concat op for this PartitionedVariable. The user may # not need it, but we'll try looking it up on MetaGraph restore # since it's in a collection. element.as_tensor() return saver_pb2.SaverDef( filename_tensor_name=filename_tensor.name, save_tensor_name=save_tensor.name, restore_op_name=restore_op.name, max_to_keep=max_to_keep, sharded=sharded, keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, version=self._write_version)