def worker_fn(
        self,
        checkpoint_dir,
        cluster_spec,
        input_arg,
        maintenance_event=None,
        training_finished=None,
        frequent_send=False,
        training_restarted=None,
        termination_config=failure_handling.TerminationConfig(grace_period=0)):

        _enable_coordination_service(cluster_spec)
        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

        def mock_termination_watcher_function_gce(*args, **kwargs):
            del args, kwargs
            if not frequent_send:
                time.sleep(1)
                if (not maintenance_event.is_set()) and (random.randrange(
                        0, 7) == 5):
                    maintenance_event.set()
                    logging.info('Termination notice available.')
                    return True

            elif frequent_send and not maintenance_event.is_set():
                logging.info('Termination notice available.')
                return True

            return False

        with mock.patch.object(
                gce_util, 'termination_watcher_function_gce',
                mock_termination_watcher_function_gce), mock.patch.object(
                    gce_util, 'detect_platform',
                    lambda: gce_util.PlatformDevice.GCE_GPU):

            class Model(module.Module):
                def __init__(self):
                    self.v = variables_lib.Variable(
                        0.,
                        synchronization=variables_lib.VariableSynchronization.
                        ON_WRITE,
                        aggregation=variables_lib.VariableAggregation.SUM)

                @def_function.function(input_signature=[])
                def __call__(self):
                    return self.v.read_value()

            with strategy.scope():
                model = Model()
                fh_ckpt = tracking_util.Checkpoint(model=model)

                if input_arg == 'checkpoint':
                    checkpoint_or_manager = fh_ckpt
                else:
                    checkpoint_or_manager = _make_checkpoint_manager(
                        fh_ckpt, checkpoint_dir, strategy.cluster_resolver)
                preemption_handler = (
                    failure_handling.PreemptionCheckpointHandler(
                        strategy.cluster_resolver, checkpoint_or_manager,
                        checkpoint_dir, termination_config))

            def distributed_train_step(current_epoch, current_step):
                @def_function.function
                def train_step():
                    model.v.assign_add(constant_op.constant(1.))

                strategy.run(train_step)

                if current_step == STEPS_PER_EPOCH - 1:
                    logging.info('epoch %d finished', current_epoch)

            logging.info('Start training at %d',
                         preemption_handler.total_run_calls)

            # If the training process has been restarted, verify that the expected
            # number of checkpoints have been written.
            # We also want to check training_finished, because there's a corner case
            # where the signal is sent quite late and training finishes before the
            # grace period ends.
            if training_restarted.is_set() and not training_finished.is_set():
                match_group = [
                    re.search(r'.*ckpt-(\d+).index', a_file)
                    for a_file in gfile.ListDirectory(checkpoint_dir)
                ]
                checkpoint_index = [
                    a_match.group(1) for a_match in match_group if a_match
                ]
                if termination_config.grace_period > 0:
                    # Two checkpoints were saved for the extended grace period.
                    self.assertEqual(
                        max([
                            int(ckpt_index) for ckpt_index in checkpoint_index
                        ]), 2)
                else:
                    self.assertEqual(
                        max([
                            int(ckpt_index) for ckpt_index in checkpoint_index
                        ]), 1)

            for epoch in range(
                    preemption_handler.total_run_calls // STEPS_PER_EPOCH,
                    EPOCHS_TO_RUN):

                for step in range(
                        preemption_handler.total_run_calls % STEPS_PER_EPOCH,
                        STEPS_PER_EPOCH):
                    preemption_handler.run(distributed_train_step, epoch, step)

            logging.info('Training finished.')
            training_finished.set()

            self.assertEqual(
                model.v.numpy(), strategy.num_replicas_in_sync *
                EPOCHS_TO_RUN * STEPS_PER_EPOCH)

            running_threads = test_util.get_running_threads()
            if test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                    running_threads) and test_util.has_thread(
                                        _LOCAL_WATCHER_THREAD_PREFIX,
                                        running_threads):
                try:
                    # Explicitly call __del__ since making it None and gc.collect does
                    # not invoke __del__ here.
                    preemption_handler.__del__()

                    time.sleep(2)

                    running_threads = test_util.get_running_threads()
                    self.assertFalse(
                        test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX,
                                             running_threads))
                    self.assertFalse(
                        test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                             running_threads))

                except urllib.error.URLError as e:
                    if 'Temporary failure in name resolution' in e.message:
                        # This is caused by a weird flakiness that mock.patch does not
                        # correctly patch gce_util.request_compute_metadata, a real request
                        # is attempted, and an error is hit in
                        # gce_util.request_compute_metadata
                        logging.warning('Hit a mock issue.')
                        return
Ejemplo n.º 2
0
    def worker_fn(self,
                  checkpoint_dir,
                  cluster_spec,
                  input_arg='checkpoint',
                  training_started_event=None,
                  raise_app_error_on_worker=None,
                  training_restarted=None,
                  training_finished=None,
                  termination_config=failure_handling.TerminationConfig()):

        _enable_coordination_service(cluster_spec)
        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

        class Model(module.Module):
            def __init__(self):
                self.v = variables_lib.Variable(
                    0.,
                    synchronization=variables_lib.VariableSynchronization.
                    ON_WRITE,
                    aggregation=variables_lib.VariableAggregation.SUM)

            @def_function.function(input_signature=[])
            def __call__(self):
                return self.v.read_value()

        with mock.patch.object(gce_util, 'on_gcp', lambda: False):

            with strategy.scope():
                model = Model()
                # Named it fh_ckpt because it'd be better that the user have their
                # regular checkpoint separate from the checkpoint for
                # PreemptionCheckpointHandler, since we will create CheckpointManager
                # to manage the checkpoint and only one CheckpointManager should be
                # active in a particular directory at a time.
                fh_ckpt = tracking_util.Checkpoint(model=model)
                if input_arg == 'checkpoint':
                    checkpoint_or_manager = fh_ckpt
                else:
                    checkpoint_or_manager = _make_checkpoint_manager(
                        fh_ckpt, checkpoint_dir, strategy.cluster_resolver)
                preemption_handler = (
                    failure_handling.PreemptionCheckpointHandler(
                        strategy.cluster_resolver, checkpoint_or_manager,
                        checkpoint_dir, termination_config))

            def distributed_train_step(current_epoch, current_step):
                @def_function.function
                def train_step():
                    if distribution_strategy_context.get_distribution_strategy(
                    ).cluster_resolver.task_id == raise_app_error_on_worker:
                        raise errors_impl.ResourceExhaustedError(
                            node_def=None,
                            op=None,
                            message='Running out of resources')

                    model.v.assign_add(constant_op.constant(1.))

                strategy.run(train_step)

                if current_step == STEPS_PER_EPOCH - 1:
                    logging.info('epoch %d finished', current_epoch)

            logging.info('Start training at %d',
                         preemption_handler.total_run_calls)

            # If the training process has been restarted, verify that the expected
            # number of checkpoints have been written.
            # we also want to check training_finished, because there's a corner case
            # where the signal is sent quite late and training finishes before the
            # grace period ends.
            if training_restarted and training_restarted.is_set(
            ) and not training_finished.is_set():
                logging.info('training restarted')
                match_group = [
                    re.search(r'.*ckpt-(\d+).index', a_file)
                    for a_file in gfile.ListDirectory(checkpoint_dir)
                ]
                checkpoint_index = [
                    a_match.group(1) for a_match in match_group if a_match
                ]
                if getattr(termination_config, 'grace_period', 0):
                    # Two checkpoints were saved for the extended grace period.
                    self.assertEqual(int(checkpoint_index[0]), 2)
                else:
                    self.assertEqual(int(checkpoint_index[0]), 1)

            for epoch in range(
                    preemption_handler.total_run_calls // STEPS_PER_EPOCH,
                    EPOCHS_TO_RUN):

                for step in range(
                        preemption_handler.total_run_calls % STEPS_PER_EPOCH,
                        STEPS_PER_EPOCH):
                    preemption_handler.run(distributed_train_step, epoch, step)
                # Add some randomness to when preemption actually happens. We should
                # trigger it for sure if the training is coming to an end and it hasn't
                # been triggered yet.
                if epoch >= EPOCHS_TO_RUN - 2:
                    trigger_it = True
                else:
                    trigger_it = False

                self._maybe_trigger_a_preemption(training_started_event,
                                                 trigger_it)

            training_finished.set()

            logging.info('Training finished.')

            self.assertEqual(
                model.v.numpy(), strategy.num_replicas_in_sync *
                EPOCHS_TO_RUN * STEPS_PER_EPOCH)