def __init__(self, cluster_resolver, checkpoint, checkpoint_dir):
    """Creates the failure handler.

    Args:
      cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`. You
        may also get it through the `cluster_resolver` attribute of the
        strategy in use.
      checkpoint: a `tf.train.Checkpoint` that will be saved upon preemption and
        loaded upon restart by the `CoordinatedCheckpointManager` API
        automatically.
      checkpoint_dir: a directory for the `CoordinatedCheckpointManager` to play
        with checkpoints. `CoordinatedCheckpointManager` will create a
        `tf.train.CheckpointManager` to manage the passed-in `checkpoint`. Since
        only one `tf.train.CheckpointManager` should be active in a particular
        directory at a time, this `checkpoint_dir` arg should preferably be
        separated from where the user saves their checkpoint for non-fault
        tolerance purpose.
    """
    self._cluster_resolver = cluster_resolver
    self._checkpoint = checkpoint
    self._id_in_cluster = str(
        multi_worker_util.id_in_cluster(
            self._cluster_resolver.cluster_spec(),
            self._cluster_resolver.task_type,
            self._cluster_resolver.task_id))

    # The number of calls to `CoordinatedCheckpointManager.run` when the latest
    # checkpoint was saved.
    self._checkpointed_runs = variables.Variable(
        initial_value=constant_op.constant(0, dtype=dtypes.int64),
        trainable=False,
        name=_ITERATION_VARIABLE)
    if not hasattr(self._checkpoint,
                   _ITERATION_VARIABLE):
      setattr(self._checkpoint, _ITERATION_VARIABLE,
              self._checkpointed_runs)

    # Make CheckpointManagers. MultiWorkerMirroredStrategy requires different
    # setup on chief and on other workers.
    self._read_checkpoint_manager = checkpoint_management.CheckpointManager(
        checkpoint, directory=checkpoint_dir, max_to_keep=1)
    if multi_worker_util.is_chief(
        cluster_spec=cluster_resolver.cluster_spec(),
        task_type=cluster_resolver.task_type,
        task_id=cluster_resolver.task_id):
      self._write_checkpoint_manager = self._read_checkpoint_manager
    else:
      self._write_checkpoint_manager = checkpoint_management.CheckpointManager(
          checkpoint,
          _mwms_write_checkpoint_dir(checkpoint_dir, cluster_resolver.task_type,
                                     cluster_resolver.task_id,
                                     cluster_resolver.cluster_spec()),
          max_to_keep=1)

    self._read_checkpoint_manager.restore_or_initialize()

    # An internal step counter that's restored to checkpointed_iterations when
    # training is restored. It increments by one every time
    # `CoordinatedCheckpointManager.run` is called. Note that in this case, the
    # user must pass a single-step training function to
    # `CoordinatedCheckpointManager.run` instead of a multiple-step one.
    self._run_counter = self._checkpointed_runs.numpy()

    # The worker itself has received preeption signal.
    self._received_own_sigterm = threading.Event()

    # Some member (could be oneself) has received preemption signal, and the
    # step number to save a checkpoint has been aligned.
    self._received_sigterm_and_step = threading.Event()

    # TODO(wxinyi): Enforce that only one instance of this class is created
    # per program.
    # TODO(wxinyi): make the thread non-daemon.
    threading.Thread(target=self._wait_for_signal, daemon=True).start()

    self._platform_device = gce_util.detect_platform()
    if self._platform_device is gce_util.PlatformDevice.GCE_GPU:
      self._start_polling_for_gce_signal()
      self._exit_code = gce_util._RESTARTABLE_EXIT_CODE
    elif self._platform_device is gce_util.PlatformDevice.INTERNAL:
      self._start_watching_for_signal()
      self._exit_code = _RESTARTABLE_EXIT_CODE
    else:
      raise NotImplementedError('CoordinatedCheckpointManager is only supported'
                                ' for MultiWorkerMirroredStrategy with GPU.')
示例#2
0
    def __init__(self,
                 cluster_resolver,
                 checkpoint,
                 checkpoint_dir,
                 termination_config=TerminationConfig()):
        """Creates the failure handler.

    Args:
      cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`. You
        may also get it through the `cluster_resolver` attribute of the strategy
        in use.
      checkpoint: a `tf.train.Checkpoint` that will be saved upon preemption and
        loaded upon restart by the `WorkerPreemptionHandler` API automatically.
      checkpoint_dir: a directory for the `WorkerPreemptionHandler` to play with
        checkpoints. `WorkerPreemptionHandler` will create a
        `tf.train.CheckpointManager` to manage the passed-in `checkpoint`. Since
        only one `tf.train.CheckpointManager` should be active in a particular
        directory at a time, this `checkpoint_dir` arg should preferably be
        separated from where the user saves their checkpoint for non-fault
        tolerance purpose.
      termination_config: a `TerminationConfig` object to configure for a
        platform other than Google Borg or GCP.
    """
        self._cluster_resolver = cluster_resolver
        self._checkpoint = checkpoint
        self._id_in_cluster = str(
            multi_worker_util.id_in_cluster(
                self._cluster_resolver.cluster_spec(),
                self._cluster_resolver.task_type,
                self._cluster_resolver.task_id))

        # The number of calls to `WorkerPreemptionHandler.run` when the latest
        # checkpoint was saved.
        self._checkpointed_runs = variables.Variable(
            initial_value=constant_op.constant(0, dtype=dtypes.int64),
            trainable=False,
            name=_ITERATION_VARIABLE)
        if not hasattr(self._checkpoint, _ITERATION_VARIABLE):
            setattr(self._checkpoint, _ITERATION_VARIABLE,
                    self._checkpointed_runs)

        # Make CheckpointManagers. MultiWorkerMirroredStrategy requires different
        # setup on chief and on other workers.
        self._read_checkpoint_manager = checkpoint_management.CheckpointManager(
            checkpoint, directory=checkpoint_dir, max_to_keep=1)
        if multi_worker_util.is_chief(
                cluster_spec=cluster_resolver.cluster_spec(),
                task_type=cluster_resolver.task_type,
                task_id=cluster_resolver.task_id):
            self._write_checkpoint_manager = self._read_checkpoint_manager
        else:
            self._write_checkpoint_manager = checkpoint_management.CheckpointManager(
                checkpoint,
                _mwms_write_checkpoint_dir(checkpoint_dir,
                                           cluster_resolver.task_type,
                                           cluster_resolver.task_id,
                                           cluster_resolver.cluster_spec()),
                max_to_keep=1)

        self._read_checkpoint_manager.restore_or_initialize()

        # grace period countdown. Set to True for all workers once they finish
        # timing saving a checkpoint. Once entering this phase, new
        # preemption/maintenance notice will not be handled, since the whole cluster
        # goes down as the worker who first initiates the grace period goes down.
        self._final_checkpoint_countdown = False

        self._estimated_run_time = 0

        # An internal step counter that's restored to checkpointed_iterations when
        # training is restored. It increments by one every time
        # `WorkerPreemptionHandler.run` is called. Note that in this case, the
        # user must pass a single-step training function to
        # `WorkerPreemptionHandler.run` instead of a multiple-step one.
        self._run_counter = self._checkpointed_runs.numpy()

        # The worker itself has received preeption signal.
        self._received_own_sigterm = threading.Event()

        # Some member (could be oneself) has received preemption signal, and the
        # step number to save a checkpoint has been aligned.
        self._received_checkpoint_step = threading.Event()

        self._platform_device = gce_util.detect_platform()

        completed_termination_config = _complete_config_for_environement(
            self._platform_device, termination_config)
        self._termination_watcher_function = completed_termination_config.termination_watcher_function
        self._exit_fn = completed_termination_config.exit_fn
        self._grace_period = completed_termination_config.time_till_termination

        # When training is interrupted, we explicitly call the cleanup methods for
        # the thread watching for local worker's termination signal and the thread
        # watching for clusterwise information before we save a checkpoint and exit.
        # In the final chapter of the training where no interruption is encountered,
        # we rely on __del__ to clean up. However, there is no guarantee when or
        # whether __del__ is executed, thus we make the threads daemon to avoid it
        # preventing program from exit.
        self._cluster_wise_termination_watcher_thread = threading.Thread(
            target=self._watch_step_to_save_key,
            name='PeerTerminationWatcher-%s' % self._id_in_cluster,
            daemon=True)
        logging.info('Start watcher for peer\'s signal.')
        self._cluster_wise_termination_watcher_thread.start()

        self._poll_termination_signal_thread = None

        if completed_termination_config.termination_watcher_function:
            self._start_polling_for_termination_signal()
        else:
            self._start_watching_for_signal()
示例#3
0
    def __init__(self,
                 cluster_resolver,
                 checkpoint_or_checkpoint_manager,
                 checkpoint_dir=None,
                 termination_config=None):
        """Creates the `PreemptionCheckpointHandler`.

    Args:
      cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`
        object. You may also obtain it through the `cluster_resolver` attribute
        of the distribution strategy in use.
      checkpoint_or_checkpoint_manager: a `tf.train.CheckpointManager` or a
        `tf.train.Checkpoint`. If you are using a `tf.train.CheckpointManager`
        to manage checkpoints outside the `PreemptionCheckpointHandler` for
        backup purpose as well, pass it as `checkpoint_or_checkpoint_manager`
        argument. Otherwise, pass a `tf.train.Checkpoint` and the
        `PreemptionCheckpointHandler` will create
        a `tf.train.CheckpointManager` to manage it in the `checkpoint_dir`.
      checkpoint_dir: a directory where the `PreemptionCheckpointHandler` saves
        and restores checkpoints. When a `PreemptionCheckpointHandler` is
        created, the latest checkpoint in the `checkpoint_dir` will be restored.
        (This is not needed if a `tf.train.CheckpointManager` instead of a
        `tf.train.Checkpoint` is passed as the
        `checkpoint_or_checkpoint_manager` argument.)
      termination_config: optional, a
        `tf.distribute.experimental.TerminationConfig` object to configure for a
        platform other than Google Borg or GCP.
    """
        self._cluster_resolver = cluster_resolver
        if isinstance(checkpoint_or_checkpoint_manager,
                      checkpoint_lib.Checkpoint) and not checkpoint_dir:
            raise errors.InvalidArgumentError(
                'When a checkpoint is passed, a '
                'checkpoint_dir must be passed as well'
                '.')
        self._id_in_cluster = str(
            multi_worker_util.id_in_cluster(
                self._cluster_resolver.cluster_spec(),
                self._cluster_resolver.task_type,
                self._cluster_resolver.task_id))

        # The number of calls to `PreemptionCheckpointHandler.run` when the latest
        # checkpoint was saved.
        self._checkpointed_runs = variables.Variable(
            initial_value=constant_op.constant(0, dtype=dtypes.int64),
            trainable=False,
            name=_ITERATION_VARIABLE)

        self._maybe_create_checkpoint_manager(checkpoint_or_checkpoint_manager,
                                              checkpoint_dir, cluster_resolver)

        if not hasattr(self._write_checkpoint_manager._checkpoint,
                       _ITERATION_VARIABLE):
            setattr(self._write_checkpoint_manager._checkpoint,
                    _ITERATION_VARIABLE, self._checkpointed_runs)

        if not hasattr(self._read_checkpoint_manager._checkpoint,
                       _ITERATION_VARIABLE):
            setattr(self._read_checkpoint_manager._checkpoint,
                    _ITERATION_VARIABLE, self._checkpointed_runs)

        self._read_checkpoint_manager.restore_or_initialize()

        # grace period countdown. Set to True for all workers once they finish
        # timing saving a checkpoint. Once entering this phase, new
        # preemption/maintenance notice will not be handled, since the whole cluster
        # goes down as the worker who first initiates the grace period goes down.
        self._final_checkpoint_countdown = False

        self._estimated_run_time = 0

        # An internal step counter that's restored to checkpointed_iterations when
        # training is restored. It increments by one every time
        # `PreemptionCheckpointHandler.run` is called. Note that in this case, the
        # user must pass a single-step training function to
        # `PreemptionCheckpointHandler.run` instead of a multiple-step one.
        self._run_counter = self._checkpointed_runs.numpy()

        # The worker itself has received preeption signal.
        self._received_own_sigterm = threading.Event()

        # Some member (could be oneself) has received preemption signal, and the
        # step number to save a checkpoint has been aligned.
        self._received_checkpoint_step = threading.Event()

        self._platform_device = gce_util.detect_platform()

        if self._platform_device in (gce_util.PlatformDevice.GCE_TPU,
                                     gce_util.PlatformDevice.GCE_CPU):
            # While running MultiWorkerMirroredStrategy training with GPUs and CPUs
            # are the same on Borg, GCE CPU VM and GPU VM are different in terms
            # of live migration, grace period, etc. We can make it work upon request.
            raise NotImplementedError(
                'PreemptionCheckpointHandler does not support '
                'training with TPU or CPU device on GCP.')

        completed_termination_config = _complete_config_for_environment(
            self._platform_device, termination_config)
        self._termination_watcher_fn = completed_termination_config.termination_watcher_fn
        self._exit_fn = completed_termination_config.exit_fn
        self._grace_period = completed_termination_config.grace_period

        # When training is interrupted, we explicitly call the cleanup methods for
        # the thread watching for local worker's termination signal and the thread
        # watching for clusterwise information before we save a checkpoint and exit.
        # In the final chapter of the training where no interruption is encountered,
        # we rely on __del__ to clean up. However, there is no guarantee when or
        # whether __del__ is executed, thus we make the threads daemon to avoid it
        # preventing program from exit.
        self._cluster_wise_termination_watcher_thread = threading.Thread(
            target=self._watch_step_to_save_key,
            name='PeerTerminationWatcher-%s' % self._id_in_cluster,
            daemon=True)
        logging.info('Start watcher for peer\'s signal.')
        self._cluster_wise_termination_watcher_thread.start()

        self._poll_termination_signal_thread = None

        if completed_termination_config.termination_watcher_fn:
            self._start_polling_for_termination_signal()
        else:
            self._start_watching_for_signal()