Esempio n. 1
0
    def restore_saveables(
        self,
        tensor_saveables: Dict[str, saveable_object.SaveableObject],
        python_saveables: List[base.PythonStateSaveable],
        registered_savers: Optional[Dict[str, Dict[str,
                                                   base.Trackable]]] = None
    ) -> Optional[List[ops.Operation]]:
        """Run or build restore operations for SaveableObjects.

    Args:
      tensor_saveables: `SaveableObject`s which correspond to Tensors.
      python_saveables: `PythonStateSaveable`s which correspond to Python
        values.
      registered_savers: a dict mapping saver names-> object name -> Trackable.
        This argument is not implemented for DTensorCheckpoint.

    Returns:
      When graph building, a list of restore operations, either cached or newly
      created, to restore `tensor_saveables`.
    """
        del registered_savers

        restore_ops = []
        # Eagerly run restorations for Python state.
        reader = None
        for saveable in python_saveables:
            if reader is None:
                # Lazily create the NewCheckpointReader, since this requires file access
                # and we may not have any Python saveables.
                reader = py_checkpoint_reader.NewCheckpointReader(
                    self.save_path_string)
            spec_names = [spec.name for spec in saveable.specs]
            saveable.python_restore(
                [reader.get_tensor(name) for name in spec_names])

        # If we have new SaveableObjects, extract and cache restore ops.
        if tensor_saveables:
            validated_saveables = saveable_object_util.validate_and_slice_inputs(
                tensor_saveables)
            validated_names = set(saveable.name
                                  for saveable in validated_saveables)
            if set(tensor_saveables.keys()) != validated_names:
                raise AssertionError(
                    ("Saveable keys changed when validating. Got back %s, was "
                     "expecting %s") %
                    (tensor_saveables.keys(), validated_names))
            # DTensor change: Use _DSaver that does restore on DTensor with
            # customized DTensorRestoreV2 op.
            new_restore_ops = _DSaver(self._mesh, validated_saveables).restore(
                self.save_path_tensor, self.options)
            if not context.executing_eagerly():
                for name, restore_op in sorted(new_restore_ops.items()):
                    restore_ops.append(restore_op)
                    assert name not in self.restore_ops_by_name
                    self.restore_ops_by_name[name] = restore_op
        return restore_ops
def warm_start(ckpt_to_initialize_from,
               vars_to_warm_start=".*",
               var_name_to_prev_var_name=None):
  """Warm-starts de.Variable using the given settings.

    Args:
      ckpt_to_initialize_from: [Required] A string specifying the directory with
        checkpoint file(s) or path to checkpoint from which to warm-start the
        model parameters.
      vars_to_warm_start: [Optional] One of the following:
        - A regular expression (string) that captures which variables to
          warm-start (see tf.compat.v1.get_collection).  This expression will only
          consider variables in the TRAINABLE_VARIABLES collection -- if you need
          to warm-start non_TRAINABLE vars (such as optimizer accumulators or
          batch norm statistics), please use the below option.
        - A list of strings, each a regex scope provided to
          tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see
          tf.compat.v1.get_collection).  For backwards compatibility reasons,
          this is separate from the single-string argument type.
        - A list of Variables to warm-start.  If you do not have access to the
          `Variable` objects at the call site, please use the above option.
        - `None`, in which case only TRAINABLE variables specified in
          `var_name_to_vocab_info` will be warm-started.
        Defaults to `'.*'`, which warm-starts all variables in the
        TRAINABLE_VARIABLES collection.  Note that this excludes variables such
        as accumulators and moving statistics from batch norm.

    Raises:
      ValueError: If saveable's spec.name not match pattern 
        defined by de.Variable._make_name.
    """

  def _replace_var_in_spec_name(spec_name, var_name):

    def _replace(m):
      return '{}_mht_{}of{}'.format(var_name, m.groups()[1], m.groups()[2])

    out = re.sub(r'(\w+)_mht_(\d+)of(\d+)', _replace, spec_name)
    if out is None:
      raise ValueError(
          "Invalid sepc name, should match `{}_mht_{}of{}`, given %s" %
          spec_name)
    return out

  logging.info("Warm-starting from: {}".format(ckpt_to_initialize_from))

  de_variables = _get_de_variables(vars_to_warm_start)
  if not var_name_to_prev_var_name:
    var_name_to_prev_var_name = {}

  ckpt_file = checkpoint_utils._get_checkpoint_filename(ckpt_to_initialize_from)
  assign_ops = []
  for variable in de_variables:
    var_name = variable.name
    prev_var_name = var_name_to_prev_var_name.get(var_name)
    if prev_var_name:
      logging.debug("Warm-start variable: {}: prev_var_name: {}".format(
          var_name, prev_var_name or "Unchanged"))
    else:
      prev_var_name = var_name

    saveables = saveable_object_util.validate_and_slice_inputs([variable])
    for saveable in saveables:
      restore_specs = []
      for spec in saveable.specs:
        restore_specs.append((_replace_var_in_spec_name(spec.name,
                                                        prev_var_name),
                              spec.slice_spec, spec.dtype))

      names, slices, dtypes = zip(*restore_specs)
      # Load tensors in cuckoo_hashtable op's device
      with ops.colocate_with(saveable.op._resource_handle.op):
        saveable_tensors = io_ops.restore_v2(ckpt_file, names, slices, dtypes)
        assign_ops.append(saveable.restore(saveable_tensors, None))

  return control_flow_ops.group(assign_ops)
Esempio n. 3
0
    def _build_internal(self,
                        names_to_saveables,
                        reshape=False,
                        sharded=False,
                        max_to_keep=5,
                        keep_checkpoint_every_n_hours=10000.0,
                        name=None,
                        restore_sequentially=False,
                        filename="model",
                        build_save=True,
                        build_restore=True):
        """build() with option to only perform save and restore."""
        if not context.executing_eagerly() and (not build_save or
                                                not build_restore):
            raise ValueError("save and restore operations need to be built together "
                            " when eager execution is not enabled.")

        saveables = saveable_object_util.validate_and_slice_inputs(
            names_to_saveables)
        if max_to_keep is None:
            max_to_keep = 0

        with ops.name_scope(name, "save",
                            [saveable.op for saveable in saveables]) as name:
            # Add a placeholder string tensor for the filename.
            filename_tensor = array_ops.placeholder_with_default(
                filename or "model", shape=(), name="filename")
            # Keep the name "Const" for backwards compatibility.
            filename_tensor = array_ops.placeholder_with_default(
                filename_tensor, shape=(), name="Const")

            # Add the save ops.
            if sharded:
                per_device = self._GroupByDevices(saveables)
                if build_save:
                    op_list = []
                    with tf.name_scope("Save_Weight_Update_Sharding"):
                        grad_and_var_items = util.get_all_grad_item()
                        for item in grad_and_var_items:
                            if item.var in names_to_saveables:
                                rank_id = item.root_rank_id
                                if rank_id >= 0:
                                    with tf.get_default_graph().control_dependencies(op_list):
                                        out_var = hccl_ops.broadcast([item.var], rank_id, 2, rank_id)
                                    op_list.append(out_var[0].op)
                    if len(op_list) > 0:
                        with tf.get_default_graph().control_dependencies(op_list):
                            save_tensor = self._AddShardedSaveOps(filename_tensor, per_device)
                    else:
                        save_tensor = self._AddShardedSaveOps(filename_tensor, per_device)
                if build_restore:
                    restore_op = self._AddShardedRestoreOps(filename_tensor, per_device,
                                                        restore_sequentially, reshape)
            else:
                if build_save:
                    op_list = []
                    with tf.name_scope("Save_Weight_Update_Sharding"):
                        grad_and_var_items = util.get_all_grad_item()
                        for item in grad_and_var_items:
                            if item.var in names_to_saveables:
                                rank_id = item.root_rank_id
                                if rank_id >= 0:
                                    with tf.get_default_graph().control_dependencies(op_list):
                                        out_var = hccl_ops.broadcast([item.var], rank_id, 2, rank_id)
                                    op_list.append(out_var[0].op)
                    if len(op_list) > 0:
                        with tf.get_default_graph().control_dependencies(op_list):
                            save_tensor = self._AddSaveOps(filename_tensor, saveables)
                    else:
                        save_tensor = self._AddSaveOps(filename_tensor, saveables)
                if build_restore:
                    restore_op = self._AddRestoreOps(filename_tensor, saveables,
                                                restore_sequentially, reshape)

        # In the following use case, it's possible to have restore_ops be called
        # something else:
        # - Build inference graph and export a meta_graph.
        # - Import the inference meta_graph
        # - Extend the inference graph to a train graph.
        # - Export a new meta_graph.
        # Now the second restore_op will be called "restore_all_1".
        # As such, comment out the assert for now until we know whether supporting
        # such usage model makes sense.
        #
        # assert restore_op.name.endswith("restore_all"), restore_op.name
        if context.executing_eagerly():
            # Store the tensor values to the tensor_names.
            save_tensor_name = save_tensor.numpy() if build_save else ""
            return saver_pb2.SaverDef(
                filename_tensor_name=filename_tensor.numpy(),
                save_tensor_name=save_tensor_name,
                restore_op_name="",
                max_to_keep=max_to_keep,
                sharded=sharded,
                keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
                version=self._write_version)
        else:
            graph = ops.get_default_graph()
            # Do some sanity checking on collections containing
            # PartitionedVariables. If a saved collection has a PartitionedVariable,
            # the GraphDef needs to include concat ops to get the value (or there'll
            # be a lookup error on load).
            check_collection_list = graph.get_all_collection_keys()
            for collection_type in check_collection_list:
                for element in graph.get_collection(collection_type):
                    if isinstance(element, variables.PartitionedVariable):
                        try:
                            graph.get_operation_by_name(element.name)
                        except KeyError:
                            # Create a concat op for this PartitionedVariable. The user may
                            # not need it, but we'll try looking it up on MetaGraph restore
                            # since it's in a collection.
                            element.as_tensor()
            return saver_pb2.SaverDef(
                filename_tensor_name=filename_tensor.name,
                save_tensor_name=save_tensor.name,
                restore_op_name=restore_op.name,
                max_to_keep=max_to_keep,
                sharded=sharded,
                keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
                version=self._write_version)