def test_multiple_workers_preempted_consecutively(self, grace_period,
                                                      input_arg):
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)
        maintenance_event = multi_process_runner.manager().Event()
        training_finished = multi_process_runner.manager().Event()
        training_restarted = multi_process_runner.manager().Event()

        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/')

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        termination_config = failure_handling.TerminationConfig(
            grace_period=grace_period)
        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, input_arg, maintenance_event,
                  training_finished, True, training_restarted,
                  termination_config),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()

        # wait for all cluster to exit with a time out
        waiting_time = 0
        exit_process_count = 0
        # this addition to mitigate the fact that our step time is too short in test
        while exit_process_count != CLUSTER_SIZE and waiting_time < max(
                grace_period + 15, 40):
            exit_process_count = 0
            for worker_id in range(CLUSTER_SIZE):
                if not mpr.process_exists('worker', worker_id):
                    exit_process_count += 1
            waiting_time += 1
            time.sleep(1)

        if waiting_time == max(grace_period + 5, 40):
            raise RuntimeError(
                'Waited long but at least one worker still exist. '
                'Considering size of our model, this should not'
                ' happen.')

        maintenance_event.set()
        logging.info('restarting workers')
        training_restarted.set()
        for worker_id in range(CLUSTER_SIZE):
            mpr.start_single_process('worker', worker_id, cluster_spec)
        logging.info('workers restarted')

        mpr.join(timeout=250)
        self.assertTrue(training_finished.is_set())
Esempio n. 2
0
    def test_grace_period_continue_training(self, input_arg):
        grace_period = 5
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)
        training_started_event = multi_process_runner.manager().Event()
        training_restarted = multi_process_runner.manager().Event()
        training_finished = multi_process_runner.manager().Event()
        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        termination_config = failure_handling.TerminationConfig(
            grace_period=grace_period)
        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, input_arg,
                  [training_started_event], None, training_restarted,
                  training_finished, termination_config),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()
        while not training_started_event.is_set():
            time.sleep(1)

        killed_worker = random.randrange(0, CLUSTER_SIZE)
        logging.info('sending SIGTERM')
        os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM)
        logging.info('SIGTERM sent')

        # wait for all cluster within the given grace period (plus a buffer since
        # our per-step time here is too small)
        waiting_time = 0
        exit_process_count = 0
        while exit_process_count != CLUSTER_SIZE and waiting_time < grace_period + 10:
            exit_process_count = 0
            for worker_id in range(CLUSTER_SIZE):
                if not mpr.process_exists('worker', worker_id):
                    exit_process_count += 1
            waiting_time += 1
            time.sleep(1)

        if waiting_time == grace_period + 10:
            raise RuntimeError('Waited exceeding grace period. ')

        logging.info('restarting workers')
        training_restarted.set()
        for worker_id in range(CLUSTER_SIZE):
            mpr.start_single_process('worker', worker_id, cluster_spec)
        logging.info('workers restarted')

        mpr.join(timeout=250)
    def test_grace_period_continue_training(self, input_arg):
        grace_period = 7
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)

        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/')
        maintenance_event = multi_process_runner.manager().Event()
        training_finished = multi_process_runner.manager().Event()
        training_restarted = multi_process_runner.manager().Event()

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        termination_config = failure_handling.TerminationConfig(
            grace_period=grace_period)
        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, input_arg, maintenance_event,
                  training_finished, False, training_restarted,
                  termination_config),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()

        while (not maintenance_event.is_set()) and (
                not training_finished.is_set()):
            time.sleep(1)

        # this addition to mitigate the fact that our step time is too short in test
        time.sleep(grace_period + 10)
        if not training_finished.is_set():
            logging.info('restarting workers')
            training_restarted.set()
            for worker_id in range(CLUSTER_SIZE):
                mpr.start_single_process('worker', worker_id, cluster_spec)
            logging.info('workers restarted')

        mpr.join(timeout=250)

        self.assertTrue(training_finished.is_set())
Esempio n. 4
0
  def test_grace_period_continue_training(self, input_arg):
    grace_period = 5
    has_chief = False
    cluster_spec = multi_worker_test_base.create_cluster_spec(
        has_chief=has_chief,
        num_workers=CLUSTER_SIZE)
    training_started_event = multi_process_runner.manager().Event()
    training_restarted = multi_process_runner.manager().Event()
    training_finished = multi_process_runner.manager().Event()
    checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

    if _is_oss():
      rpc_layer = 'grpc'
    else:
      rpc_layer = 'grpc+loas'

    termination_config = failure_handling.TerminationConfig(
        grace_period=grace_period)
    mpr = multi_process_runner.MultiProcessRunner(
        self.worker_fn,
        cluster_spec,
        args=(checkpoint_dir, cluster_spec, input_arg, [training_started_event],
              None, training_restarted, training_finished, termination_config),
        rpc_layer=rpc_layer,
        return_output=True,
        dependence_on_chief=has_chief)

    logging.info('Cluster starting.')
    mpr.start()
    while not training_started_event.is_set():
      time.sleep(1)

    killed_worker = random.randrange(0, CLUSTER_SIZE)
    logging.info('sending SIGTERM')
    os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM)
    logging.info('SIGTERM sent')

    raise_if_not_all_exit(grace_period, mpr)

    logging.info('restarting workers')
    training_restarted.set()
    for worker_id in range(CLUSTER_SIZE):
      mpr.start_single_process('worker', worker_id, cluster_spec)
    logging.info('workers restarted')

    mpr.join(timeout=250)
Esempio n. 5
0
    def test_multiple_workers_preempted_consecutively(self, grace_period,
                                                      input_arg):
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)
        maintenance_event = multi_process_runner.manager().Event()
        training_finished = multi_process_runner.manager().Event()
        training_restarted = multi_process_runner.manager().Event()

        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/')

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        termination_config = failure_handling.TerminationConfig(
            grace_period=grace_period)
        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, input_arg, maintenance_event,
                  training_finished, True, training_restarted,
                  termination_config),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()

        raise_if_not_all_exit(grace_period, mpr)

        maintenance_event.set()
        logging.info('restarting workers')
        training_restarted.set()
        for worker_id in range(CLUSTER_SIZE):
            mpr.start_single_process('worker', worker_id, cluster_spec)
        logging.info('workers restarted')

        mpr.join(timeout=250)
        self.assertTrue(training_finished.is_set())
    def worker_fn(
        self,
        checkpoint_dir,
        cluster_spec,
        input_arg,
        maintenance_event=None,
        training_finished=None,
        frequent_send=False,
        training_restarted=None,
        termination_config=failure_handling.TerminationConfig(grace_period=0)):

        _enable_coordination_service(cluster_spec)
        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

        def mock_termination_watcher_function_gce(*args, **kwargs):
            del args, kwargs
            if not frequent_send:
                time.sleep(1)
                if (not maintenance_event.is_set()) and (random.randrange(
                        0, 7) == 5):
                    maintenance_event.set()
                    logging.info('Termination notice available.')
                    return True

            elif frequent_send and not maintenance_event.is_set():
                logging.info('Termination notice available.')
                return True

            return False

        with mock.patch.object(
                gce_util, 'termination_watcher_function_gce',
                mock_termination_watcher_function_gce), mock.patch.object(
                    gce_util, 'detect_platform',
                    lambda: gce_util.PlatformDevice.GCE_GPU):

            class Model(module.Module):
                def __init__(self):
                    self.v = variables_lib.Variable(
                        0.,
                        synchronization=variables_lib.VariableSynchronization.
                        ON_WRITE,
                        aggregation=variables_lib.VariableAggregation.SUM)

                @def_function.function(input_signature=[])
                def __call__(self):
                    return self.v.read_value()

            with strategy.scope():
                model = Model()
                fh_ckpt = tracking_util.Checkpoint(model=model)

                if input_arg == 'checkpoint':
                    checkpoint_or_manager = fh_ckpt
                else:
                    checkpoint_or_manager = _make_checkpoint_manager(
                        fh_ckpt, checkpoint_dir, strategy.cluster_resolver)
                preemption_handler = (
                    failure_handling.PreemptionCheckpointHandler(
                        strategy.cluster_resolver, checkpoint_or_manager,
                        checkpoint_dir, termination_config))

            def distributed_train_step(current_epoch, current_step):
                @def_function.function
                def train_step():
                    model.v.assign_add(constant_op.constant(1.))

                strategy.run(train_step)

                if current_step == STEPS_PER_EPOCH - 1:
                    logging.info('epoch %d finished', current_epoch)

            logging.info('Start training at %d',
                         preemption_handler.total_run_calls)

            # If the training process has been restarted, verify that the expected
            # number of checkpoints have been written.
            # We also want to check training_finished, because there's a corner case
            # where the signal is sent quite late and training finishes before the
            # grace period ends.
            if training_restarted.is_set() and not training_finished.is_set():
                match_group = [
                    re.search(r'.*ckpt-(\d+).index', a_file)
                    for a_file in gfile.ListDirectory(checkpoint_dir)
                ]
                checkpoint_index = [
                    a_match.group(1) for a_match in match_group if a_match
                ]
                if termination_config.grace_period > 0:
                    # Two checkpoints were saved for the extended grace period.
                    self.assertEqual(
                        max([
                            int(ckpt_index) for ckpt_index in checkpoint_index
                        ]), 2)
                else:
                    self.assertEqual(
                        max([
                            int(ckpt_index) for ckpt_index in checkpoint_index
                        ]), 1)

            for epoch in range(
                    preemption_handler.total_run_calls // STEPS_PER_EPOCH,
                    EPOCHS_TO_RUN):

                for step in range(
                        preemption_handler.total_run_calls % STEPS_PER_EPOCH,
                        STEPS_PER_EPOCH):
                    preemption_handler.run(distributed_train_step, epoch, step)

            logging.info('Training finished.')
            training_finished.set()

            self.assertEqual(
                model.v.numpy(), strategy.num_replicas_in_sync *
                EPOCHS_TO_RUN * STEPS_PER_EPOCH)

            running_threads = test_util.get_running_threads()
            if test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                    running_threads) and test_util.has_thread(
                                        _LOCAL_WATCHER_THREAD_PREFIX,
                                        running_threads):
                try:
                    # Explicitly call __del__ since making it None and gc.collect does
                    # not invoke __del__ here.
                    preemption_handler.__del__()

                    time.sleep(2)

                    running_threads = test_util.get_running_threads()
                    self.assertFalse(
                        test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX,
                                             running_threads))
                    self.assertFalse(
                        test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                             running_threads))

                except urllib.error.URLError as e:
                    if 'Temporary failure in name resolution' in e.message:
                        # This is caused by a weird flakiness that mock.patch does not
                        # correctly patch gce_util.request_compute_metadata, a real request
                        # is attempted, and an error is hit in
                        # gce_util.request_compute_metadata
                        logging.warning('Hit a mock issue.')
                        return
Esempio n. 7
0
    def worker_fn(self,
                  checkpoint_dir,
                  cluster_spec,
                  input_arg='checkpoint',
                  training_started_event=None,
                  raise_app_error_on_worker=None,
                  training_restarted=None,
                  training_finished=None,
                  termination_config=failure_handling.TerminationConfig()):

        _enable_coordination_service(cluster_spec)
        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

        class Model(module.Module):
            def __init__(self):
                self.v = variables_lib.Variable(
                    0.,
                    synchronization=variables_lib.VariableSynchronization.
                    ON_WRITE,
                    aggregation=variables_lib.VariableAggregation.SUM)

            @def_function.function(input_signature=[])
            def __call__(self):
                return self.v.read_value()

        with mock.patch.object(gce_util, 'on_gcp', lambda: False):

            with strategy.scope():
                model = Model()
                # Named it fh_ckpt because it'd be better that the user have their
                # regular checkpoint separate from the checkpoint for
                # PreemptionCheckpointHandler, since we will create CheckpointManager
                # to manage the checkpoint and only one CheckpointManager should be
                # active in a particular directory at a time.
                fh_ckpt = tracking_util.Checkpoint(model=model)
                if input_arg == 'checkpoint':
                    checkpoint_or_manager = fh_ckpt
                else:
                    checkpoint_or_manager = _make_checkpoint_manager(
                        fh_ckpt, checkpoint_dir, strategy.cluster_resolver)
                preemption_handler = (
                    failure_handling.PreemptionCheckpointHandler(
                        strategy.cluster_resolver, checkpoint_or_manager,
                        checkpoint_dir, termination_config))

            def distributed_train_step(current_epoch, current_step):
                @def_function.function
                def train_step():
                    if distribution_strategy_context.get_distribution_strategy(
                    ).cluster_resolver.task_id == raise_app_error_on_worker:
                        raise errors_impl.ResourceExhaustedError(
                            node_def=None,
                            op=None,
                            message='Running out of resources')

                    model.v.assign_add(constant_op.constant(1.))

                strategy.run(train_step)

                if current_step == STEPS_PER_EPOCH - 1:
                    logging.info('epoch %d finished', current_epoch)

            logging.info('Start training at %d',
                         preemption_handler.total_run_calls)

            # If the training process has been restarted, verify that the expected
            # number of checkpoints have been written.
            # we also want to check training_finished, because there's a corner case
            # where the signal is sent quite late and training finishes before the
            # grace period ends.
            if training_restarted and training_restarted.is_set(
            ) and not training_finished.is_set():
                logging.info('training restarted')
                match_group = [
                    re.search(r'.*ckpt-(\d+).index', a_file)
                    for a_file in gfile.ListDirectory(checkpoint_dir)
                ]
                checkpoint_index = [
                    a_match.group(1) for a_match in match_group if a_match
                ]
                if getattr(termination_config, 'grace_period', 0):
                    # Two checkpoints were saved for the extended grace period.
                    self.assertEqual(int(checkpoint_index[0]), 2)
                else:
                    self.assertEqual(int(checkpoint_index[0]), 1)

            for epoch in range(
                    preemption_handler.total_run_calls // STEPS_PER_EPOCH,
                    EPOCHS_TO_RUN):

                for step in range(
                        preemption_handler.total_run_calls % STEPS_PER_EPOCH,
                        STEPS_PER_EPOCH):
                    preemption_handler.run(distributed_train_step, epoch, step)
                # Add some randomness to when preemption actually happens. We should
                # trigger it for sure if the training is coming to an end and it hasn't
                # been triggered yet.
                if epoch >= EPOCHS_TO_RUN - 2:
                    trigger_it = True
                else:
                    trigger_it = False

                self._maybe_trigger_a_preemption(training_started_event,
                                                 trigger_it)

            training_finished.set()

            logging.info('Training finished.')

            self.assertEqual(
                model.v.numpy(), strategy.num_replicas_in_sync *
                EPOCHS_TO_RUN * STEPS_PER_EPOCH)
  def test_grace_period_continue_training(self, input_arg, mwms_mode):
    if _is_oss():
      rpc_layer = 'grpc'
    else:
      rpc_layer = 'grpc+loas'

    checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

    if mwms_mode == 'multi_worker':
      grace_period = 5
      termination_config = failure_handling.TerminationConfig(
          grace_period=grace_period)
      has_chief = False
      cluster_spec = multi_worker_test_base.create_cluster_spec(
          has_chief=has_chief,
          num_workers=CLUSTER_SIZE)
      training_started_event = multi_process_runner.manager().Event()
      training_restarted = multi_process_runner.manager().Event()
      training_finished = multi_process_runner.manager().Event()

      mpr = multi_process_runner.MultiProcessRunner(
          self.worker_fn,
          cluster_spec,
          args=(checkpoint_dir, cluster_spec, input_arg,
                [training_started_event], None, training_restarted,
                training_finished, termination_config),
          rpc_layer=rpc_layer,
          return_output=True,
          dependence_on_chief=has_chief)

      logging.info('Cluster starting.')
      mpr.start()
      while not training_started_event.is_set():
        time.sleep(1)

      killed_worker = random.randrange(0, CLUSTER_SIZE)
      logging.info('sending SIGTERM')
      os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM)
      logging.info('SIGTERM sent')

      raise_if_not_all_exit(grace_period, mpr)

      logging.info('restarting workers')
      training_restarted.set()
      for worker_id in range(CLUSTER_SIZE):
        mpr.start_single_process('worker', worker_id, cluster_spec)
      logging.info('workers restarted')

      mpr.join(timeout=250)

    else:
      # This is because single worker trains super fast with regards to the size
      # of "model" here. With a longer grace period, the training just finishes
      # within the grace period so we can't verify the exit behavior.
      grace_period = 1
      termination_config = failure_handling.TerminationConfig(
          grace_period=grace_period)
      cluster_spec = server_lib.ClusterSpec({})

      training_started_event = threading.Event()
      training_restarted = threading.Event()
      training_finished = threading.Event()
      def sending_sigterm(training_started_event):
        while not training_started_event.is_set():
          time.sleep(1)
        logging.info('sending sigterm')
        training_started_event.set()
        os.kill(os.getpid(), signal.SIGTERM)

      preemption_sender_thread = threading.Thread(
          target=sending_sigterm, args=(training_started_event,))
      preemption_sender_thread.start()

      caught_exit = False
      try:
        self.worker_fn(checkpoint_dir, cluster_spec, input_arg,
                       [training_started_event], None, training_restarted,
                       training_finished, termination_config)

      except SystemExit as exit_error:
        caught_exit = True
        # We cannot use assertRaise instead, since termination is not always
        # triggered.
        self.assertEqual(exit_error.code, 42)  # pylint: disable=g-assert-in-except

      preemption_sender_thread.join(10)
      if not training_finished.is_set():
        self.assertTrue(caught_exit)

        logging.info('restarting workers')
        training_restarted.set()
        self.worker_fn(checkpoint_dir, cluster_spec, input_arg,
                       [training_started_event], None, training_restarted,
                       training_finished, termination_config)
Esempio n. 9
0
    def worker_fn(self,
                  checkpoint_dir,
                  cluster_spec,
                  training_started_event=None,
                  raise_app_error_on_worker=None,
                  termination_config=failure_handling.TerminationConfig()):

        _enable_coordination_service(cluster_spec)
        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

        class Model(module.Module):
            def __init__(self):
                self.v = variables_lib.Variable(
                    0.,
                    synchronization=variables_lib.VariableSynchronization.
                    ON_WRITE,
                    aggregation=variables_lib.VariableAggregation.SUM)

            @def_function.function(input_signature=[])
            def __call__(self):
                return self.v.read_value()

        with mock.patch.object(gce_util, 'on_gcp', lambda: False):

            with strategy.scope():
                model = Model()
                # Named it fh_ckpt because it'd be better that the user have their
                # regular checkpoint separate from the checkpoint for
                # WorkerPreemptionHandler, since we will create CheckpointManager
                # to manage the checkpoint and only one CheckpointManager should be
                # active in a particular directory at a time.
                fh_ckpt = tracking_util.Checkpoint(model=model)

                worker_preemption_watcher = failure_handling.WorkerPreemptionHandler(
                    strategy.cluster_resolver, fh_ckpt, checkpoint_dir,
                    termination_config)

            def distributed_train_step(current_epoch, current_step):
                @def_function.function
                def train_step():
                    if distribution_strategy_context.get_distribution_strategy(
                    ).cluster_resolver.task_id == raise_app_error_on_worker:
                        raise errors_impl.ResourceExhaustedError(
                            node_def=None,
                            op=None,
                            message='Running out of resources')

                    model.v.assign_add(constant_op.constant(1.))

                strategy.run(train_step)

                if current_step == STEPS_PER_EPOCH - 1:
                    logging.info('epoch %d finished', current_epoch)

            logging.info('Restored training at %d',
                         worker_preemption_watcher.total_runs)
            for epoch in range(
                    worker_preemption_watcher.total_runs // STEPS_PER_EPOCH,
                    EPOCHS_TO_RUN):

                for step in range(
                        worker_preemption_watcher.total_runs % STEPS_PER_EPOCH,
                        STEPS_PER_EPOCH):
                    worker_preemption_watcher.run(distributed_train_step,
                                                  epoch, step)
                # Add some randomness to when preemption actually happens. We should
                # trigger it for sure if the training is coming to an end and it hasn't
                # been triggered yet.
                if epoch >= EPOCHS_TO_RUN - 2:
                    trigger_it = True
                else:
                    trigger_it = False

                self._maybe_trigger_a_preemption(training_started_event,
                                                 trigger_it)

            logging.info('Training finished.')

            self.assertEqual(
                model.v.numpy(), strategy.num_replicas_in_sync *
                EPOCHS_TO_RUN * STEPS_PER_EPOCH)
Esempio n. 10
0
    def test_grace_period_continue_training(self):
        grace_period = 5
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)
        training_started_event = multi_process_runner.manager().Event()

        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        termination_config = failure_handling.TerminationConfig(
            time_till_termination=grace_period)
        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, [training_started_event], None,
                  termination_config),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()
        while not training_started_event.is_set():
            time.sleep(1)

        killed_worker = random.randrange(0, CLUSTER_SIZE)
        logging.info('sending SIGTERM')
        os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM)
        logging.info('SIGTERM sent')

        # wait for all cluster within the given grace period (plus a buffer since
        # our per-step time here is too small)
        waiting_time = 0
        exit_process_count = 0
        while exit_process_count != CLUSTER_SIZE and waiting_time < grace_period + 10:
            exit_process_count = 0
            for worker_id in range(CLUSTER_SIZE):
                if not mpr.process_exists('worker', worker_id):
                    exit_process_count += 1
            waiting_time += 1
            time.sleep(1)

        if waiting_time == grace_period + 10:
            raise RuntimeError('Waited exceeding grace period. ')

        logging.info('restarting workers')
        for worker_id in range(CLUSTER_SIZE):
            mpr.start_single_process('worker', worker_id, cluster_spec)
        logging.info('workers restarted')

        stdout = mpr.join(timeout=250).stdout
        all_start_point = []
        checkpoint_count = []
        for msg in stdout:
            # TODO(wxinyi): remove the string matching and assert checkpoint number.
            matched_group = re.search(r'.*Restored training at (\d+)', msg)
            checkpoint_group = re.search(r'.*RUN_TO_CHECKPOINT set to (\d+)',
                                         msg)

            if matched_group:
                all_start_point.append(int(matched_group.group(1)))

            if checkpoint_group:
                checkpoint_count.append(int(checkpoint_group.group(1)))

        # remove duplicate logs created due to presence of multiple workers
        start_points = all_start_point[::CLUSTER_SIZE]

        # assert that after restarting, we don't repeat previous training steps
        self.assertNotEqual(start_points[-1], 0)

        # One for timing, another for final call.
        self.assertLen(set(checkpoint_count), 2)
Esempio n. 11
0
    def worker_fn(self,
                  checkpoint_dir,
                  cluster_spec,
                  maintenance_event=None,
                  training_finished=None,
                  frequent_send=False,
                  termination_config=failure_handling.TerminationConfig()):

        _enable_coordination_service(cluster_spec)
        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

        def mock_termination_watcher_function_gce(*args, **kwargs):
            del args, kwargs
            if not frequent_send:
                time.sleep(1)
                if (not maintenance_event.is_set()) and (random.randrange(
                        0, 20) > 18):
                    maintenance_event.set()
                    logging.info('Termination notice available.')
                    return True

            elif frequent_send and not maintenance_event.is_set():
                logging.info('Termination notice available.')
                return True

            return False

        with mock.patch.object(
                gce_util, 'termination_watcher_function_gce',
                mock_termination_watcher_function_gce), mock.patch.object(
                    gce_util, 'detect_platform',
                    lambda: gce_util.PlatformDevice.GCE_GPU):

            class Model(module.Module):
                def __init__(self):
                    self.v = variables_lib.Variable(
                        0.,
                        synchronization=variables_lib.VariableSynchronization.
                        ON_WRITE,
                        aggregation=variables_lib.VariableAggregation.SUM)

                @def_function.function(input_signature=[])
                def __call__(self):
                    return self.v.read_value()

            with strategy.scope():
                model = Model()
                fh_ckpt = tracking_util.Checkpoint(model=model)

                worker_preemption_watcher = failure_handling.WorkerPreemptionHandler(
                    strategy.cluster_resolver, fh_ckpt, checkpoint_dir,
                    termination_config)

            def distributed_train_step(current_epoch, current_step):
                @def_function.function
                def train_step():
                    model.v.assign_add(constant_op.constant(1.))

                strategy.run(train_step)

                if current_step == STEPS_PER_EPOCH - 1:
                    logging.info('epoch %d finished', current_epoch)

            logging.info('Start training at %d',
                         worker_preemption_watcher.total_runs)
            for epoch in range(
                    worker_preemption_watcher.total_runs // STEPS_PER_EPOCH,
                    EPOCHS_TO_RUN):

                for step in range(
                        worker_preemption_watcher.total_runs % STEPS_PER_EPOCH,
                        STEPS_PER_EPOCH):
                    worker_preemption_watcher.run(distributed_train_step,
                                                  epoch, step)

            logging.info('Training finished.')
            training_finished.set()

            self.assertEqual(
                model.v.numpy(), strategy.num_replicas_in_sync *
                EPOCHS_TO_RUN * STEPS_PER_EPOCH)

            running_threads = test_util.get_running_threads()
            if test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                    running_threads) and test_util.has_thread(
                                        _LOCAL_WATCHER_THREAD_PREFIX,
                                        running_threads):
                try:
                    # Explicitly call __del__ since making it None and gc.collect does
                    # not invoke __del__ here.
                    worker_preemption_watcher.__del__()

                    time.sleep(2)

                    running_threads = test_util.get_running_threads()
                    self.assertFalse(
                        test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX,
                                             running_threads))
                    self.assertFalse(
                        test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                             running_threads))

                except urllib.error.URLError as e:
                    if 'Temporary failure in name resolution' in e.message:
                        # This is caused by a weird flakiness that mock.patch does not
                        # correctly patch gce_util.request_compute_metadata, a real request
                        # is attempted, and an error is hit in
                        # gce_util.request_compute_metadata
                        logging.warning('Hit a mock issue.')
                        return
Esempio n. 12
0
    def test_grace_period_continue_training(self):
        grace_period = 7
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)

        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')
        maintenance_event = multi_process_runner.manager().Event()
        training_finished = multi_process_runner.manager().Event()

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        termination_config = failure_handling.TerminationConfig(
            time_till_termination=grace_period)
        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, maintenance_event,
                  training_finished, False, termination_config),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()

        while (not maintenance_event.is_set()) and (
                not training_finished.is_set()):
            time.sleep(1)

        # this addition to mitigate the fact that our step time is too short in test
        time.sleep(grace_period + 10)
        if not training_finished.is_set():
            logging.info('restarting workers')
            for worker_id in range(CLUSTER_SIZE):
                mpr.start_single_process('worker', worker_id, cluster_spec)
            logging.info('workers restarted')

        if maintenance_event.is_set():
            stdout = mpr.join(timeout=250).stdout
            all_start_point = []
            checkpoint_count = []
            for msg in stdout:
                matched_group = re.search(r'.*Start training at (\d+)', msg)
                checkpoint_group = re.search(
                    r'.*RUN_TO_CHECKPOINT set to (\d+)', msg)

                if matched_group:
                    all_start_point.append(int(matched_group.group(1)))

                if checkpoint_group:
                    checkpoint_count.append(int(checkpoint_group.group(1)))

            # remove duplicate logs created due to presence of multiple workers
            start_points = all_start_point[::CLUSTER_SIZE]

            # if maintenance_event is set at the very end of training and training
            # completes, there won't be a restart.
            if len(start_points) > 1:
                # assert that after restarting, we don't repeat previous training steps
                self.assertNotEqual(start_points[-1], 0)

                # One for timing, another for final call.
                self.assertLen(set(checkpoint_count), 2)
Esempio n. 13
0
    def test_multiple_workers_preempted_consecutively(self, grace_period):
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)
        maintenance_event = multi_process_runner.manager().Event()
        training_finished = multi_process_runner.manager().Event()

        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        termination_config = failure_handling.TerminationConfig(
            time_till_termination=grace_period)
        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, maintenance_event,
                  training_finished, True, termination_config),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()

        # wait for all cluster to exit with a time out
        waiting_time = 0
        exit_process_count = 0
        # this addition to mitigate the fact that our step time is too short in test
        while exit_process_count != CLUSTER_SIZE and waiting_time < max(
                grace_period + 15, 40):
            exit_process_count = 0
            for worker_id in range(CLUSTER_SIZE):
                if not mpr.process_exists('worker', worker_id):
                    exit_process_count += 1
            waiting_time += 1
            time.sleep(1)

        if waiting_time == max(grace_period + 5, 40):
            raise RuntimeError(
                'Waited long but at least one worker still exist. '
                'Considering size of our model, this should not'
                ' happen.')

        maintenance_event.set()
        logging.info('restarting workers')
        for worker_id in range(CLUSTER_SIZE):
            mpr.start_single_process('worker', worker_id, cluster_spec)
        logging.info('workers restarted')

        stdout = mpr.join(timeout=250).stdout
        found_message = 0
        checkpoint_count = []
        for msg in stdout:
            matched_group = re.search(r'.*has received termination notice*',
                                      msg)
            checkpoint_group = re.search(r'.*RUN_TO_CHECKPOINT set to (\d+)',
                                         msg)
            if matched_group:
                found_message += 1
            if checkpoint_group:
                checkpoint_count.append(int(checkpoint_group.group(1)))

        self.assertGreaterEqual(found_message, 1)
        if grace_period > 0:
            self.assertLen(set(checkpoint_count), 2)
Esempio n. 14
0
    def test_grace_period_continue_training(self, input_arg, mwms_mode):
        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/')
        grace_period = 7
        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        termination_config = failure_handling.TerminationConfig(
            grace_period=grace_period)

        if mwms_mode == 'multi_worker':
            has_chief = False
            cluster_spec = multi_worker_test_base.create_cluster_spec(
                has_chief=has_chief, num_workers=CLUSTER_SIZE)

            checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/')
            maintenance_event = multi_process_runner.manager().Event()
            training_finished = multi_process_runner.manager().Event()
            training_restarted = multi_process_runner.manager().Event()

            mpr = multi_process_runner.MultiProcessRunner(
                self.worker_fn,
                cluster_spec,
                args=(checkpoint_dir, cluster_spec, input_arg,
                      maintenance_event, training_finished, False,
                      training_restarted, termination_config),
                rpc_layer=rpc_layer,
                return_output=True,
                dependence_on_chief=has_chief)

            logging.info('Cluster starting.')
            mpr.start()

            while (not maintenance_event.is_set()) and (
                    not training_finished.is_set()):
                time.sleep(1)

            raise_if_not_all_exit(grace_period, mpr)

            if not training_finished.is_set():
                logging.info('restarting workers')
                training_restarted.set()
                for worker_id in range(CLUSTER_SIZE):
                    mpr.start_single_process('worker', worker_id, cluster_spec)
                logging.info('workers restarted')

            mpr.join(timeout=250)

            self.assertTrue(training_finished.is_set())

        else:
            maintenance_event = threading.Event()
            training_finished = threading.Event()
            training_restarted = threading.Event()

            cluster_spec = server_lib.ClusterSpec({})
            caught_exit = False
            try:
                self.worker_fn(checkpoint_dir, cluster_spec, input_arg,
                               maintenance_event, training_finished, False,
                               training_restarted, termination_config)
            except SystemExit as exit_error:
                caught_exit = True
                # We cannot use assertRaise instead, since termination is not always
                # triggered.
                self.assertEqual(exit_error.code, 143)  # pylint:disable=g-assert-in-except

            if maintenance_event.is_set() and not training_finished.is_set():
                self.assertTrue(caught_exit)

                logging.info('restarting workers')
                training_restarted.set()
                self.worker_fn(checkpoint_dir, cluster_spec, input_arg,
                               maintenance_event, training_finished, False,
                               training_restarted, termination_config)
                self.assertTrue(training_finished.is_set())