Пример #1
0
  def _ensure_threads_closed(self):
    """Ensures worker and preemption threads are closed."""
    # Worker and preemption threads should exist before releasing
    # ClusterCoordinator.
    running_threads = test_util.get_running_threads()
    self.assertTrue(
        test_util.has_thread(_WORKER_THREAD_PREFIX, running_threads))
    self.assertIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads)

    # Print object graph if ClusterCoordinator may leak.
    if sys.getrefcount(self.cluster_coord) > 2:
      try:
        test_util.show_backref(self.cluster_coord)
      except:  # pylint: disable=bare-except
        pass

    # Wait for threads to close.
    self.cluster_coord = None
    self.strategy = None
    gc.collect()
    time.sleep(1)

    # Verify thread names.
    running_threads = test_util.get_running_threads()
    self.assertNotIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads)
    self.assertFalse(
        test_util.has_thread(_WORKER_THREAD_PREFIX, running_threads),
        "Worker thread is not stopped properly.")
Пример #2
0
  def worker_fn(self,
                checkpoint_dir,
                cluster_spec,
                maintenance_event=None,
                training_finished=None,
                frequent_send=False):

    _enable_coordination_service(cluster_spec)
    strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

    def mock_termination_watcher_function_gce(*args, **kwargs):
      del args, kwargs
      if not frequent_send:
        time.sleep(1)
        if (not maintenance_event.is_set()) and (random.randrange(0, 20) > 18):
          maintenance_event.set()
          logging.info('Termination notice available.')
          return True

      elif frequent_send and not maintenance_event.is_set():
        logging.info('Termination notice available.')
        return True

      return False

    with mock.patch.object(
        gce_util, 'termination_watcher_function_gce',
        mock_termination_watcher_function_gce), mock.patch.object(
            gce_util, 'detect_platform',
            lambda: gce_util.PlatformDevice.GCE_GPU):

      class Model(module.Module):

        def __init__(self):
          self.v = variables_lib.Variable(
              0.,
              synchronization=variables_lib.VariableSynchronization.ON_WRITE,
              aggregation=variables_lib.VariableAggregation.SUM)

        @def_function.function(input_signature=[])
        def __call__(self):
          return self.v.read_value()

      with strategy.scope():
        model = Model()
        fh_ckpt = tracking_util.Checkpoint(model=model)

        worker_preemption_watcher = failure_handling.WorkerPreemptionHandler(
            strategy.cluster_resolver, fh_ckpt, checkpoint_dir)

      def distributed_train_step(current_epoch, current_step):

        @def_function.function
        def train_step():
          model.v.assign_add(constant_op.constant(1.))

        strategy.run(train_step)

        if current_step == STEPS_PER_EPOCH - 1:
          logging.info('epoch %d finished', current_epoch)

      logging.info('Start training at %d', worker_preemption_watcher.total_runs)
      for epoch in range(
          worker_preemption_watcher.total_runs // STEPS_PER_EPOCH,
          EPOCHS_TO_RUN):

        for step in range(
            worker_preemption_watcher.total_runs % STEPS_PER_EPOCH,
            STEPS_PER_EPOCH):
          worker_preemption_watcher.run(distributed_train_step, epoch, step)

      training_finished.set()

      self.assertEqual(
          model.v.numpy(),
          strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH)

      running_threads = test_util.get_running_threads()
      if test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                              running_threads) and test_util.has_thread(
                                  _LOCAL_WATCHER_THREAD_PREFIX,
                                  running_threads):
        try:
          # Explicitly call __del__ since making it None and gc.collect does
          # not invoke __del__ here.
          worker_preemption_watcher.__del__()

          time.sleep(2)

          running_threads = test_util.get_running_threads()
          self.assertFalse(
              test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX,
                                   running_threads))
          self.assertFalse(
              test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                   running_threads))

        except urllib.error.URLError as e:
          if 'Temporary failure in name resolution' in e.message:
            # This is caused by a weird flakiness that mock.patch does not
            # correctly patch gce_util.request_compute_metadata, a real request
            # is attempted, and an error is hit in
            # gce_util.request_compute_metadata
            logging.warning('Hit a mock issue.')
            return
    def worker_fn(
        self,
        checkpoint_dir,
        cluster_spec,
        input_arg,
        maintenance_event=None,
        training_finished=None,
        frequent_send=False,
        training_restarted=None,
        termination_config=failure_handling.TerminationConfig(grace_period=0)):

        _enable_coordination_service(cluster_spec)
        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

        def mock_termination_watcher_function_gce(*args, **kwargs):
            del args, kwargs
            if not frequent_send:
                time.sleep(1)
                if (not maintenance_event.is_set()) and (random.randrange(
                        0, 7) == 5):
                    maintenance_event.set()
                    logging.info('Termination notice available.')
                    return True

            elif frequent_send and not maintenance_event.is_set():
                logging.info('Termination notice available.')
                return True

            return False

        with mock.patch.object(
                gce_util, 'termination_watcher_function_gce',
                mock_termination_watcher_function_gce), mock.patch.object(
                    gce_util, 'detect_platform',
                    lambda: gce_util.PlatformDevice.GCE_GPU):

            class Model(module.Module):
                def __init__(self):
                    self.v = variables_lib.Variable(
                        0.,
                        synchronization=variables_lib.VariableSynchronization.
                        ON_WRITE,
                        aggregation=variables_lib.VariableAggregation.SUM)

                @def_function.function(input_signature=[])
                def __call__(self):
                    return self.v.read_value()

            with strategy.scope():
                model = Model()
                fh_ckpt = tracking_util.Checkpoint(model=model)

                if input_arg == 'checkpoint':
                    checkpoint_or_manager = fh_ckpt
                else:
                    checkpoint_or_manager = _make_checkpoint_manager(
                        fh_ckpt, checkpoint_dir, strategy.cluster_resolver)
                preemption_handler = (
                    failure_handling.PreemptionCheckpointHandler(
                        strategy.cluster_resolver, checkpoint_or_manager,
                        checkpoint_dir, termination_config))

            def distributed_train_step(current_epoch, current_step):
                @def_function.function
                def train_step():
                    model.v.assign_add(constant_op.constant(1.))

                strategy.run(train_step)

                if current_step == STEPS_PER_EPOCH - 1:
                    logging.info('epoch %d finished', current_epoch)

            logging.info('Start training at %d',
                         preemption_handler.total_run_calls)

            # If the training process has been restarted, verify that the expected
            # number of checkpoints have been written.
            # We also want to check training_finished, because there's a corner case
            # where the signal is sent quite late and training finishes before the
            # grace period ends.
            if training_restarted.is_set() and not training_finished.is_set():
                match_group = [
                    re.search(r'.*ckpt-(\d+).index', a_file)
                    for a_file in gfile.ListDirectory(checkpoint_dir)
                ]
                checkpoint_index = [
                    a_match.group(1) for a_match in match_group if a_match
                ]
                if termination_config.grace_period > 0:
                    # Two checkpoints were saved for the extended grace period.
                    self.assertEqual(
                        max([
                            int(ckpt_index) for ckpt_index in checkpoint_index
                        ]), 2)
                else:
                    self.assertEqual(
                        max([
                            int(ckpt_index) for ckpt_index in checkpoint_index
                        ]), 1)

            for epoch in range(
                    preemption_handler.total_run_calls // STEPS_PER_EPOCH,
                    EPOCHS_TO_RUN):

                for step in range(
                        preemption_handler.total_run_calls % STEPS_PER_EPOCH,
                        STEPS_PER_EPOCH):
                    preemption_handler.run(distributed_train_step, epoch, step)

            logging.info('Training finished.')
            training_finished.set()

            self.assertEqual(
                model.v.numpy(), strategy.num_replicas_in_sync *
                EPOCHS_TO_RUN * STEPS_PER_EPOCH)

            running_threads = test_util.get_running_threads()
            if test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                    running_threads) and test_util.has_thread(
                                        _LOCAL_WATCHER_THREAD_PREFIX,
                                        running_threads):
                try:
                    # Explicitly call __del__ since making it None and gc.collect does
                    # not invoke __del__ here.
                    preemption_handler.__del__()

                    time.sleep(2)

                    running_threads = test_util.get_running_threads()
                    self.assertFalse(
                        test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX,
                                             running_threads))
                    self.assertFalse(
                        test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                             running_threads))

                except urllib.error.URLError as e:
                    if 'Temporary failure in name resolution' in e.message:
                        # This is caused by a weird flakiness that mock.patch does not
                        # correctly patch gce_util.request_compute_metadata, a real request
                        # is attempted, and an error is hit in
                        # gce_util.request_compute_metadata
                        logging.warning('Hit a mock issue.')
                        return
    def worker_fn(self,
                  checkpoint_dir,
                  cluster_spec,
                  maintenance_event,
                  training_finished,
                  frequent_send=False):

        _enable_coordination_service(cluster_spec)
        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

        def mock_request_compute_metadata(*args, **kwargs):
            del kwargs  # Unused.
            if args[0] == 'instance/maintenance-event':
                if not frequent_send:
                    time.sleep(1)
                    if (not maintenance_event.is_set()) and (random.randrange(
                            0, 20) > 18):
                        maintenance_event.set()
                        logging.info('Maintenance notice available.')
                        return 'TERMINATE_ON_HOST_MAINTENANCE'
                elif frequent_send and not maintenance_event.is_set():
                    return 'TERMINATE_ON_HOST_MAINTENANCE'

            return 'NONE'

        with mock.patch.object(
                gce_util, 'request_compute_metadata',
                mock_request_compute_metadata), mock.patch.object(
                    gce_util, 'detect_platform',
                    lambda: gce_util.PlatformDevice.GCE_GPU):

            class Model(module.Module):
                def __init__(self):
                    self.v = variables_lib.Variable(
                        0.,
                        synchronization=variables_lib.VariableSynchronization.
                        ON_WRITE,
                        aggregation=variables_lib.VariableAggregation.SUM)

                @def_function.function(input_signature=[])
                def __call__(self):
                    return self.v.read_value()

            with strategy.scope():
                model = Model()
                fh_ckpt = tracking_util.Checkpoint(model=model)

                failure_handler = failure_handling.CoordinatedCheckpointManager(
                    strategy.cluster_resolver, fh_ckpt, checkpoint_dir)

            def distributed_train_step(current_epoch, current_step):
                @def_function.function
                def train_step():
                    model.v.assign_add(constant_op.constant(1.))

                strategy.run(train_step)

                if current_step == STEPS_PER_EPOCH - 1:
                    logging.info('epoch %d finished', current_epoch)

            logging.info('Start training at %d', failure_handler.total_runs)
            for epoch in range(failure_handler.total_runs // STEPS_PER_EPOCH,
                               EPOCHS_TO_RUN):

                for step in range(failure_handler.total_runs % STEPS_PER_EPOCH,
                                  STEPS_PER_EPOCH):
                    failure_handler.run(distributed_train_step, epoch, step)

            self.assertEqual(
                model.v.numpy(), strategy.num_replicas_in_sync *
                EPOCHS_TO_RUN * STEPS_PER_EPOCH)

            training_finished.set()

            running_threads = test_util.get_running_threads()
            strategy.gather(constant_op.constant([10]), axis=0)
            self.assertTrue(
                test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                     running_threads))
            self.assertTrue(
                test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX,
                                     running_threads))

            strategy.gather(constant_op.constant([10]), axis=0)

            # Explicitly call __del__ since making it None and gc.collect does
            # not invoke __del__ here.
            failure_handler.__del__()

            time.sleep(2)

            running_threads = test_util.get_running_threads()
            self.assertFalse(
                test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX,
                                     running_threads))
            self.assertFalse(
                test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                     running_threads))