Exemplo n.º 1
0
  def worker_fn(self,
                checkpoint_dir,
                cluster_spec,
                training_started_event=None,
                raise_app_error_on_worker=None):

    _enable_coordination_service(cluster_spec)
    strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

    class Model(module.Module):

      def __init__(self):
        self.v = variables_lib.Variable(
            0.,
            synchronization=variables_lib.VariableSynchronization.ON_WRITE,
            aggregation=variables_lib.VariableAggregation.SUM)

      @def_function.function(input_signature=[])
      def __call__(self):
        return self.v.read_value()

    with mock.patch.object(gce_util, 'on_gcp', lambda: False):

      with strategy.scope():
        model = Model()
        # Named it fh_ckpt because it'd be better that the user have their
        # regular checkpoint separate from the checkpoint for
        # WorkerPreemptionHandler, since we will create CheckpointManager
        # to manage the checkpoint and only one CheckpointManager should be
        # active in a particular directory at a time.
        fh_ckpt = tracking_util.Checkpoint(model=model)

        worker_preemption_watcher = failure_handling.WorkerPreemptionHandler(
            strategy.cluster_resolver, fh_ckpt, checkpoint_dir)

      def distributed_train_step(current_epoch, current_step):

        @def_function.function
        def train_step():
          if distribution_strategy_context.get_distribution_strategy(
          ).cluster_resolver.task_id == raise_app_error_on_worker:
            raise errors_impl.ResourceExhaustedError(
                node_def=None, op=None, message='Running out of resources')

          model.v.assign_add(constant_op.constant(1.))

        strategy.run(train_step)

        if current_step == STEPS_PER_EPOCH - 1:
          logging.info('epoch %d finished', current_epoch)

      logging.info('Restored training at %d',
                   worker_preemption_watcher.total_runs)
      for epoch in range(
          worker_preemption_watcher.total_runs // STEPS_PER_EPOCH,
          EPOCHS_TO_RUN):

        for step in range(
            worker_preemption_watcher.total_runs % STEPS_PER_EPOCH,
            STEPS_PER_EPOCH):
          worker_preemption_watcher.run(distributed_train_step, epoch, step)
        # Add some randomness to when preemption actually happens. We should
        # trigger it for sure if the training is coming to an end and it hasn't
        # been triggered yet.
        if epoch >= EPOCHS_TO_RUN - 2:
          trigger_it = True
        else:
          trigger_it = False

        self._maybe_trigger_a_preemption(training_started_event, trigger_it)

      self.assertEqual(
          model.v.numpy(),
          strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH)
Exemplo n.º 2
0
  def worker_fn(self,
                checkpoint_dir,
                cluster_spec,
                maintenance_event=None,
                training_finished=None,
                frequent_send=False,
                training_restarted=None,
                termination_config=failure_handling.TerminationConfig(
                    time_till_termination=0)):

    _enable_coordination_service(cluster_spec)
    strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

    def mock_termination_watcher_function_gce(*args, **kwargs):
      del args, kwargs
      if not frequent_send:
        time.sleep(1)
        if (not maintenance_event.is_set()) and (random.randrange(0, 7) == 5):
          maintenance_event.set()
          logging.info('Termination notice available.')
          return True

      elif frequent_send and not maintenance_event.is_set():
        logging.info('Termination notice available.')
        return True

      return False

    with mock.patch.object(
        gce_util, 'termination_watcher_function_gce',
        mock_termination_watcher_function_gce), mock.patch.object(
            gce_util, 'detect_platform',
            lambda: gce_util.PlatformDevice.GCE_GPU):

      class Model(module.Module):

        def __init__(self):
          self.v = variables_lib.Variable(
              0.,
              synchronization=variables_lib.VariableSynchronization.ON_WRITE,
              aggregation=variables_lib.VariableAggregation.SUM)

        @def_function.function(input_signature=[])
        def __call__(self):
          return self.v.read_value()

      with strategy.scope():
        model = Model()
        fh_ckpt = tracking_util.Checkpoint(model=model)

        worker_preemption_watcher = failure_handling.WorkerPreemptionHandler(
            strategy.cluster_resolver, fh_ckpt, checkpoint_dir,
            termination_config)

      def distributed_train_step(current_epoch, current_step):

        @def_function.function
        def train_step():
          model.v.assign_add(constant_op.constant(1.))

        strategy.run(train_step)

        if current_step == STEPS_PER_EPOCH - 1:
          logging.info('epoch %d finished', current_epoch)

      logging.info('Start training at %d', worker_preemption_watcher.total_runs)

      # If the training process has been restarted, verify that the expected
      # number of checkpoints have been written.
      # We also want to check training_finished, because there's a corner case
      # where the signal is sent quite late and training finishes before the
      # grace period ends.
      if training_restarted.is_set() and not training_finished.is_set():
        match_group = [
            re.search(r'.*ckpt-(\d+).index', a_file)
            for a_file in gfile.ListDirectory(checkpoint_dir)
        ]
        checkpoint_index = [
            a_match.group(1) for a_match in match_group if a_match
        ]
        if termination_config.time_till_termination > 0:
          # Two checkpoints were saved for the extended grace period.
          self.assertEqual(int(checkpoint_index[0]), 2)
        else:
          self.assertEqual(int(checkpoint_index[0]), 1)

      for epoch in range(
          worker_preemption_watcher.total_runs // STEPS_PER_EPOCH,
          EPOCHS_TO_RUN):

        for step in range(
            worker_preemption_watcher.total_runs % STEPS_PER_EPOCH,
            STEPS_PER_EPOCH):
          worker_preemption_watcher.run(distributed_train_step, epoch, step)

      logging.info('Training finished.')
      training_finished.set()

      self.assertEqual(
          model.v.numpy(),
          strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH)

      running_threads = test_util.get_running_threads()
      if test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                              running_threads) and test_util.has_thread(
                                  _LOCAL_WATCHER_THREAD_PREFIX,
                                  running_threads):
        try:
          # Explicitly call __del__ since making it None and gc.collect does
          # not invoke __del__ here.
          worker_preemption_watcher.__del__()

          time.sleep(2)

          running_threads = test_util.get_running_threads()
          self.assertFalse(
              test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX,
                                   running_threads))
          self.assertFalse(
              test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                   running_threads))

        except urllib.error.URLError as e:
          if 'Temporary failure in name resolution' in e.message:
            # This is caused by a weird flakiness that mock.patch does not
            # correctly patch gce_util.request_compute_metadata, a real request
            # is attempted, and an error is hit in
            # gce_util.request_compute_metadata
            logging.warning('Hit a mock issue.')
            return
Exemplo n.º 3
0
  def worker_fn(self,
                checkpoint_dir,
                cluster_spec,
                maintenance_event=None,
                training_finished=None,
                frequent_send=False):

    _enable_coordination_service(cluster_spec)
    strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

    def mock_termination_watcher_function_gce(*args, **kwargs):
      del args, kwargs
      if not frequent_send:
        time.sleep(1)
        if (not maintenance_event.is_set()) and (random.randrange(0, 20) > 18):
          maintenance_event.set()
          logging.info('Termination notice available.')
          return True

      elif frequent_send and not maintenance_event.is_set():
        logging.info('Termination notice available.')
        return True

      return False

    with mock.patch.object(
        gce_util, 'termination_watcher_function_gce',
        mock_termination_watcher_function_gce), mock.patch.object(
            gce_util, 'detect_platform',
            lambda: gce_util.PlatformDevice.GCE_GPU):

      class Model(module.Module):

        def __init__(self):
          self.v = variables_lib.Variable(
              0.,
              synchronization=variables_lib.VariableSynchronization.ON_WRITE,
              aggregation=variables_lib.VariableAggregation.SUM)

        @def_function.function(input_signature=[])
        def __call__(self):
          return self.v.read_value()

      with strategy.scope():
        model = Model()
        fh_ckpt = tracking_util.Checkpoint(model=model)

        worker_preemption_watcher = failure_handling.WorkerPreemptionHandler(
            strategy.cluster_resolver, fh_ckpt, checkpoint_dir)

      def distributed_train_step(current_epoch, current_step):

        @def_function.function
        def train_step():
          model.v.assign_add(constant_op.constant(1.))

        strategy.run(train_step)

        if current_step == STEPS_PER_EPOCH - 1:
          logging.info('epoch %d finished', current_epoch)

      logging.info('Start training at %d', worker_preemption_watcher.total_runs)
      for epoch in range(
          worker_preemption_watcher.total_runs // STEPS_PER_EPOCH,
          EPOCHS_TO_RUN):

        for step in range(
            worker_preemption_watcher.total_runs % STEPS_PER_EPOCH,
            STEPS_PER_EPOCH):
          worker_preemption_watcher.run(distributed_train_step, epoch, step)

      training_finished.set()

      self.assertEqual(
          model.v.numpy(),
          strategy.num_replicas_in_sync * EPOCHS_TO_RUN * STEPS_PER_EPOCH)

      running_threads = test_util.get_running_threads()
      if test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                              running_threads) and test_util.has_thread(
                                  _LOCAL_WATCHER_THREAD_PREFIX,
                                  running_threads):
        try:
          # Explicitly call __del__ since making it None and gc.collect does
          # not invoke __del__ here.
          worker_preemption_watcher.__del__()

          time.sleep(2)

          running_threads = test_util.get_running_threads()
          self.assertFalse(
              test_util.has_thread(_LOCAL_WATCHER_THREAD_PREFIX,
                                   running_threads))
          self.assertFalse(
              test_util.has_thread(_PEER_WATCHER_THREAD_PREFIX,
                                   running_threads))

        except urllib.error.URLError as e:
          if 'Temporary failure in name resolution' in e.message:
            # This is caused by a weird flakiness that mock.patch does not
            # correctly patch gce_util.request_compute_metadata, a real request
            # is attempted, and an error is hit in
            # gce_util.request_compute_metadata
            logging.warning('Hit a mock issue.')
            return
Exemplo n.º 4
0
    def worker_fn(self,
                  checkpoint_dir,
                  cluster_spec,
                  training_started_event=None,
                  raise_app_error_on_worker=None,
                  training_restarted=None,
                  training_finished=None,
                  termination_config=failure_handling.TerminationConfig()):

        _enable_coordination_service(cluster_spec)
        strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()

        class Model(module.Module):
            def __init__(self):
                self.v = variables_lib.Variable(
                    0.,
                    synchronization=variables_lib.VariableSynchronization.
                    ON_WRITE,
                    aggregation=variables_lib.VariableAggregation.SUM)

            @def_function.function(input_signature=[])
            def __call__(self):
                return self.v.read_value()

        with mock.patch.object(gce_util, 'on_gcp', lambda: False):

            with strategy.scope():
                model = Model()
                # Named it fh_ckpt because it'd be better that the user have their
                # regular checkpoint separate from the checkpoint for
                # WorkerPreemptionHandler, since we will create CheckpointManager
                # to manage the checkpoint and only one CheckpointManager should be
                # active in a particular directory at a time.
                fh_ckpt = tracking_util.Checkpoint(model=model)

                worker_preemption_watcher = failure_handling.WorkerPreemptionHandler(
                    strategy.cluster_resolver, fh_ckpt, checkpoint_dir,
                    termination_config)

            def distributed_train_step(current_epoch, current_step):
                @def_function.function
                def train_step():
                    if distribution_strategy_context.get_distribution_strategy(
                    ).cluster_resolver.task_id == raise_app_error_on_worker:
                        raise errors_impl.ResourceExhaustedError(
                            node_def=None,
                            op=None,
                            message='Running out of resources')

                    model.v.assign_add(constant_op.constant(1.))

                strategy.run(train_step)

                if current_step == STEPS_PER_EPOCH - 1:
                    logging.info('epoch %d finished', current_epoch)

            logging.info('Start training at %d',
                         worker_preemption_watcher.total_runs)

            # If the training process has been restarted, verify that the expected
            # number of checkpoints have been written.
            # we also want to check training_finished, because there's a corner case
            # where the signal is sent quite late and training finishes before the
            # grace period ends.
            if training_restarted and training_restarted.is_set(
            ) and not training_finished.is_set():
                logging.info('training restarted')
                match_group = [
                    re.search(r'.*ckpt-(\d+).index', a_file)
                    for a_file in gfile.ListDirectory(checkpoint_dir)
                ]
                checkpoint_index = [
                    a_match.group(1) for a_match in match_group if a_match
                ]
                if getattr(termination_config, 'time_till_termination', 0):
                    # Two checkpoints were saved for the extended grace period.
                    self.assertEqual(int(checkpoint_index[0]), 2)
                else:
                    self.assertEqual(int(checkpoint_index[0]), 1)

            for epoch in range(
                    worker_preemption_watcher.total_runs // STEPS_PER_EPOCH,
                    EPOCHS_TO_RUN):

                for step in range(
                        worker_preemption_watcher.total_runs % STEPS_PER_EPOCH,
                        STEPS_PER_EPOCH):
                    worker_preemption_watcher.run(distributed_train_step,
                                                  epoch, step)
                # Add some randomness to when preemption actually happens. We should
                # trigger it for sure if the training is coming to an end and it hasn't
                # been triggered yet.
                if epoch >= EPOCHS_TO_RUN - 2:
                    trigger_it = True
                else:
                    trigger_it = False

                self._maybe_trigger_a_preemption(training_started_event,
                                                 trigger_it)

            training_finished.set()

            logging.info('Training finished.')

            self.assertEqual(
                model.v.numpy(), strategy.num_replicas_in_sync *
                EPOCHS_TO_RUN * STEPS_PER_EPOCH)