def testMultipleChiefs(self):
   cluster_spec = {
       "chief": ["127.0.0.1:8258", "127.0.0.1:7566"],
   }
   with self.assertRaisesRegexp(ValueError,
                                "There must be at most one 'chief' job."):
     multi_worker_util.id_in_cluster(cluster_spec, "chief", 0)
 def testMultipleChiefs(self):
     cluster_spec = {
         "chief": ["127.0.0.1:8258", "127.0.0.1:7566"],
     }
     with self.assertRaisesRegex(ValueError,
                                 "There must be at most one 'chief' job."):
         multi_worker_util.id_in_cluster(cluster_spec, "chief", 0)
    def testWorkerId(self):
        cluster_spec = {
            "chief": ["127.0.0.1:1234"],
            "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
            "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
        }
        self.assertEqual(
            multi_worker_util.id_in_cluster(cluster_spec, "worker", 1), 2)

        cluster_spec = {
            "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
            "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
        }
        self.assertEqual(
            multi_worker_util.id_in_cluster(cluster_spec, "worker", 1), 1)
  def testWorkerId(self):
    cluster_spec = {
        "chief": ["127.0.0.1:1234"],
        "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
        "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
    }
    self.assertEqual(
        multi_worker_util.id_in_cluster(cluster_spec, "worker", 1), 2)

    cluster_spec = {
        "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
        "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
    }
    self.assertEqual(
        multi_worker_util.id_in_cluster(cluster_spec, "worker", 1), 1)
 def testEvaluatorId(self):
   cluster_spec = {
       "chief": ["127.0.0.1:1234"],
       "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
       "evaluator": ["127.0.0.1:7566"]
   }
   self.assertEqual(
       multi_worker_util.id_in_cluster(cluster_spec, "evaluator", 0), 0)
 def testEvaluatorId(self):
     cluster_spec = {
         "chief": ["127.0.0.1:1234"],
         "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
         "evaluator": ["127.0.0.1:7566"]
     }
     self.assertEqual(
         multi_worker_util.id_in_cluster(cluster_spec, "evaluator", 0), 0)
def maybe_shard_dataset(dataset):
  """Shard the dataset if running in multi-node environment."""
  cluster_resolver = TFConfigClusterResolver()
  cluster_spec = cluster_resolver.cluster_spec().as_dict()
  if cluster_spec:
    dataset = dataset.shard(
        multi_worker_util.worker_count(cluster_spec,
                                       cluster_resolver.task_type),
        multi_worker_util.id_in_cluster(
            cluster_spec, cluster_resolver.task_type, cluster_resolver.task_id))
  return dataset
 def _make_input_context(self):
   if self._cluster_spec is None:
     input_pipeline_id = 0
   else:
     input_pipeline_id = multi_worker_util.id_in_cluster(
         self._cluster_spec, self._task_type, self._task_id)
   input_context = distribute_lib.InputContext(
       num_input_pipelines=self._num_workers,
       input_pipeline_id=input_pipeline_id,
       num_replicas_in_sync=self._num_replicas_in_sync)
   return input_context
Esempio n. 9
0
 def _make_input_context(self):
     if self._cluster_spec is None:
         input_pipeline_id = 0
     else:
         input_pipeline_id = multi_worker_util.id_in_cluster(
             self._cluster_spec, self._task_type, self._task_id)
     input_context = distribute_lib.InputContext(
         num_input_pipelines=self._num_workers,
         input_pipeline_id=input_pipeline_id,
         num_replicas_in_sync=self._num_replicas_in_sync)
     return input_context
Esempio n. 10
0
def maybe_shard_dataset(dataset):
    """Shard the dataset if running in multi-node environment."""
    cluster_resolver = TFConfigClusterResolver()
    cluster_spec = cluster_resolver.cluster_spec().as_dict()
    if cluster_spec:
        dataset = dataset.shard(
            multi_worker_util.worker_count(cluster_spec,
                                           cluster_resolver.task_type),
            multi_worker_util.id_in_cluster(cluster_spec,
                                            cluster_resolver.task_type,
                                            cluster_resolver.task_id))
    return dataset
def batch_and_maybe_shard_dataset(dataset, global_batch_size):
  """Shard the dataset if running in multi-node environment."""

  cluster_resolver = TFConfigClusterResolver()
  cluster_spec = cluster_resolver.cluster_spec().as_dict()
  if cluster_spec:
    task_type = cluster_resolver.task_type
    task_id = cluster_resolver.task_id
    num_workers = int(multi_worker_util.worker_count(cluster_spec, task_type))
    id_in_cluster = int(
        multi_worker_util.id_in_cluster(cluster_spec, task_type, task_id))
    dataset = dataset.shard(num_workers, id_in_cluster)
  return dataset.batch(global_batch_size)
Esempio n. 12
0
def batch_and_maybe_shard_dataset(dataset, global_batch_size):
    """Shard the dataset if running in multi-node environment."""

    cluster_resolver = TFConfigClusterResolver()
    cluster_spec = cluster_resolver.cluster_spec().as_dict()
    if cluster_spec:
        task_type = cluster_resolver.task_type
        task_id = cluster_resolver.task_id
        num_workers = int(
            multi_worker_util.worker_count(cluster_spec, task_type))
        id_in_cluster = int(
            multi_worker_util.id_in_cluster(cluster_spec, task_type, task_id))
        dataset = dataset.shard(num_workers, id_in_cluster)
    return dataset.batch(global_batch_size)
Esempio n. 13
0
 def _make_input_fn_iterator(
     self,
     input_fn,
     replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
   """Distributes the dataset to each local GPU."""
   if self._cluster_spec is None:
     input_pipeline_id = 0
   else:
     input_pipeline_id = multi_worker_util.id_in_cluster(
         self._cluster_spec, self._task_type, self._task_id)
   input_context = distribute_lib.InputContext(
       num_input_pipelines=self._num_workers,
       input_pipeline_id=input_pipeline_id,
       num_replicas_in_sync=self._num_replicas_in_sync)
   return values.PerReplicaDataset(
       self._call_dataset_fn(input_fn, input_context), self._devices, True)
  def _make_input_fn_iterator(
      self,
      input_fn,
      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
    """Distributes the dataset to each local GPU."""
    if self._cluster_spec is None:
      input_pipeline_id = 0
    else:
      input_pipeline_id = multi_worker_util.id_in_cluster(
          self._cluster_spec, self._task_type, self._task_id)
    input_context = distribute_lib.InputContext(
        num_input_pipelines=self._num_workers,
        input_pipeline_id=input_pipeline_id,
        num_replicas_in_sync=self._num_replicas_in_sync)

    return input_lib.InputFunctionIterator(
        input_fn, self._input_workers, [input_context])
Esempio n. 15
0
    def _experimental_distribute_datasets_from_function(self, dataset_fn):
        if self._cluster_spec:
            input_pipeline_id = multi_worker_util.id_in_cluster(
                self._cluster_spec, self._task_type, self._task_id)
            num_input_pipelines = multi_worker_util.worker_count(
                self._cluster_spec, self._task_type)
        else:
            input_pipeline_id = 0
            num_input_pipelines = 1

        input_context = distribute_lib.InputContext(
            num_input_pipelines=num_input_pipelines,
            input_pipeline_id=input_pipeline_id,
            num_replicas_in_sync=self._num_replicas_in_sync)

        return input_lib.get_distributed_datasets_from_function(
            dataset_fn, self._input_workers, [input_context],
            self._container_strategy())
 def _make_input_fn_iterator(
         self,
         input_fn,
         replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
     """Distributes the dataset to each local GPU."""
     if self._cluster_spec:
         input_pipeline_id = multi_worker_util.id_in_cluster(
             self._cluster_spec, self._task_type, self._task_id)
         num_input_pipelines = multi_worker_util.worker_count(
             self._cluster_spec, self._task_type)
     else:
         input_pipeline_id = 0
         num_input_pipelines = 1
     input_context = distribute_lib.InputContext(
         num_input_pipelines=num_input_pipelines,
         input_pipeline_id=input_pipeline_id,
         num_replicas_in_sync=self._num_replicas_in_sync)
     return input_lib.InputFunctionIterator(input_fn, self._input_workers,
                                            [input_context])
 def _make_input_fn_iterator(
     self,
     input_fn,
     replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
   """Distributes the dataset to each local GPU."""
   if self._cluster_spec:
     input_pipeline_id = multi_worker_util.id_in_cluster(
         self._cluster_spec, self._task_type, self._task_id)
     num_input_pipelines = multi_worker_util.worker_count(
         self._cluster_spec, self._task_type)
   else:
     input_pipeline_id = 0
     num_input_pipelines = 1
   input_context = distribute_lib.InputContext(
       num_input_pipelines=num_input_pipelines,
       input_pipeline_id=input_pipeline_id,
       num_replicas_in_sync=self._num_replicas_in_sync)
   worker_device_pairs = [(self._worker_device, self._compute_devices)]
   return values.InputFunctionIterator(
       input_fn, worker_device_pairs, [input_context])
Esempio n. 18
0
    def _initialize_multi_worker(self, cluster_resolver):
        """Initializes the object for multi-worker training."""
        cluster_spec = multi_worker_util.normalize_cluster_spec(
            cluster_resolver.cluster_spec())
        task_type = cluster_resolver.task_type
        task_id = cluster_resolver.task_id
        if task_type is None or task_id is None:
            raise ValueError(
                "When `cluster_spec` is given, you must also specify "
                "`task_type` and `task_id`.")
        self._cluster_spec = cluster_spec
        self._task_type = task_type
        self._task_id = task_id
        self._id_in_cluster = multi_worker_util.id_in_cluster(
            self._cluster_spec, self._task_type, self._task_id)

        self._num_workers = multi_worker_util.worker_count(
            cluster_spec, task_type)
        if not self._num_workers:
            raise ValueError(
                "No `worker`, `chief` or `evaluator` tasks can be found "
                "in `cluster_spec`.")

        self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
                                                    task_id)

        self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
        self._host_input_device = numpy_dataset.SingleDevice(
            self._worker_device)

        if (ops.executing_eagerly_outside_functions() and
                not getattr(self, "_local_or_standalone_client_mode", False)):
            context.context().configure_collective_ops(
                collective_leader=multi_worker_util.collective_leader(
                    cluster_spec, task_type, task_id),
                scoped_allocator_enabled_ops=("CollectiveReduce", ),
                device_filters=("/job:%s/task:%d" % (task_type, task_id), ))
            self._collective_ops_configured = True

        # Starting a std server in eager mode and in independent worker mode.
        if (context.executing_eagerly()
                and not getattr(self, "_std_server_started", False) and
                not getattr(self, "_local_or_standalone_client_mode", False)):
            # Checking _local_or_standalone_client_mode as well because we should not
            # create the std server in standalone client mode.
            config_proto = copy.deepcopy(context.context().config)
            config_proto = self._update_config_proto(config_proto)

            if hasattr(cluster_resolver, "port"):
                port = cluster_resolver.port
            else:
                port = 0
            server_def = tensorflow_server_pb2.ServerDef(
                cluster=cluster_spec.as_cluster_def(),
                default_session_config=config_proto,
                job_name=task_type,
                task_index=task_id,
                protocol=cluster_resolver.rpc_layer or "grpc",
                port=port)
            context.context().enable_collective_ops(server_def)
            self._std_server_started = True
            # The `ensure_initialized` is needed before calling
            # `context.context().devices()`.
            context.context().ensure_initialized()
            logging.info(
                "Enabled multi-worker collective ops with available devices: %r",
                context.context().devices())

        # TODO(yuefengz): The `num_gpus` is only for this particular task. It
        # assumes all workers have the same number of GPUs. We should remove this
        # assumption by querying all tasks for their numbers of GPUs.
        # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
        # some cases.
        if isinstance(cluster_resolver, TFConfigClusterResolver):
            num_gpus = context.num_gpus()
        else:
            num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)

        if num_gpus:
            local_devices = tuple("%s/device:GPU:%d" % (self._worker_device, i)
                                  for i in range(num_gpus))
        else:
            local_devices = (self._worker_device, )

        self._collective_keys = cross_device_utils.CollectiveKeys()
        self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
            devices=local_devices,
            group_size=len(local_devices) * self._num_workers,
            collective_keys=self._collective_keys,
            communication=self._communication)
        # CrossDeviceOps for per host tensors.
        self._host_cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
            devices=[self._worker_device],
            group_size=self._num_workers,
            collective_keys=self._collective_keys,
            communication=cross_device_ops_lib.CollectiveCommunication.RING,
        )
        super(CollectiveAllReduceExtended,
              self)._initialize_single_worker(local_devices)

        # Add a default device so that ops without specified devices will not end up
        # on other workers.
        self._default_device = "/job:%s/task:%d" % (task_type, task_id)

        # Save the num_gpus_per_worker and rpc_layer for configure method.
        self._num_gpus_per_worker = num_gpus
        self._rpc_layer = cluster_resolver.rpc_layer
        self._warn_nccl_no_gpu()

        # TODO(b/151232436): Enable check health thread by default.
        if self._enable_check_health:
            self._start_check_health_thread()

        logging.info(
            "MultiWorkerMirroredStrategy with cluster_spec = %r, task_type = %r, "
            "task_id = %r, num_workers = %r, local_devices = %r, "
            "communication = %s", cluster_spec.as_dict(), task_type, task_id,
            self._num_workers, local_devices, self._communication)
Esempio n. 19
0
    def __init__(self,
                 cluster_resolver,
                 checkpoint,
                 checkpoint_dir,
                 termination_config=TerminationConfig()):
        """Creates the failure handler.

    Args:
      cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`. You
        may also get it through the `cluster_resolver` attribute of the strategy
        in use.
      checkpoint: a `tf.train.Checkpoint` that will be saved upon preemption and
        loaded upon restart by the `WorkerPreemptionHandler` API automatically.
      checkpoint_dir: a directory for the `WorkerPreemptionHandler` to play with
        checkpoints. `WorkerPreemptionHandler` will create a
        `tf.train.CheckpointManager` to manage the passed-in `checkpoint`. Since
        only one `tf.train.CheckpointManager` should be active in a particular
        directory at a time, this `checkpoint_dir` arg should preferably be
        separated from where the user saves their checkpoint for non-fault
        tolerance purpose.
      termination_config: a `TerminationConfig` object to configure for a
        platform other than Google Borg or GCP.
    """
        self._cluster_resolver = cluster_resolver
        self._checkpoint = checkpoint
        self._id_in_cluster = str(
            multi_worker_util.id_in_cluster(
                self._cluster_resolver.cluster_spec(),
                self._cluster_resolver.task_type,
                self._cluster_resolver.task_id))

        # The number of calls to `WorkerPreemptionHandler.run` when the latest
        # checkpoint was saved.
        self._checkpointed_runs = variables.Variable(
            initial_value=constant_op.constant(0, dtype=dtypes.int64),
            trainable=False,
            name=_ITERATION_VARIABLE)
        if not hasattr(self._checkpoint, _ITERATION_VARIABLE):
            setattr(self._checkpoint, _ITERATION_VARIABLE,
                    self._checkpointed_runs)

        # Make CheckpointManagers. MultiWorkerMirroredStrategy requires different
        # setup on chief and on other workers.
        self._read_checkpoint_manager = checkpoint_management.CheckpointManager(
            checkpoint, directory=checkpoint_dir, max_to_keep=1)
        if multi_worker_util.is_chief(
                cluster_spec=cluster_resolver.cluster_spec(),
                task_type=cluster_resolver.task_type,
                task_id=cluster_resolver.task_id):
            self._write_checkpoint_manager = self._read_checkpoint_manager
        else:
            self._write_checkpoint_manager = checkpoint_management.CheckpointManager(
                checkpoint,
                _mwms_write_checkpoint_dir(checkpoint_dir,
                                           cluster_resolver.task_type,
                                           cluster_resolver.task_id,
                                           cluster_resolver.cluster_spec()),
                max_to_keep=1)

        self._read_checkpoint_manager.restore_or_initialize()

        # grace period countdown. Set to True for all workers once they finish
        # timing saving a checkpoint. Once entering this phase, new
        # preemption/maintenance notice will not be handled, since the whole cluster
        # goes down as the worker who first initiates the grace period goes down.
        self._final_checkpoint_countdown = False

        self._estimated_run_time = 0

        # An internal step counter that's restored to checkpointed_iterations when
        # training is restored. It increments by one every time
        # `WorkerPreemptionHandler.run` is called. Note that in this case, the
        # user must pass a single-step training function to
        # `WorkerPreemptionHandler.run` instead of a multiple-step one.
        self._run_counter = self._checkpointed_runs.numpy()

        # The worker itself has received preeption signal.
        self._received_own_sigterm = threading.Event()

        # Some member (could be oneself) has received preemption signal, and the
        # step number to save a checkpoint has been aligned.
        self._received_checkpoint_step = threading.Event()

        self._platform_device = gce_util.detect_platform()

        completed_termination_config = _complete_config_for_environement(
            self._platform_device, termination_config)
        self._termination_watcher_function = completed_termination_config.termination_watcher_function
        self._exit_fn = completed_termination_config.exit_fn
        self._grace_period = completed_termination_config.time_till_termination

        # When training is interrupted, we explicitly call the cleanup methods for
        # the thread watching for local worker's termination signal and the thread
        # watching for clusterwise information before we save a checkpoint and exit.
        # In the final chapter of the training where no interruption is encountered,
        # we rely on __del__ to clean up. However, there is no guarantee when or
        # whether __del__ is executed, thus we make the threads daemon to avoid it
        # preventing program from exit.
        self._cluster_wise_termination_watcher_thread = threading.Thread(
            target=self._watch_step_to_save_key,
            name='PeerTerminationWatcher-%s' % self._id_in_cluster,
            daemon=True)
        logging.info('Start watcher for peer\'s signal.')
        self._cluster_wise_termination_watcher_thread.start()

        self._poll_termination_signal_thread = None

        if completed_termination_config.termination_watcher_function:
            self._start_polling_for_termination_signal()
        else:
            self._start_watching_for_signal()
 def testPsId(self):
   cluster_spec = {"chief": ["127.0.0.1:1234"], "ps": ["127.0.0.1:7566"]}
   with self.assertRaisesRegexp(ValueError,
                                "There is no id for task_type 'ps'"):
     multi_worker_util.id_in_cluster(cluster_spec, "ps", 0)
 def testPsId(self):
     cluster_spec = {"chief": ["127.0.0.1:1234"], "ps": ["127.0.0.1:7566"]}
     with self.assertRaisesRegex(ValueError,
                                 "There is no id for task_type 'ps'"):
         multi_worker_util.id_in_cluster(cluster_spec, "ps", 0)
Esempio n. 22
0
    def __init__(self,
                 cluster_resolver,
                 checkpoint_or_checkpoint_manager,
                 checkpoint_dir=None,
                 termination_config=None):
        """Creates the `PreemptionCheckpointHandler`.

    Args:
      cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`
        object. You may also obtain it through the `cluster_resolver` attribute
        of the distribution strategy in use.
      checkpoint_or_checkpoint_manager: a `tf.train.CheckpointManager` or a
        `tf.train.Checkpoint`. If you are using a `tf.train.CheckpointManager`
        to manage checkpoints outside the `PreemptionCheckpointHandler` for
        backup purpose as well, pass it as `checkpoint_or_checkpoint_manager`
        argument. Otherwise, pass a `tf.train.Checkpoint` and the
        `PreemptionCheckpointHandler` will create
        a `tf.train.CheckpointManager` to manage it in the `checkpoint_dir`.
      checkpoint_dir: a directory where the `PreemptionCheckpointHandler` saves
        and restores checkpoints. When a `PreemptionCheckpointHandler` is
        created, the latest checkpoint in the `checkpoint_dir` will be restored.
        (This is not needed if a `tf.train.CheckpointManager` instead of a
        `tf.train.Checkpoint` is passed as the
        `checkpoint_or_checkpoint_manager` argument.)
      termination_config: optional, a
        `tf.distribute.experimental.TerminationConfig` object to configure for a
        platform other than Google Borg or GCP.
    """
        self._cluster_resolver = cluster_resolver
        if isinstance(checkpoint_or_checkpoint_manager,
                      checkpoint_lib.Checkpoint) and not checkpoint_dir:
            raise errors.InvalidArgumentError(
                'When a checkpoint is passed, a '
                'checkpoint_dir must be passed as well'
                '.')
        self._id_in_cluster = str(
            multi_worker_util.id_in_cluster(
                self._cluster_resolver.cluster_spec(),
                self._cluster_resolver.task_type,
                self._cluster_resolver.task_id))

        # The number of calls to `PreemptionCheckpointHandler.run` when the latest
        # checkpoint was saved.
        self._checkpointed_runs = variables.Variable(
            initial_value=constant_op.constant(0, dtype=dtypes.int64),
            trainable=False,
            name=_ITERATION_VARIABLE)

        self._maybe_create_checkpoint_manager(checkpoint_or_checkpoint_manager,
                                              checkpoint_dir, cluster_resolver)

        if not hasattr(self._write_checkpoint_manager._checkpoint,
                       _ITERATION_VARIABLE):
            setattr(self._write_checkpoint_manager._checkpoint,
                    _ITERATION_VARIABLE, self._checkpointed_runs)

        if not hasattr(self._read_checkpoint_manager._checkpoint,
                       _ITERATION_VARIABLE):
            setattr(self._read_checkpoint_manager._checkpoint,
                    _ITERATION_VARIABLE, self._checkpointed_runs)

        self._read_checkpoint_manager.restore_or_initialize()

        # grace period countdown. Set to True for all workers once they finish
        # timing saving a checkpoint. Once entering this phase, new
        # preemption/maintenance notice will not be handled, since the whole cluster
        # goes down as the worker who first initiates the grace period goes down.
        self._final_checkpoint_countdown = False

        self._estimated_run_time = 0

        # An internal step counter that's restored to checkpointed_iterations when
        # training is restored. It increments by one every time
        # `PreemptionCheckpointHandler.run` is called. Note that in this case, the
        # user must pass a single-step training function to
        # `PreemptionCheckpointHandler.run` instead of a multiple-step one.
        self._run_counter = self._checkpointed_runs.numpy()

        # The worker itself has received preeption signal.
        self._received_own_sigterm = threading.Event()

        # Some member (could be oneself) has received preemption signal, and the
        # step number to save a checkpoint has been aligned.
        self._received_checkpoint_step = threading.Event()

        self._platform_device = gce_util.detect_platform()

        if self._platform_device in (gce_util.PlatformDevice.GCE_TPU,
                                     gce_util.PlatformDevice.GCE_CPU):
            # While running MultiWorkerMirroredStrategy training with GPUs and CPUs
            # are the same on Borg, GCE CPU VM and GPU VM are different in terms
            # of live migration, grace period, etc. We can make it work upon request.
            raise NotImplementedError(
                'PreemptionCheckpointHandler does not support '
                'training with TPU or CPU device on GCP.')

        completed_termination_config = _complete_config_for_environment(
            self._platform_device, termination_config)
        self._termination_watcher_fn = completed_termination_config.termination_watcher_fn
        self._exit_fn = completed_termination_config.exit_fn
        self._grace_period = completed_termination_config.grace_period

        # When training is interrupted, we explicitly call the cleanup methods for
        # the thread watching for local worker's termination signal and the thread
        # watching for clusterwise information before we save a checkpoint and exit.
        # In the final chapter of the training where no interruption is encountered,
        # we rely on __del__ to clean up. However, there is no guarantee when or
        # whether __del__ is executed, thus we make the threads daemon to avoid it
        # preventing program from exit.
        self._cluster_wise_termination_watcher_thread = threading.Thread(
            target=self._watch_step_to_save_key,
            name='PeerTerminationWatcher-%s' % self._id_in_cluster,
            daemon=True)
        logging.info('Start watcher for peer\'s signal.')
        self._cluster_wise_termination_watcher_thread.start()

        self._poll_termination_signal_thread = None

        if completed_termination_config.termination_watcher_fn:
            self._start_polling_for_termination_signal()
        else:
            self._start_watching_for_signal()
Esempio n. 23
0
    def __init__(self, cluster_resolver, checkpoint, checkpoint_dir):
        """Creates the failure handler.

    Args:
      cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`. You
        may also get it through the `cluster_resolver` attribute of the
        strategy in use.
      checkpoint: a `tf.train.Checkpoint` that will be saved upon preemption and
        loaded upon restart by the `CoordinatedCheckpointManager` API
        automatically.
      checkpoint_dir: a directory for the `CoordinatedCheckpointManager` to play
        with checkpoints. `CoordinatedCheckpointManager` will create a
        `tf.train.CheckpointManager` to manage the passed-in `checkpoint`. Since
        only one `tf.train.CheckpointManager` should be active in a particular
        directory at a time, this `checkpoint_dir` arg should preferably be
        separated from where the user saves their checkpoint for non-fault
        tolerance purpose.
    """
        self._cluster_resolver = cluster_resolver
        self._checkpoint = checkpoint
        self._id_in_cluster = str(
            multi_worker_util.id_in_cluster(
                self._cluster_resolver.cluster_spec(),
                self._cluster_resolver.task_type,
                self._cluster_resolver.task_id))

        # The number of calls to `CoordinatedCheckpointManager.run` when the latest
        # checkpoint was saved.
        self._checkpointed_runs = variables.Variable(
            initial_value=constant_op.constant(0, dtype=dtypes.int64),
            trainable=False,
            name=_ITERATION_VARIABLE)
        if not hasattr(self._checkpoint, _ITERATION_VARIABLE):
            setattr(self._checkpoint, _ITERATION_VARIABLE,
                    self._checkpointed_runs)

        # Make CheckpointManagers. MultiWorkerMirroredStrategy requires different
        # setup on chief and on other workers.
        self._read_checkpoint_manager = checkpoint_management.CheckpointManager(
            checkpoint, directory=checkpoint_dir, max_to_keep=1)
        if multi_worker_util.is_chief(
                cluster_spec=cluster_resolver.cluster_spec(),
                task_type=cluster_resolver.task_type,
                task_id=cluster_resolver.task_id):
            self._write_checkpoint_manager = self._read_checkpoint_manager
        else:
            self._write_checkpoint_manager = checkpoint_management.CheckpointManager(
                checkpoint,
                _mwms_write_checkpoint_dir(checkpoint_dir,
                                           cluster_resolver.task_type,
                                           cluster_resolver.task_id,
                                           cluster_resolver.cluster_spec()),
                max_to_keep=1)

        self._read_checkpoint_manager.restore_or_initialize()

        # An internal step counter that's restored to checkpointed_iterations when
        # training is restored. It increments by one every time
        # `CoordinatedCheckpointManager.run` is called. Note that in this case, the
        # user must pass a single-step training function to
        # `CoordinatedCheckpointManager.run` instead of a multiple-step one.
        self._run_counter = self._checkpointed_runs.numpy()

        # The worker itself has received preeption signal.
        self._received_own_sigterm = threading.Event()

        # Some member (could be oneself) has received preemption signal, and the
        # step number to save a checkpoint has been aligned.
        self._received_sigterm_and_step = threading.Event()

        # TODO(wxinyi): Enforce that only one instance of this class is created
        # per program.
        # TODO(wxinyi): make the thread non-daemon.
        threading.Thread(target=self._wait_for_signal, daemon=True).start()
        signal.signal(signal.SIGTERM, self._sigterm_handler_fn)
Esempio n. 24
0
    def __init__(self, cluster_resolver, checkpoint, checkpoint_dir):
        """Creates the failure handler.

    Args:
      cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`. You
        may also get it through the `cluster_resolver` attribute of the
        strategy in use.
      checkpoint: a `tf.train.Checkpoint` that will be saved upon preemption and
        loaded upon restart by the `CoordinatedCheckpointManager` API
        automatically.
      checkpoint_dir: a directory for the `CoordinatedCheckpointManager` to play
        with checkpoints. `CoordinatedCheckpointManager` will create a
        `tf.train.CheckpointManager` to manage the passed-in `checkpoint`. Since
        only one `tf.train.CheckpointManager` should be active in a particular
        directory at a time, this `checkpoint_dir` arg should preferably be
        separated from where the user saves their checkpoint for non-fault
        tolerance purpose.
    """
        self._cluster_resolver = cluster_resolver
        self._checkpoint = checkpoint
        self._id_in_cluster = str(
            multi_worker_util.id_in_cluster(
                self._cluster_resolver.cluster_spec(),
                self._cluster_resolver.task_type,
                self._cluster_resolver.task_id))

        # The number of calls to `CoordinatedCheckpointManager.run` when the latest
        # checkpoint was saved.
        self._checkpointed_runs = variables.Variable(
            initial_value=constant_op.constant(0, dtype=dtypes.int64),
            trainable=False,
            name=_ITERATION_VARIABLE)
        if not hasattr(self._checkpoint, _ITERATION_VARIABLE):
            setattr(self._checkpoint, _ITERATION_VARIABLE,
                    self._checkpointed_runs)

        # Make CheckpointManagers. MultiWorkerMirroredStrategy requires different
        # setup on chief and on other workers.
        self._read_checkpoint_manager = checkpoint_management.CheckpointManager(
            checkpoint, directory=checkpoint_dir, max_to_keep=1)
        if multi_worker_util.is_chief(
                cluster_spec=cluster_resolver.cluster_spec(),
                task_type=cluster_resolver.task_type,
                task_id=cluster_resolver.task_id):
            self._write_checkpoint_manager = self._read_checkpoint_manager
        else:
            self._write_checkpoint_manager = checkpoint_management.CheckpointManager(
                checkpoint,
                _mwms_write_checkpoint_dir(checkpoint_dir,
                                           cluster_resolver.task_type,
                                           cluster_resolver.task_id,
                                           cluster_resolver.cluster_spec()),
                max_to_keep=1)

        self._read_checkpoint_manager.restore_or_initialize()

        # An internal step counter that's restored to checkpointed_iterations when
        # training is restored. It increments by one every time
        # `CoordinatedCheckpointManager.run` is called. Note that in this case, the
        # user must pass a single-step training function to
        # `CoordinatedCheckpointManager.run` instead of a multiple-step one.
        self._run_counter = self._checkpointed_runs.numpy()

        # The worker itself has received preeption signal.
        self._received_own_sigterm = threading.Event()

        # Some member (could be oneself) has received preemption signal, and the
        # step number to save a checkpoint has been aligned.
        self._received_sigterm_and_step = threading.Event()

        # When training is interrupted, we explicitly call the cleanup methods for
        # the thread watching for local worker's termination signal and the thread
        # watching for clusterwise information before we save a checkpoint and exit.
        # In the final chapter of the training where no interruption is encountered,
        # we rely on __del__ to clean up. However, there is no guarantee when or
        # whether __del__ is executed, thus we make the threads daemon to avoid it
        # preventing program from exit.
        self._cluster_wise_termination_watcher_thread = threading.Thread(
            target=self._wait_for_signal,
            name='PeerTerminationWatcher-%s' % self._id_in_cluster,
            daemon=True)
        self._cluster_wise_termination_watcher_thread.start()

        self._poll_gce_signal_thread = None
        self._platform_device = gce_util.detect_platform()
        if self._platform_device is gce_util.PlatformDevice.GCE_GPU:
            self._start_polling_for_gce_signal()
            self._exit_code = gce_util._RESTARTABLE_EXIT_CODE
        elif self._platform_device is gce_util.PlatformDevice.INTERNAL:
            self._start_watching_for_signal()
            self._exit_code = _RESTARTABLE_EXIT_CODE
        else:
            raise NotImplementedError(
                'CoordinatedCheckpointManager is only supported'
                ' for MultiWorkerMirroredStrategy with GPU.')