def __init__(self, cluster_resolver, checkpoint, checkpoint_dir): """Creates the failure handler. Args: cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`. You may also get it through the `cluster_resolver` attribute of the strategy in use. checkpoint: a `tf.train.Checkpoint` that will be saved upon preemption and loaded upon restart by the `CoordinatedCheckpointManager` API automatically. checkpoint_dir: a directory for the `CoordinatedCheckpointManager` to play with checkpoints. `CoordinatedCheckpointManager` will create a `tf.train.CheckpointManager` to manage the passed-in `checkpoint`. Since only one `tf.train.CheckpointManager` should be active in a particular directory at a time, this `checkpoint_dir` arg should preferably be separated from where the user saves their checkpoint for non-fault tolerance purpose. """ self._cluster_resolver = cluster_resolver self._checkpoint = checkpoint self._id_in_cluster = str( multi_worker_util.id_in_cluster( self._cluster_resolver.cluster_spec(), self._cluster_resolver.task_type, self._cluster_resolver.task_id)) # The number of calls to `CoordinatedCheckpointManager.run` when the latest # checkpoint was saved. self._checkpointed_runs = variables.Variable( initial_value=constant_op.constant(0, dtype=dtypes.int64), trainable=False, name=_ITERATION_VARIABLE) if not hasattr(self._checkpoint, _ITERATION_VARIABLE): setattr(self._checkpoint, _ITERATION_VARIABLE, self._checkpointed_runs) # Make CheckpointManagers. MultiWorkerMirroredStrategy requires different # setup on chief and on other workers. self._read_checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, directory=checkpoint_dir, max_to_keep=1) if multi_worker_util.is_chief( cluster_spec=cluster_resolver.cluster_spec(), task_type=cluster_resolver.task_type, task_id=cluster_resolver.task_id): self._write_checkpoint_manager = self._read_checkpoint_manager else: self._write_checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, _mwms_write_checkpoint_dir(checkpoint_dir, cluster_resolver.task_type, cluster_resolver.task_id, cluster_resolver.cluster_spec()), max_to_keep=1) self._read_checkpoint_manager.restore_or_initialize() # An internal step counter that's restored to checkpointed_iterations when # training is restored. It increments by one every time # `CoordinatedCheckpointManager.run` is called. Note that in this case, the # user must pass a single-step training function to # `CoordinatedCheckpointManager.run` instead of a multiple-step one. self._run_counter = self._checkpointed_runs.numpy() # The worker itself has received preeption signal. self._received_own_sigterm = threading.Event() # Some member (could be oneself) has received preemption signal, and the # step number to save a checkpoint has been aligned. self._received_sigterm_and_step = threading.Event() # TODO(wxinyi): Enforce that only one instance of this class is created # per program. # TODO(wxinyi): make the thread non-daemon. threading.Thread(target=self._wait_for_signal, daemon=True).start() self._platform_device = gce_util.detect_platform() if self._platform_device is gce_util.PlatformDevice.GCE_GPU: self._start_polling_for_gce_signal() self._exit_code = gce_util._RESTARTABLE_EXIT_CODE elif self._platform_device is gce_util.PlatformDevice.INTERNAL: self._start_watching_for_signal() self._exit_code = _RESTARTABLE_EXIT_CODE else: raise NotImplementedError('CoordinatedCheckpointManager is only supported' ' for MultiWorkerMirroredStrategy with GPU.')
def __init__(self, cluster_resolver, checkpoint, checkpoint_dir, termination_config=TerminationConfig()): """Creates the failure handler. Args: cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`. You may also get it through the `cluster_resolver` attribute of the strategy in use. checkpoint: a `tf.train.Checkpoint` that will be saved upon preemption and loaded upon restart by the `WorkerPreemptionHandler` API automatically. checkpoint_dir: a directory for the `WorkerPreemptionHandler` to play with checkpoints. `WorkerPreemptionHandler` will create a `tf.train.CheckpointManager` to manage the passed-in `checkpoint`. Since only one `tf.train.CheckpointManager` should be active in a particular directory at a time, this `checkpoint_dir` arg should preferably be separated from where the user saves their checkpoint for non-fault tolerance purpose. termination_config: a `TerminationConfig` object to configure for a platform other than Google Borg or GCP. """ self._cluster_resolver = cluster_resolver self._checkpoint = checkpoint self._id_in_cluster = str( multi_worker_util.id_in_cluster( self._cluster_resolver.cluster_spec(), self._cluster_resolver.task_type, self._cluster_resolver.task_id)) # The number of calls to `WorkerPreemptionHandler.run` when the latest # checkpoint was saved. self._checkpointed_runs = variables.Variable( initial_value=constant_op.constant(0, dtype=dtypes.int64), trainable=False, name=_ITERATION_VARIABLE) if not hasattr(self._checkpoint, _ITERATION_VARIABLE): setattr(self._checkpoint, _ITERATION_VARIABLE, self._checkpointed_runs) # Make CheckpointManagers. MultiWorkerMirroredStrategy requires different # setup on chief and on other workers. self._read_checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, directory=checkpoint_dir, max_to_keep=1) if multi_worker_util.is_chief( cluster_spec=cluster_resolver.cluster_spec(), task_type=cluster_resolver.task_type, task_id=cluster_resolver.task_id): self._write_checkpoint_manager = self._read_checkpoint_manager else: self._write_checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, _mwms_write_checkpoint_dir(checkpoint_dir, cluster_resolver.task_type, cluster_resolver.task_id, cluster_resolver.cluster_spec()), max_to_keep=1) self._read_checkpoint_manager.restore_or_initialize() # grace period countdown. Set to True for all workers once they finish # timing saving a checkpoint. Once entering this phase, new # preemption/maintenance notice will not be handled, since the whole cluster # goes down as the worker who first initiates the grace period goes down. self._final_checkpoint_countdown = False self._estimated_run_time = 0 # An internal step counter that's restored to checkpointed_iterations when # training is restored. It increments by one every time # `WorkerPreemptionHandler.run` is called. Note that in this case, the # user must pass a single-step training function to # `WorkerPreemptionHandler.run` instead of a multiple-step one. self._run_counter = self._checkpointed_runs.numpy() # The worker itself has received preeption signal. self._received_own_sigterm = threading.Event() # Some member (could be oneself) has received preemption signal, and the # step number to save a checkpoint has been aligned. self._received_checkpoint_step = threading.Event() self._platform_device = gce_util.detect_platform() completed_termination_config = _complete_config_for_environement( self._platform_device, termination_config) self._termination_watcher_function = completed_termination_config.termination_watcher_function self._exit_fn = completed_termination_config.exit_fn self._grace_period = completed_termination_config.time_till_termination # When training is interrupted, we explicitly call the cleanup methods for # the thread watching for local worker's termination signal and the thread # watching for clusterwise information before we save a checkpoint and exit. # In the final chapter of the training where no interruption is encountered, # we rely on __del__ to clean up. However, there is no guarantee when or # whether __del__ is executed, thus we make the threads daemon to avoid it # preventing program from exit. self._cluster_wise_termination_watcher_thread = threading.Thread( target=self._watch_step_to_save_key, name='PeerTerminationWatcher-%s' % self._id_in_cluster, daemon=True) logging.info('Start watcher for peer\'s signal.') self._cluster_wise_termination_watcher_thread.start() self._poll_termination_signal_thread = None if completed_termination_config.termination_watcher_function: self._start_polling_for_termination_signal() else: self._start_watching_for_signal()
def __init__(self, cluster_resolver, checkpoint_or_checkpoint_manager, checkpoint_dir=None, termination_config=None): """Creates the `PreemptionCheckpointHandler`. Args: cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver` object. You may also obtain it through the `cluster_resolver` attribute of the distribution strategy in use. checkpoint_or_checkpoint_manager: a `tf.train.CheckpointManager` or a `tf.train.Checkpoint`. If you are using a `tf.train.CheckpointManager` to manage checkpoints outside the `PreemptionCheckpointHandler` for backup purpose as well, pass it as `checkpoint_or_checkpoint_manager` argument. Otherwise, pass a `tf.train.Checkpoint` and the `PreemptionCheckpointHandler` will create a `tf.train.CheckpointManager` to manage it in the `checkpoint_dir`. checkpoint_dir: a directory where the `PreemptionCheckpointHandler` saves and restores checkpoints. When a `PreemptionCheckpointHandler` is created, the latest checkpoint in the `checkpoint_dir` will be restored. (This is not needed if a `tf.train.CheckpointManager` instead of a `tf.train.Checkpoint` is passed as the `checkpoint_or_checkpoint_manager` argument.) termination_config: optional, a `tf.distribute.experimental.TerminationConfig` object to configure for a platform other than Google Borg or GCP. """ self._cluster_resolver = cluster_resolver if isinstance(checkpoint_or_checkpoint_manager, checkpoint_lib.Checkpoint) and not checkpoint_dir: raise errors.InvalidArgumentError( 'When a checkpoint is passed, a ' 'checkpoint_dir must be passed as well' '.') self._id_in_cluster = str( multi_worker_util.id_in_cluster( self._cluster_resolver.cluster_spec(), self._cluster_resolver.task_type, self._cluster_resolver.task_id)) # The number of calls to `PreemptionCheckpointHandler.run` when the latest # checkpoint was saved. self._checkpointed_runs = variables.Variable( initial_value=constant_op.constant(0, dtype=dtypes.int64), trainable=False, name=_ITERATION_VARIABLE) self._maybe_create_checkpoint_manager(checkpoint_or_checkpoint_manager, checkpoint_dir, cluster_resolver) if not hasattr(self._write_checkpoint_manager._checkpoint, _ITERATION_VARIABLE): setattr(self._write_checkpoint_manager._checkpoint, _ITERATION_VARIABLE, self._checkpointed_runs) if not hasattr(self._read_checkpoint_manager._checkpoint, _ITERATION_VARIABLE): setattr(self._read_checkpoint_manager._checkpoint, _ITERATION_VARIABLE, self._checkpointed_runs) self._read_checkpoint_manager.restore_or_initialize() # grace period countdown. Set to True for all workers once they finish # timing saving a checkpoint. Once entering this phase, new # preemption/maintenance notice will not be handled, since the whole cluster # goes down as the worker who first initiates the grace period goes down. self._final_checkpoint_countdown = False self._estimated_run_time = 0 # An internal step counter that's restored to checkpointed_iterations when # training is restored. It increments by one every time # `PreemptionCheckpointHandler.run` is called. Note that in this case, the # user must pass a single-step training function to # `PreemptionCheckpointHandler.run` instead of a multiple-step one. self._run_counter = self._checkpointed_runs.numpy() # The worker itself has received preeption signal. self._received_own_sigterm = threading.Event() # Some member (could be oneself) has received preemption signal, and the # step number to save a checkpoint has been aligned. self._received_checkpoint_step = threading.Event() self._platform_device = gce_util.detect_platform() if self._platform_device in (gce_util.PlatformDevice.GCE_TPU, gce_util.PlatformDevice.GCE_CPU): # While running MultiWorkerMirroredStrategy training with GPUs and CPUs # are the same on Borg, GCE CPU VM and GPU VM are different in terms # of live migration, grace period, etc. We can make it work upon request. raise NotImplementedError( 'PreemptionCheckpointHandler does not support ' 'training with TPU or CPU device on GCP.') completed_termination_config = _complete_config_for_environment( self._platform_device, termination_config) self._termination_watcher_fn = completed_termination_config.termination_watcher_fn self._exit_fn = completed_termination_config.exit_fn self._grace_period = completed_termination_config.grace_period # When training is interrupted, we explicitly call the cleanup methods for # the thread watching for local worker's termination signal and the thread # watching for clusterwise information before we save a checkpoint and exit. # In the final chapter of the training where no interruption is encountered, # we rely on __del__ to clean up. However, there is no guarantee when or # whether __del__ is executed, thus we make the threads daemon to avoid it # preventing program from exit. self._cluster_wise_termination_watcher_thread = threading.Thread( target=self._watch_step_to_save_key, name='PeerTerminationWatcher-%s' % self._id_in_cluster, daemon=True) logging.info('Start watcher for peer\'s signal.') self._cluster_wise_termination_watcher_thread.start() self._poll_termination_signal_thread = None if completed_termination_config.termination_watcher_fn: self._start_polling_for_termination_signal() else: self._start_watching_for_signal()