def testMultipleChiefs(self): cluster_spec = { "chief": ["127.0.0.1:8258", "127.0.0.1:7566"], } with self.assertRaisesRegexp(ValueError, "There must be at most one 'chief' job."): multi_worker_util.id_in_cluster(cluster_spec, "chief", 0)
def testMultipleChiefs(self): cluster_spec = { "chief": ["127.0.0.1:8258", "127.0.0.1:7566"], } with self.assertRaisesRegex(ValueError, "There must be at most one 'chief' job."): multi_worker_util.id_in_cluster(cluster_spec, "chief", 0)
def testWorkerId(self): cluster_spec = { "chief": ["127.0.0.1:1234"], "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] } self.assertEqual( multi_worker_util.id_in_cluster(cluster_spec, "worker", 1), 2) cluster_spec = { "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] } self.assertEqual( multi_worker_util.id_in_cluster(cluster_spec, "worker", 1), 1)
def testWorkerId(self): cluster_spec = { "chief": ["127.0.0.1:1234"], "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] } self.assertEqual( multi_worker_util.id_in_cluster(cluster_spec, "worker", 1), 2) cluster_spec = { "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] } self.assertEqual( multi_worker_util.id_in_cluster(cluster_spec, "worker", 1), 1)
def testEvaluatorId(self): cluster_spec = { "chief": ["127.0.0.1:1234"], "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], "evaluator": ["127.0.0.1:7566"] } self.assertEqual( multi_worker_util.id_in_cluster(cluster_spec, "evaluator", 0), 0)
def testEvaluatorId(self): cluster_spec = { "chief": ["127.0.0.1:1234"], "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], "evaluator": ["127.0.0.1:7566"] } self.assertEqual( multi_worker_util.id_in_cluster(cluster_spec, "evaluator", 0), 0)
def maybe_shard_dataset(dataset): """Shard the dataset if running in multi-node environment.""" cluster_resolver = TFConfigClusterResolver() cluster_spec = cluster_resolver.cluster_spec().as_dict() if cluster_spec: dataset = dataset.shard( multi_worker_util.worker_count(cluster_spec, cluster_resolver.task_type), multi_worker_util.id_in_cluster( cluster_spec, cluster_resolver.task_type, cluster_resolver.task_id)) return dataset
def _make_input_context(self): if self._cluster_spec is None: input_pipeline_id = 0 else: input_pipeline_id = multi_worker_util.id_in_cluster( self._cluster_spec, self._task_type, self._task_id) input_context = distribute_lib.InputContext( num_input_pipelines=self._num_workers, input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) return input_context
def _make_input_context(self): if self._cluster_spec is None: input_pipeline_id = 0 else: input_pipeline_id = multi_worker_util.id_in_cluster( self._cluster_spec, self._task_type, self._task_id) input_context = distribute_lib.InputContext( num_input_pipelines=self._num_workers, input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) return input_context
def maybe_shard_dataset(dataset): """Shard the dataset if running in multi-node environment.""" cluster_resolver = TFConfigClusterResolver() cluster_spec = cluster_resolver.cluster_spec().as_dict() if cluster_spec: dataset = dataset.shard( multi_worker_util.worker_count(cluster_spec, cluster_resolver.task_type), multi_worker_util.id_in_cluster(cluster_spec, cluster_resolver.task_type, cluster_resolver.task_id)) return dataset
def batch_and_maybe_shard_dataset(dataset, global_batch_size): """Shard the dataset if running in multi-node environment.""" cluster_resolver = TFConfigClusterResolver() cluster_spec = cluster_resolver.cluster_spec().as_dict() if cluster_spec: task_type = cluster_resolver.task_type task_id = cluster_resolver.task_id num_workers = int(multi_worker_util.worker_count(cluster_spec, task_type)) id_in_cluster = int( multi_worker_util.id_in_cluster(cluster_spec, task_type, task_id)) dataset = dataset.shard(num_workers, id_in_cluster) return dataset.batch(global_batch_size)
def batch_and_maybe_shard_dataset(dataset, global_batch_size): """Shard the dataset if running in multi-node environment.""" cluster_resolver = TFConfigClusterResolver() cluster_spec = cluster_resolver.cluster_spec().as_dict() if cluster_spec: task_type = cluster_resolver.task_type task_id = cluster_resolver.task_id num_workers = int( multi_worker_util.worker_count(cluster_spec, task_type)) id_in_cluster = int( multi_worker_util.id_in_cluster(cluster_spec, task_type, task_id)) dataset = dataset.shard(num_workers, id_in_cluster) return dataset.batch(global_batch_size)
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): """Distributes the dataset to each local GPU.""" if self._cluster_spec is None: input_pipeline_id = 0 else: input_pipeline_id = multi_worker_util.id_in_cluster( self._cluster_spec, self._task_type, self._task_id) input_context = distribute_lib.InputContext( num_input_pipelines=self._num_workers, input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) return values.PerReplicaDataset( self._call_dataset_fn(input_fn, input_context), self._devices, True)
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): """Distributes the dataset to each local GPU.""" if self._cluster_spec is None: input_pipeline_id = 0 else: input_pipeline_id = multi_worker_util.id_in_cluster( self._cluster_spec, self._task_type, self._task_id) input_context = distribute_lib.InputContext( num_input_pipelines=self._num_workers, input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) return input_lib.InputFunctionIterator( input_fn, self._input_workers, [input_context])
def _experimental_distribute_datasets_from_function(self, dataset_fn): if self._cluster_spec: input_pipeline_id = multi_worker_util.id_in_cluster( self._cluster_spec, self._task_type, self._task_id) num_input_pipelines = multi_worker_util.worker_count( self._cluster_spec, self._task_type) else: input_pipeline_id = 0 num_input_pipelines = 1 input_context = distribute_lib.InputContext( num_input_pipelines=num_input_pipelines, input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) return input_lib.get_distributed_datasets_from_function( dataset_fn, self._input_workers, [input_context], self._container_strategy())
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): """Distributes the dataset to each local GPU.""" if self._cluster_spec: input_pipeline_id = multi_worker_util.id_in_cluster( self._cluster_spec, self._task_type, self._task_id) num_input_pipelines = multi_worker_util.worker_count( self._cluster_spec, self._task_type) else: input_pipeline_id = 0 num_input_pipelines = 1 input_context = distribute_lib.InputContext( num_input_pipelines=num_input_pipelines, input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) return input_lib.InputFunctionIterator(input_fn, self._input_workers, [input_context])
def _make_input_fn_iterator( self, input_fn, replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): """Distributes the dataset to each local GPU.""" if self._cluster_spec: input_pipeline_id = multi_worker_util.id_in_cluster( self._cluster_spec, self._task_type, self._task_id) num_input_pipelines = multi_worker_util.worker_count( self._cluster_spec, self._task_type) else: input_pipeline_id = 0 num_input_pipelines = 1 input_context = distribute_lib.InputContext( num_input_pipelines=num_input_pipelines, input_pipeline_id=input_pipeline_id, num_replicas_in_sync=self._num_replicas_in_sync) worker_device_pairs = [(self._worker_device, self._compute_devices)] return values.InputFunctionIterator( input_fn, worker_device_pairs, [input_context])
def _initialize_multi_worker(self, cluster_resolver): """Initializes the object for multi-worker training.""" cluster_spec = multi_worker_util.normalize_cluster_spec( cluster_resolver.cluster_spec()) task_type = cluster_resolver.task_type task_id = cluster_resolver.task_id if task_type is None or task_id is None: raise ValueError( "When `cluster_spec` is given, you must also specify " "`task_type` and `task_id`.") self._cluster_spec = cluster_spec self._task_type = task_type self._task_id = task_id self._id_in_cluster = multi_worker_util.id_in_cluster( self._cluster_spec, self._task_type, self._task_id) self._num_workers = multi_worker_util.worker_count( cluster_spec, task_type) if not self._num_workers: raise ValueError( "No `worker`, `chief` or `evaluator` tasks can be found " "in `cluster_spec`.") self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, task_id) self._worker_device = "/job:%s/task:%d" % (task_type, task_id) self._host_input_device = numpy_dataset.SingleDevice( self._worker_device) if (ops.executing_eagerly_outside_functions() and not getattr(self, "_local_or_standalone_client_mode", False)): context.context().configure_collective_ops( collective_leader=multi_worker_util.collective_leader( cluster_spec, task_type, task_id), scoped_allocator_enabled_ops=("CollectiveReduce", ), device_filters=("/job:%s/task:%d" % (task_type, task_id), )) self._collective_ops_configured = True # Starting a std server in eager mode and in independent worker mode. if (context.executing_eagerly() and not getattr(self, "_std_server_started", False) and not getattr(self, "_local_or_standalone_client_mode", False)): # Checking _local_or_standalone_client_mode as well because we should not # create the std server in standalone client mode. config_proto = copy.deepcopy(context.context().config) config_proto = self._update_config_proto(config_proto) if hasattr(cluster_resolver, "port"): port = cluster_resolver.port else: port = 0 server_def = tensorflow_server_pb2.ServerDef( cluster=cluster_spec.as_cluster_def(), default_session_config=config_proto, job_name=task_type, task_index=task_id, protocol=cluster_resolver.rpc_layer or "grpc", port=port) context.context().enable_collective_ops(server_def) self._std_server_started = True # The `ensure_initialized` is needed before calling # `context.context().devices()`. context.context().ensure_initialized() logging.info( "Enabled multi-worker collective ops with available devices: %r", context.context().devices()) # TODO(yuefengz): The `num_gpus` is only for this particular task. It # assumes all workers have the same number of GPUs. We should remove this # assumption by querying all tasks for their numbers of GPUs. # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in # some cases. if isinstance(cluster_resolver, TFConfigClusterResolver): num_gpus = context.num_gpus() else: num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) if num_gpus: local_devices = tuple("%s/device:GPU:%d" % (self._worker_device, i) for i in range(num_gpus)) else: local_devices = (self._worker_device, ) self._collective_keys = cross_device_utils.CollectiveKeys() self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( devices=local_devices, group_size=len(local_devices) * self._num_workers, collective_keys=self._collective_keys, communication=self._communication) # CrossDeviceOps for per host tensors. self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( devices=[self._worker_device], group_size=self._num_workers, collective_keys=self._collective_keys, communication=cross_device_ops_lib.CollectiveCommunication.RING, ) super(CollectiveAllReduceExtended, self)._initialize_single_worker(local_devices) # Add a default device so that ops without specified devices will not end up # on other workers. self._default_device = "/job:%s/task:%d" % (task_type, task_id) # Save the num_gpus_per_worker and rpc_layer for configure method. self._num_gpus_per_worker = num_gpus self._rpc_layer = cluster_resolver.rpc_layer self._warn_nccl_no_gpu() # TODO(b/151232436): Enable check health thread by default. if self._enable_check_health: self._start_check_health_thread() logging.info( "MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, " "task_id = %r, num_workers = %r, local_devices = %r, " "communication = %s", cluster_spec.as_dict(), task_type, task_id, self._num_workers, local_devices, self._communication)
def __init__(self, cluster_resolver, checkpoint, checkpoint_dir, termination_config=TerminationConfig()): """Creates the failure handler. Args: cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`. You may also get it through the `cluster_resolver` attribute of the strategy in use. checkpoint: a `tf.train.Checkpoint` that will be saved upon preemption and loaded upon restart by the `WorkerPreemptionHandler` API automatically. checkpoint_dir: a directory for the `WorkerPreemptionHandler` to play with checkpoints. `WorkerPreemptionHandler` will create a `tf.train.CheckpointManager` to manage the passed-in `checkpoint`. Since only one `tf.train.CheckpointManager` should be active in a particular directory at a time, this `checkpoint_dir` arg should preferably be separated from where the user saves their checkpoint for non-fault tolerance purpose. termination_config: a `TerminationConfig` object to configure for a platform other than Google Borg or GCP. """ self._cluster_resolver = cluster_resolver self._checkpoint = checkpoint self._id_in_cluster = str( multi_worker_util.id_in_cluster( self._cluster_resolver.cluster_spec(), self._cluster_resolver.task_type, self._cluster_resolver.task_id)) # The number of calls to `WorkerPreemptionHandler.run` when the latest # checkpoint was saved. self._checkpointed_runs = variables.Variable( initial_value=constant_op.constant(0, dtype=dtypes.int64), trainable=False, name=_ITERATION_VARIABLE) if not hasattr(self._checkpoint, _ITERATION_VARIABLE): setattr(self._checkpoint, _ITERATION_VARIABLE, self._checkpointed_runs) # Make CheckpointManagers. MultiWorkerMirroredStrategy requires different # setup on chief and on other workers. self._read_checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, directory=checkpoint_dir, max_to_keep=1) if multi_worker_util.is_chief( cluster_spec=cluster_resolver.cluster_spec(), task_type=cluster_resolver.task_type, task_id=cluster_resolver.task_id): self._write_checkpoint_manager = self._read_checkpoint_manager else: self._write_checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, _mwms_write_checkpoint_dir(checkpoint_dir, cluster_resolver.task_type, cluster_resolver.task_id, cluster_resolver.cluster_spec()), max_to_keep=1) self._read_checkpoint_manager.restore_or_initialize() # grace period countdown. Set to True for all workers once they finish # timing saving a checkpoint. Once entering this phase, new # preemption/maintenance notice will not be handled, since the whole cluster # goes down as the worker who first initiates the grace period goes down. self._final_checkpoint_countdown = False self._estimated_run_time = 0 # An internal step counter that's restored to checkpointed_iterations when # training is restored. It increments by one every time # `WorkerPreemptionHandler.run` is called. Note that in this case, the # user must pass a single-step training function to # `WorkerPreemptionHandler.run` instead of a multiple-step one. self._run_counter = self._checkpointed_runs.numpy() # The worker itself has received preeption signal. self._received_own_sigterm = threading.Event() # Some member (could be oneself) has received preemption signal, and the # step number to save a checkpoint has been aligned. self._received_checkpoint_step = threading.Event() self._platform_device = gce_util.detect_platform() completed_termination_config = _complete_config_for_environement( self._platform_device, termination_config) self._termination_watcher_function = completed_termination_config.termination_watcher_function self._exit_fn = completed_termination_config.exit_fn self._grace_period = completed_termination_config.time_till_termination # When training is interrupted, we explicitly call the cleanup methods for # the thread watching for local worker's termination signal and the thread # watching for clusterwise information before we save a checkpoint and exit. # In the final chapter of the training where no interruption is encountered, # we rely on __del__ to clean up. However, there is no guarantee when or # whether __del__ is executed, thus we make the threads daemon to avoid it # preventing program from exit. self._cluster_wise_termination_watcher_thread = threading.Thread( target=self._watch_step_to_save_key, name='PeerTerminationWatcher-%s' % self._id_in_cluster, daemon=True) logging.info('Start watcher for peer\'s signal.') self._cluster_wise_termination_watcher_thread.start() self._poll_termination_signal_thread = None if completed_termination_config.termination_watcher_function: self._start_polling_for_termination_signal() else: self._start_watching_for_signal()
def testPsId(self): cluster_spec = {"chief": ["127.0.0.1:1234"], "ps": ["127.0.0.1:7566"]} with self.assertRaisesRegexp(ValueError, "There is no id for task_type 'ps'"): multi_worker_util.id_in_cluster(cluster_spec, "ps", 0)
def testPsId(self): cluster_spec = {"chief": ["127.0.0.1:1234"], "ps": ["127.0.0.1:7566"]} with self.assertRaisesRegex(ValueError, "There is no id for task_type 'ps'"): multi_worker_util.id_in_cluster(cluster_spec, "ps", 0)
def __init__(self, cluster_resolver, checkpoint_or_checkpoint_manager, checkpoint_dir=None, termination_config=None): """Creates the `PreemptionCheckpointHandler`. Args: cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver` object. You may also obtain it through the `cluster_resolver` attribute of the distribution strategy in use. checkpoint_or_checkpoint_manager: a `tf.train.CheckpointManager` or a `tf.train.Checkpoint`. If you are using a `tf.train.CheckpointManager` to manage checkpoints outside the `PreemptionCheckpointHandler` for backup purpose as well, pass it as `checkpoint_or_checkpoint_manager` argument. Otherwise, pass a `tf.train.Checkpoint` and the `PreemptionCheckpointHandler` will create a `tf.train.CheckpointManager` to manage it in the `checkpoint_dir`. checkpoint_dir: a directory where the `PreemptionCheckpointHandler` saves and restores checkpoints. When a `PreemptionCheckpointHandler` is created, the latest checkpoint in the `checkpoint_dir` will be restored. (This is not needed if a `tf.train.CheckpointManager` instead of a `tf.train.Checkpoint` is passed as the `checkpoint_or_checkpoint_manager` argument.) termination_config: optional, a `tf.distribute.experimental.TerminationConfig` object to configure for a platform other than Google Borg or GCP. """ self._cluster_resolver = cluster_resolver if isinstance(checkpoint_or_checkpoint_manager, checkpoint_lib.Checkpoint) and not checkpoint_dir: raise errors.InvalidArgumentError( 'When a checkpoint is passed, a ' 'checkpoint_dir must be passed as well' '.') self._id_in_cluster = str( multi_worker_util.id_in_cluster( self._cluster_resolver.cluster_spec(), self._cluster_resolver.task_type, self._cluster_resolver.task_id)) # The number of calls to `PreemptionCheckpointHandler.run` when the latest # checkpoint was saved. self._checkpointed_runs = variables.Variable( initial_value=constant_op.constant(0, dtype=dtypes.int64), trainable=False, name=_ITERATION_VARIABLE) self._maybe_create_checkpoint_manager(checkpoint_or_checkpoint_manager, checkpoint_dir, cluster_resolver) if not hasattr(self._write_checkpoint_manager._checkpoint, _ITERATION_VARIABLE): setattr(self._write_checkpoint_manager._checkpoint, _ITERATION_VARIABLE, self._checkpointed_runs) if not hasattr(self._read_checkpoint_manager._checkpoint, _ITERATION_VARIABLE): setattr(self._read_checkpoint_manager._checkpoint, _ITERATION_VARIABLE, self._checkpointed_runs) self._read_checkpoint_manager.restore_or_initialize() # grace period countdown. Set to True for all workers once they finish # timing saving a checkpoint. Once entering this phase, new # preemption/maintenance notice will not be handled, since the whole cluster # goes down as the worker who first initiates the grace period goes down. self._final_checkpoint_countdown = False self._estimated_run_time = 0 # An internal step counter that's restored to checkpointed_iterations when # training is restored. It increments by one every time # `PreemptionCheckpointHandler.run` is called. Note that in this case, the # user must pass a single-step training function to # `PreemptionCheckpointHandler.run` instead of a multiple-step one. self._run_counter = self._checkpointed_runs.numpy() # The worker itself has received preeption signal. self._received_own_sigterm = threading.Event() # Some member (could be oneself) has received preemption signal, and the # step number to save a checkpoint has been aligned. self._received_checkpoint_step = threading.Event() self._platform_device = gce_util.detect_platform() if self._platform_device in (gce_util.PlatformDevice.GCE_TPU, gce_util.PlatformDevice.GCE_CPU): # While running MultiWorkerMirroredStrategy training with GPUs and CPUs # are the same on Borg, GCE CPU VM and GPU VM are different in terms # of live migration, grace period, etc. We can make it work upon request. raise NotImplementedError( 'PreemptionCheckpointHandler does not support ' 'training with TPU or CPU device on GCP.') completed_termination_config = _complete_config_for_environment( self._platform_device, termination_config) self._termination_watcher_fn = completed_termination_config.termination_watcher_fn self._exit_fn = completed_termination_config.exit_fn self._grace_period = completed_termination_config.grace_period # When training is interrupted, we explicitly call the cleanup methods for # the thread watching for local worker's termination signal and the thread # watching for clusterwise information before we save a checkpoint and exit. # In the final chapter of the training where no interruption is encountered, # we rely on __del__ to clean up. However, there is no guarantee when or # whether __del__ is executed, thus we make the threads daemon to avoid it # preventing program from exit. self._cluster_wise_termination_watcher_thread = threading.Thread( target=self._watch_step_to_save_key, name='PeerTerminationWatcher-%s' % self._id_in_cluster, daemon=True) logging.info('Start watcher for peer\'s signal.') self._cluster_wise_termination_watcher_thread.start() self._poll_termination_signal_thread = None if completed_termination_config.termination_watcher_fn: self._start_polling_for_termination_signal() else: self._start_watching_for_signal()
def __init__(self, cluster_resolver, checkpoint, checkpoint_dir): """Creates the failure handler. Args: cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`. You may also get it through the `cluster_resolver` attribute of the strategy in use. checkpoint: a `tf.train.Checkpoint` that will be saved upon preemption and loaded upon restart by the `CoordinatedCheckpointManager` API automatically. checkpoint_dir: a directory for the `CoordinatedCheckpointManager` to play with checkpoints. `CoordinatedCheckpointManager` will create a `tf.train.CheckpointManager` to manage the passed-in `checkpoint`. Since only one `tf.train.CheckpointManager` should be active in a particular directory at a time, this `checkpoint_dir` arg should preferably be separated from where the user saves their checkpoint for non-fault tolerance purpose. """ self._cluster_resolver = cluster_resolver self._checkpoint = checkpoint self._id_in_cluster = str( multi_worker_util.id_in_cluster( self._cluster_resolver.cluster_spec(), self._cluster_resolver.task_type, self._cluster_resolver.task_id)) # The number of calls to `CoordinatedCheckpointManager.run` when the latest # checkpoint was saved. self._checkpointed_runs = variables.Variable( initial_value=constant_op.constant(0, dtype=dtypes.int64), trainable=False, name=_ITERATION_VARIABLE) if not hasattr(self._checkpoint, _ITERATION_VARIABLE): setattr(self._checkpoint, _ITERATION_VARIABLE, self._checkpointed_runs) # Make CheckpointManagers. MultiWorkerMirroredStrategy requires different # setup on chief and on other workers. self._read_checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, directory=checkpoint_dir, max_to_keep=1) if multi_worker_util.is_chief( cluster_spec=cluster_resolver.cluster_spec(), task_type=cluster_resolver.task_type, task_id=cluster_resolver.task_id): self._write_checkpoint_manager = self._read_checkpoint_manager else: self._write_checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, _mwms_write_checkpoint_dir(checkpoint_dir, cluster_resolver.task_type, cluster_resolver.task_id, cluster_resolver.cluster_spec()), max_to_keep=1) self._read_checkpoint_manager.restore_or_initialize() # An internal step counter that's restored to checkpointed_iterations when # training is restored. It increments by one every time # `CoordinatedCheckpointManager.run` is called. Note that in this case, the # user must pass a single-step training function to # `CoordinatedCheckpointManager.run` instead of a multiple-step one. self._run_counter = self._checkpointed_runs.numpy() # The worker itself has received preeption signal. self._received_own_sigterm = threading.Event() # Some member (could be oneself) has received preemption signal, and the # step number to save a checkpoint has been aligned. self._received_sigterm_and_step = threading.Event() # TODO(wxinyi): Enforce that only one instance of this class is created # per program. # TODO(wxinyi): make the thread non-daemon. threading.Thread(target=self._wait_for_signal, daemon=True).start() signal.signal(signal.SIGTERM, self._sigterm_handler_fn)
def __init__(self, cluster_resolver, checkpoint, checkpoint_dir): """Creates the failure handler. Args: cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`. You may also get it through the `cluster_resolver` attribute of the strategy in use. checkpoint: a `tf.train.Checkpoint` that will be saved upon preemption and loaded upon restart by the `CoordinatedCheckpointManager` API automatically. checkpoint_dir: a directory for the `CoordinatedCheckpointManager` to play with checkpoints. `CoordinatedCheckpointManager` will create a `tf.train.CheckpointManager` to manage the passed-in `checkpoint`. Since only one `tf.train.CheckpointManager` should be active in a particular directory at a time, this `checkpoint_dir` arg should preferably be separated from where the user saves their checkpoint for non-fault tolerance purpose. """ self._cluster_resolver = cluster_resolver self._checkpoint = checkpoint self._id_in_cluster = str( multi_worker_util.id_in_cluster( self._cluster_resolver.cluster_spec(), self._cluster_resolver.task_type, self._cluster_resolver.task_id)) # The number of calls to `CoordinatedCheckpointManager.run` when the latest # checkpoint was saved. self._checkpointed_runs = variables.Variable( initial_value=constant_op.constant(0, dtype=dtypes.int64), trainable=False, name=_ITERATION_VARIABLE) if not hasattr(self._checkpoint, _ITERATION_VARIABLE): setattr(self._checkpoint, _ITERATION_VARIABLE, self._checkpointed_runs) # Make CheckpointManagers. MultiWorkerMirroredStrategy requires different # setup on chief and on other workers. self._read_checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, directory=checkpoint_dir, max_to_keep=1) if multi_worker_util.is_chief( cluster_spec=cluster_resolver.cluster_spec(), task_type=cluster_resolver.task_type, task_id=cluster_resolver.task_id): self._write_checkpoint_manager = self._read_checkpoint_manager else: self._write_checkpoint_manager = checkpoint_management.CheckpointManager( checkpoint, _mwms_write_checkpoint_dir(checkpoint_dir, cluster_resolver.task_type, cluster_resolver.task_id, cluster_resolver.cluster_spec()), max_to_keep=1) self._read_checkpoint_manager.restore_or_initialize() # An internal step counter that's restored to checkpointed_iterations when # training is restored. It increments by one every time # `CoordinatedCheckpointManager.run` is called. Note that in this case, the # user must pass a single-step training function to # `CoordinatedCheckpointManager.run` instead of a multiple-step one. self._run_counter = self._checkpointed_runs.numpy() # The worker itself has received preeption signal. self._received_own_sigterm = threading.Event() # Some member (could be oneself) has received preemption signal, and the # step number to save a checkpoint has been aligned. self._received_sigterm_and_step = threading.Event() # When training is interrupted, we explicitly call the cleanup methods for # the thread watching for local worker's termination signal and the thread # watching for clusterwise information before we save a checkpoint and exit. # In the final chapter of the training where no interruption is encountered, # we rely on __del__ to clean up. However, there is no guarantee when or # whether __del__ is executed, thus we make the threads daemon to avoid it # preventing program from exit. self._cluster_wise_termination_watcher_thread = threading.Thread( target=self._wait_for_signal, name='PeerTerminationWatcher-%s' % self._id_in_cluster, daemon=True) self._cluster_wise_termination_watcher_thread.start() self._poll_gce_signal_thread = None self._platform_device = gce_util.detect_platform() if self._platform_device is gce_util.PlatformDevice.GCE_GPU: self._start_polling_for_gce_signal() self._exit_code = gce_util._RESTARTABLE_EXIT_CODE elif self._platform_device is gce_util.PlatformDevice.INTERNAL: self._start_watching_for_signal() self._exit_code = _RESTARTABLE_EXIT_CODE else: raise NotImplementedError( 'CoordinatedCheckpointManager is only supported' ' for MultiWorkerMirroredStrategy with GPU.')