def test_basic_run(self, input_arg):
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)
        maintenance_event = multi_process_runner.manager().Event()
        training_finished = multi_process_runner.manager().Event()
        training_restarted = multi_process_runner.manager().Event()

        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/')

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, input_arg, maintenance_event,
                  training_finished, False, training_restarted),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()

        while (not maintenance_event.is_set()) and (
                not training_finished.is_set()):
            time.sleep(1)

        # wait for all cluster to exit with a time out
        waiting_time = 0
        exit_process_count = 0
        # this addition to mitigate the fact that our step time is too short in test
        while exit_process_count != CLUSTER_SIZE and waiting_time < 15:
            exit_process_count = 0
            for worker_id in range(CLUSTER_SIZE):
                if not mpr.process_exists('worker', worker_id):
                    exit_process_count += 1
            waiting_time += 1
            time.sleep(1)

        if waiting_time >= 15:
            raise RuntimeError(
                'Waited long but at least one worker still exist. '
                'Considering size of our model, this should not'
                ' happen.')

        if not training_finished.is_set():
            logging.info('restarting workers')
            training_restarted.set()
            for worker_id in range(CLUSTER_SIZE):
                mpr.start_single_process('worker', worker_id, cluster_spec)
            logging.info('workers restarted')

        mpr.join(timeout=250)
        self.assertTrue(training_finished.is_set())
示例#2
0
    def test_grace_period_continue_training(self):
        grace_period = 5
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)
        training_started_event = multi_process_runner.manager().Event()
        training_restarted = multi_process_runner.manager().Event()
        training_finished = multi_process_runner.manager().Event()
        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        termination_config = failure_handling.TerminationConfig(
            time_till_termination=grace_period)
        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, [training_started_event], None,
                  training_restarted, training_finished, termination_config),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()
        while not training_started_event.is_set():
            time.sleep(1)

        killed_worker = random.randrange(0, CLUSTER_SIZE)
        logging.info('sending SIGTERM')
        os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM)
        logging.info('SIGTERM sent')

        # wait for all cluster within the given grace period (plus a buffer since
        # our per-step time here is too small)
        waiting_time = 0
        exit_process_count = 0
        while exit_process_count != CLUSTER_SIZE and waiting_time < grace_period + 10:
            exit_process_count = 0
            for worker_id in range(CLUSTER_SIZE):
                if not mpr.process_exists('worker', worker_id):
                    exit_process_count += 1
            waiting_time += 1
            time.sleep(1)

        if waiting_time == grace_period + 10:
            raise RuntimeError('Waited exceeding grace period. ')

        logging.info('restarting workers')
        training_restarted.set()
        for worker_id in range(CLUSTER_SIZE):
            mpr.start_single_process('worker', worker_id, cluster_spec)
        logging.info('workers restarted')

        mpr.join(timeout=250)
  def test_basic_run(self):
    has_chief = False
    cluster_spec = multi_worker_test_base.create_cluster_spec(
        has_chief=has_chief,
        num_workers=CLUSTER_SIZE)
    maintenance_event = multi_process_runner.manager().Event()
    training_finished = multi_process_runner.manager().Event()

    checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

    if _is_oss():
      rpc_layer = 'grpc'
    else:
      rpc_layer = 'grpc+loas'

    mpr = multi_process_runner.MultiProcessRunner(
        self.worker_fn,
        cluster_spec,
        args=(checkpoint_dir, cluster_spec, maintenance_event,
              training_finished),
        rpc_layer=rpc_layer,
        return_output=True,
        dependence_on_chief=has_chief)

    logging.info('Cluster starting.')
    mpr.start()

    while (not maintenance_event.is_set()) and (not training_finished.is_set()):
      time.sleep(1)

    time.sleep(5)
    if not training_finished.is_set():
      logging.info('restarting workers')
      for worker_id in range(CLUSTER_SIZE):
        mpr.start_single_process('worker', worker_id, cluster_spec)
      logging.info('workers restarted')

    stdout = mpr.join().stdout
    if maintenance_event.is_set():
      all_start_point = []
      for msg in stdout:
        matched_group = re.search(r'.*Start training at (\d+)', msg)

        if matched_group:
          all_start_point.append(int(matched_group.group(1)))

      # remove duplicate logs created due to presence of multiple workers
      start_points = all_start_point[::CLUSTER_SIZE]

      if len(start_points) > 1:
        # assert that after restarting, we don't repeat previous training steps
        self.assertNotEqual(start_points[-1], 0)
示例#4
0
    def test_creating_variable(self):
        # See PeerFailureTest.test_creating_variable

        def worker_fn(attempts):
            context.context().enable_coordination_service(COORDINATION_SERVICE)
            strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
            task_id, attempt = get_attempt(strategy, attempts)
            with strategy.scope():
                tf.Variable(1.)
                # worker-1 dies here.
                if attempt == 1 and task_id == 1:
                    quick_exit(1)
                v = tf.Variable(tf.random.uniform(()))
                return v.read_value().numpy()

        cluster_spec = multi_worker_test_base.create_cluster_spec(
            num_workers=2)
        attempts = multi_process_runner.manager().dict()
        mpr = multi_process_runner.MultiProcessRunner(worker_fn,
                                                      cluster_spec,
                                                      rpc_layer=RPC_PROTOCOL,
                                                      args=(attempts, ),
                                                      auto_restart=True)
        mpr.start()
        results = mpr.join(timeout=90).return_value
        self.assertEqual(results[0], results[1])
    def test_reduce_small_tensor_broken(self):
        # This test simulates the case when a worker fails before or during reducing
        # a small tensors, e.g. reading a metric.
        #
        # Note that this is a rather corner case and only happens when all of the
        # following conditions are met:
        #   - There're two workers.
        #   - They're reducing a small tensor. The definition of small varies
        #     per platform.
        #   - They're reducing a single tensor. Batched all-reduce are not affected.
        #   - It must be worker-1 that fails.

        def worker_fn(attempts):
            strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
            task_id, attempt = get_attempt(strategy, attempts)
            value = tf.identity([1.])
            strategy.reduce("sum", value, axis=None)
            # worker-1 dies here.
            if attempt == 1 and task_id == 1:
                quick_exit(1)
            strategy.reduce("sum", value, axis=None)

        cluster_spec = multi_worker_test_base.create_cluster_spec(
            num_workers=2)
        attempts = multi_process_runner.manager().dict()
        mpr = multi_process_runner.MultiProcessRunner(worker_fn,
                                                      cluster_spec,
                                                      args=(attempts, ),
                                                      auto_restart=True)
        mpr.start()
        # TODO(b/151232436): worker-0 should raises Unavailable instead of hanging.
        # Now after worker-1 fails, worker-0 waits on the second reduce; after
        # worker-1 recovers, worker-1 waits on the first reduce.
        with self.assertRaises(multi_process_runner.SubprocessTimeoutError):
            mpr.join(timeout=30)
示例#6
0
    def test_reduce_small_tensor(self):
        # See PeerFailureTest.test_reduce_small_tensor

        def worker_fn(attempts):
            context.context().enable_coordination_service(COORDINATION_SERVICE)
            strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
            task_id, attempt = get_attempt(strategy, attempts)
            value = tf.identity([1.])
            strategy.reduce("sum", value, axis=None)
            # worker-1 dies here.
            if attempt == 1 and task_id == 1:
                quick_exit(1)
            return strategy.reduce("sum", value, axis=None).numpy()

        cluster_spec = multi_worker_test_base.create_cluster_spec(
            num_workers=2)
        attempts = multi_process_runner.manager().dict()
        mpr = multi_process_runner.MultiProcessRunner(worker_fn,
                                                      cluster_spec,
                                                      rpc_layer=RPC_PROTOCOL,
                                                      args=(attempts, ),
                                                      auto_restart=True)
        mpr.start()
        results = mpr.join(timeout=90).return_value
        self.assertAllEqual(results, [[2.], [2.]])
    def test_creating_variable_broken(self):
        # This test simulates the case when a worker fails before or during creating
        # a variable. Creating variables involve broadcasting the initial value from
        # the first replica to all replicas.

        def worker_fn(attempts):
            strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
            task_id, attempt = get_attempt(strategy, attempts)
            with strategy.scope():
                tf.Variable(1.)
                # worker-1 dies here.
                if attempt == 1 and task_id == 1:
                    quick_exit(1)
                v = tf.Variable(tf.random.uniform(()))
                return v.read_value().numpy()

        cluster_spec = multi_worker_test_base.create_cluster_spec(
            num_workers=2)
        attempts = multi_process_runner.manager().dict()
        mpr = multi_process_runner.MultiProcessRunner(worker_fn,
                                                      cluster_spec,
                                                      args=(attempts, ),
                                                      auto_restart=True)
        mpr.start()
        # TODO(b/151232436): worker-0 should raises Unavailable instead of hanging.
        # Now after worker-1 fails, worker-0 waits on the second variable creation;
        # after worker-1 recovers, worker-1 waits on the first variable creation.
        with self.assertRaises(multi_process_runner.SubprocessTimeoutError):
            mpr.join(timeout=30)
    def test_grace_period_continue_training(self, input_arg):
        grace_period = 7
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)

        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/')
        maintenance_event = multi_process_runner.manager().Event()
        training_finished = multi_process_runner.manager().Event()
        training_restarted = multi_process_runner.manager().Event()

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        termination_config = failure_handling.TerminationConfig(
            grace_period=grace_period)
        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, input_arg, maintenance_event,
                  training_finished, False, training_restarted,
                  termination_config),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()

        while (not maintenance_event.is_set()) and (
                not training_finished.is_set()):
            time.sleep(1)

        # this addition to mitigate the fact that our step time is too short in test
        time.sleep(grace_period + 10)
        if not training_finished.is_set():
            logging.info('restarting workers')
            training_restarted.set()
            for worker_id in range(CLUSTER_SIZE):
                mpr.start_single_process('worker', worker_id, cluster_spec)
            logging.info('workers restarted')

        mpr.join(timeout=250)

        self.assertTrue(training_finished.is_set())
示例#9
0
  def test_grace_period_continue_training(self, input_arg):
    grace_period = 5
    has_chief = False
    cluster_spec = multi_worker_test_base.create_cluster_spec(
        has_chief=has_chief,
        num_workers=CLUSTER_SIZE)
    training_started_event = multi_process_runner.manager().Event()
    training_restarted = multi_process_runner.manager().Event()
    training_finished = multi_process_runner.manager().Event()
    checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

    if _is_oss():
      rpc_layer = 'grpc'
    else:
      rpc_layer = 'grpc+loas'

    termination_config = failure_handling.TerminationConfig(
        grace_period=grace_period)
    mpr = multi_process_runner.MultiProcessRunner(
        self.worker_fn,
        cluster_spec,
        args=(checkpoint_dir, cluster_spec, input_arg, [training_started_event],
              None, training_restarted, training_finished, termination_config),
        rpc_layer=rpc_layer,
        return_output=True,
        dependence_on_chief=has_chief)

    logging.info('Cluster starting.')
    mpr.start()
    while not training_started_event.is_set():
      time.sleep(1)

    killed_worker = random.randrange(0, CLUSTER_SIZE)
    logging.info('sending SIGTERM')
    os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM)
    logging.info('SIGTERM sent')

    raise_if_not_all_exit(grace_period, mpr)

    logging.info('restarting workers')
    training_restarted.set()
    for worker_id in range(CLUSTER_SIZE):
      mpr.start_single_process('worker', worker_id, cluster_spec)
    logging.info('workers restarted')

    mpr.join(timeout=250)
示例#10
0
    def test_preemption_checkpointing(self, input_arg):
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)
        training_started_event = multi_process_runner.manager().Event()
        training_restarted = multi_process_runner.manager().Event()
        training_finished = multi_process_runner.manager().Event()

        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, input_arg,
                  [training_started_event
                   ], None, training_restarted, training_finished),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()
        while not training_started_event.is_set():
            time.sleep(1)

        logging.info('sending sigterm')
        killed_worker = random.randrange(0, CLUSTER_SIZE)
        os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM)

        logging.info('sigterm sent')
        time.sleep(5)

        logging.info('restarting workers')
        training_restarted.set()
        for worker_id in range(CLUSTER_SIZE):
            mpr.start_single_process('worker', worker_id, cluster_spec)
        logging.info('workers restarted')

        mpr.join(timeout=270)
    def __init__(self,
                 cluster_resolver,
                 stream_output=False,
                 collective_leader=None):
        self._cluster_resolver = cluster_resolver
        self._cluster_spec = cluster_resolver.cluster_spec().as_dict()
        self._rpc_layer = cluster_resolver.rpc_layer
        self._stream_output = stream_output
        self._start_events = {}
        self._finish_events = {}
        self._mpr_manager = multi_process_runner.manager()

        def task_function(start_events, finish_events):
            cluster_resolver = TFConfigClusterResolver()
            cluster_spec = cluster_resolver.cluster_spec()
            task_type = cluster_resolver.task_type
            task_id = cluster_resolver.task_id
            rpc_layer = cluster_resolver.rpc_layer

            # TODO(yuefengz): support GPU clusters.
            server_config = config_pb2.ConfigProto()
            server_config.device_count['GPU'] = 0

            if collective_leader:
                server_config.experimental.collective_group_leader = collective_leader
                server_config.experimental.collective_nccl = False

                logging.info(
                    'Enabling collective ops with cluster_spec = %r, task_type = %r, '
                    'task_id = %r, rpc_layer = %r, collective_leader = %s',
                    cluster_spec, task_type, task_id, rpc_layer,
                    collective_leader)
            else:
                logging.info(
                    'Starting server with cluster_spec = %r, task_type = %r, '
                    'task_id = %r, rpc_layer = %r', cluster_spec, task_type,
                    task_id, rpc_layer)

            server_lib.Server(cluster_spec,
                              job_name=task_type,
                              protocol=rpc_layer,
                              task_index=task_id,
                              config=server_config,
                              start=True)

            start_event = start_events[task_type][task_id]
            start_event.set()

            finish_event = finish_events[task_type][task_id]
            finish_event.wait()

            os._exit(0)  # pylint: disable=protected-access

        self._task_function = task_function
        self._mpr = None
示例#12
0
    def test_basic_run(self, input_arg):
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)
        maintenance_event = multi_process_runner.manager().Event()
        training_finished = multi_process_runner.manager().Event()
        training_restarted = multi_process_runner.manager().Event()

        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/')

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, input_arg, maintenance_event,
                  training_finished, False, training_restarted),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()

        while (not maintenance_event.is_set()) and (
                not training_finished.is_set()):
            time.sleep(1)

        raise_if_not_all_exit(0, mpr)

        if not training_finished.is_set():
            logging.info('restarting workers')
            training_restarted.set()
            for worker_id in range(CLUSTER_SIZE):
                mpr.start_single_process('worker', worker_id, cluster_spec)
            logging.info('workers restarted')

        mpr.join(timeout=250)
        self.assertTrue(training_finished.is_set())
示例#13
0
    def test_multiple_workers_preempted_consecutively(self, grace_period,
                                                      input_arg):
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)
        maintenance_event = multi_process_runner.manager().Event()
        training_finished = multi_process_runner.manager().Event()
        training_restarted = multi_process_runner.manager().Event()

        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/')

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        termination_config = failure_handling.TerminationConfig(
            grace_period=grace_period)
        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, input_arg, maintenance_event,
                  training_finished, True, training_restarted,
                  termination_config),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()

        raise_if_not_all_exit(grace_period, mpr)

        maintenance_event.set()
        logging.info('restarting workers')
        training_restarted.set()
        for worker_id in range(CLUSTER_SIZE):
            mpr.start_single_process('worker', worker_id, cluster_spec)
        logging.info('workers restarted')

        mpr.join(timeout=250)
        self.assertTrue(training_finished.is_set())
示例#14
0
  def test_preemption_checkpointing(self):
    has_chief = False
    cluster_spec = multi_worker_test_base.create_cluster_spec(
        has_chief=has_chief,
        num_workers=CLUSTER_SIZE)
    training_started_event = multi_process_runner.manager().Event()

    checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

    if _is_oss():
      rpc_layer = 'grpc'
    else:
      rpc_layer = 'grpc+loas'

    mpr = multi_process_runner.MultiProcessRunner(
        self.worker_fn,
        cluster_spec,
        args=(checkpoint_dir, cluster_spec, [training_started_event]),
        rpc_layer=rpc_layer,
        return_output=True,
        dependence_on_chief=has_chief)

    logging.info('Cluster starting.')
    mpr.start()
    while not training_started_event.is_set():
      time.sleep(1)

    logging.info('sending sigterm')
    killed_worker = random.randrange(0, CLUSTER_SIZE)
    os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM)

    logging.info('sigterm sent')
    time.sleep(5)

    logging.info('restarting workers')
    for worker_id in range(CLUSTER_SIZE):
      mpr.start_single_process('worker', worker_id, cluster_spec)
    logging.info('workers restarted')

    stdout = mpr.join().stdout
    all_start_point = []
    for msg in stdout:
      matched_group = re.search(r'.*Restored training at (\d+)', msg)

      if matched_group:
        all_start_point.append(int(matched_group.group(1)))

    # remove duplicate logs created due to presence of multiple workers
    start_points = all_start_point[::CLUSTER_SIZE]

    # assert that after restarting, we don't repeat previous training steps
    self.assertNotEqual(start_points[-1], 0)
示例#15
0
    def test_auto_restart(self):
        def proc_func(counter):
            counter.value += 1
            if counter.value == 1:
                raise ValueError

        manager = multi_process_runner.manager()
        counter = manager.Value(int, 0)
        mpr = multi_process_runner.MultiProcessRunner(
            proc_func,
            multi_worker_test_base.create_cluster_spec(num_workers=1),
            args=(counter, ),
            auto_restart=True)
        mpr.start()
        mpr.join()
        self.assertEqual(counter.value, 2)
示例#16
0
  def __init__(self, cluster_resolver):
    self._cluster_resolver = cluster_resolver
    self._cluster_spec = cluster_resolver.cluster_spec().as_dict()
    self._rpc_layer = cluster_resolver.rpc_layer
    self._start_events = {}
    self._finish_events = {}
    self._mpr_manager = multi_process_runner.manager()

    def task_function(start_events, finish_events):
      cluster_resolver = TFConfigClusterResolver()
      cluster_spec = cluster_resolver.cluster_spec()
      task_type = cluster_resolver.task_type
      task_id = cluster_resolver.task_id
      rpc_layer = cluster_resolver.rpc_layer

      logging.info(
          'Starting server with cluster_spec = %r, task_type = %r, '
          'task_id = %r, rpc_layer = %r', cluster_spec, task_type, task_id,
          rpc_layer)

      # TODO(yuefengz): support GPU clusters.
      server_config = config_pb2.ConfigProto()
      server_config.device_count['GPU'] = 0

      # Set the environment variable to prevent hanging upon job failure and
      # restart. Note that it defaults to 'use_caller' at Google, but defaults
      # to False in OSS.
      os.environ['GRPC_FAIL_FAST'] = 'use_caller'

      server_lib.Server(
          cluster_spec,
          job_name=task_type,
          protocol=rpc_layer,
          task_index=task_id,
          config=server_config,
          start=True)

      start_event = start_events[task_type][task_id]
      start_event.set()

      finish_event = finish_events[task_type][task_id]
      finish_event.wait()

      os._exit(0)  # pylint: disable=protected-access

    self._task_function = task_function
    self._mpr = None
示例#17
0
    def test_auto_restart_and_chief(self):
        # If the chief has exited with zero exit code, auto restart should stop
        # restarting other tasks even if they fail.

        def proc_func():
            time.sleep(1)
            if multi_worker_test_base.get_task_type() != 'chief':
                raise ValueError

        manager = multi_process_runner.manager()
        mpr = multi_process_runner.MultiProcessRunner(
            proc_func,
            multi_worker_test_base.create_cluster_spec(has_chief=True,
                                                       num_workers=1),
            auto_restart=True)
        mpr.start()
        with self.assertRaises(ValueError):
            mpr.join(timeout=10)
    def test_numpy_fetched_after_worker_failure(self):
        def fn(first_fetch_occurred_event, worker_terminated_event):
            os.environ["GRPC_FAIL_FAST"] = "use_caller"

            cluster_resolver = TFConfigClusterResolver()
            if cluster_resolver.task_type != "chief":
                utils.start_server(cluster_resolver, "grpc")
            strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
                cluster_resolver)
            ps_coordinator = coordinator_lib.ClusterCoordinator(strategy)

            with strategy.scope():
                v = variables.Variable(initial_value=0, dtype=dtypes.int32)

            @def_function.function
            def worker_fn():
                return v + 1, v - 1

            remote_value = ps_coordinator.schedule(worker_fn)
            logging.info("result (1st fetch): %r", remote_value.fetch())
            first_fetch_occurred_event.set()
            worker_terminated_event.wait()
            logging.info("result (2nd fetch): %r", remote_value.fetch())

        manager = multi_process_runner.manager()
        first_fetch_occurred_event = manager.Event()
        worker_terminated_event = manager.Event()
        mpr = multi_process_runner.MultiProcessRunner(
            fn,
            multi_worker_test_base.create_cluster_spec(has_chief=True,
                                                       num_workers=1,
                                                       num_ps=1,
                                                       has_eval=False),
            args=(first_fetch_occurred_event, worker_terminated_event),
            rpc_layer="grpc",
            return_output=True,
            use_dill_for_args=False)

        mpr.start()
        first_fetch_occurred_event.wait()
        mpr.terminate("worker", 0)
        worker_terminated_event.set()
        self.assertTrue(
            any("result (2nd fetch)" in msg for msg in mpr.join().stdout))
示例#19
0
    def test_quick_recover(self):
        # This test simulates the case when a worker fails but recovers quickly
        # before the next collective.
        #
        # It's not guaranteed that the cluster only restarts once when one worker
        # fails. The external job management system is expected to keep restarting
        # failed workers.

        def worker_fn(attempts):
            # Set a long check alive interval to better simulate the case when a
            # worker fails and recovers during a check alive interval.
            mwms_lib.CollectiveAllReduceExtended._check_alive_interval = 30
            mwms_lib.CollectiveAllReduceExtended._check_alive_initial_timeout = 30

            context.context().configure_coordination_service(
                COORDINATION_SERVICE)
            strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
            task_id, attempt = get_attempt(strategy, attempts)

            @tf.function
            def replica_fn():
                ctx = tf.distribute.get_replica_context()
                # Use a large tensor because small tensor may hang regardless when the
                # worker recovers.
                value = tf.ones((64, 64))
                ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value, value])

            strategy.run(replica_fn)
            # worker-1 dies here.
            if attempt == 1 and task_id == 1:
                quick_exit(1)
            strategy.run(replica_fn)

        cluster_spec = multi_worker_test_base.create_cluster_spec(
            num_workers=2)
        attempts = multi_process_runner.manager().dict()
        mpr = multi_process_runner.MultiProcessRunner(worker_fn,
                                                      cluster_spec,
                                                      rpc_layer=RPC_PROTOCOL,
                                                      args=(attempts, ),
                                                      auto_restart=True)
        mpr.start()
        mpr.join(timeout=90)
    def test_quick_recover(self):
        # This test simulates the case when a worker fails but recovers quickly
        # before the next collective.
        #
        # It's not guaranteed that the cluster only restarts once when one worker
        # fails. The external job management system is expected to keep restarting
        # failed workers.

        def worker_fn(attempts):
            strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
            task_id, attempt = get_attempt(strategy, attempts)

            if attempt == 2 and task_id == 1:
                multi_process_runner.barrier().wait()

            @tf.function
            def replica_fn():
                ctx = tf.distribute.get_replica_context()
                # Use a large tensor because small tensor may hang regardless when the
                # worker recovers.
                value = tf.ones((64, 64))
                ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value, value])

            strategy.run(replica_fn)
            # worker-1 dies here.
            if attempt == 1 and task_id == 1:
                quick_exit(1)
            # Make worker-0 waits for worker-1 to restart before entering the next
            # collective to simulate a quick recovery of worker-1.
            if attempt == 1 and task_id == 0:
                multi_process_runner.barrier().wait()
            strategy.run(replica_fn)

        cluster_spec = multi_worker_test_base.create_cluster_spec(
            num_workers=2)
        attempts = multi_process_runner.manager().dict()
        mpr = multi_process_runner.MultiProcessRunner(worker_fn,
                                                      cluster_spec,
                                                      args=(attempts, ),
                                                      auto_restart=True)
        mpr.start()
        mpr.join(timeout=90)
示例#21
0
    def test_auto_restart_terminate(self):
        # Tasks terminated by the user should also be restarted.

        def proc_func(counter):
            counter.value += 1
            if counter.value == 1:
                time.sleep(100)

        manager = multi_process_runner.manager()
        counter = manager.Value(int, 0)

        mpr = multi_process_runner.MultiProcessRunner(
            proc_func,
            multi_worker_test_base.create_cluster_spec(has_chief=False,
                                                       num_workers=1),
            args=(counter, ),
            auto_restart=True)
        mpr.start()
        time.sleep(3)
        mpr.terminate('worker', 0)
        mpr.join(timeout=20)
        self.assertEqual(counter.value, 2)
示例#22
0
    def test_grace_period_continue_training(self, input_arg, mwms_mode):
        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/')
        grace_period = 7
        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        termination_config = failure_handling.TerminationConfig(
            grace_period=grace_period)

        if mwms_mode == 'multi_worker':
            has_chief = False
            cluster_spec = multi_worker_test_base.create_cluster_spec(
                has_chief=has_chief, num_workers=CLUSTER_SIZE)

            checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/')
            maintenance_event = multi_process_runner.manager().Event()
            training_finished = multi_process_runner.manager().Event()
            training_restarted = multi_process_runner.manager().Event()

            mpr = multi_process_runner.MultiProcessRunner(
                self.worker_fn,
                cluster_spec,
                args=(checkpoint_dir, cluster_spec, input_arg,
                      maintenance_event, training_finished, False,
                      training_restarted, termination_config),
                rpc_layer=rpc_layer,
                return_output=True,
                dependence_on_chief=has_chief)

            logging.info('Cluster starting.')
            mpr.start()

            while (not maintenance_event.is_set()) and (
                    not training_finished.is_set()):
                time.sleep(1)

            raise_if_not_all_exit(grace_period, mpr)

            if not training_finished.is_set():
                logging.info('restarting workers')
                training_restarted.set()
                for worker_id in range(CLUSTER_SIZE):
                    mpr.start_single_process('worker', worker_id, cluster_spec)
                logging.info('workers restarted')

            mpr.join(timeout=250)

            self.assertTrue(training_finished.is_set())

        else:
            maintenance_event = threading.Event()
            training_finished = threading.Event()
            training_restarted = threading.Event()

            cluster_spec = server_lib.ClusterSpec({})
            caught_exit = False
            try:
                self.worker_fn(checkpoint_dir, cluster_spec, input_arg,
                               maintenance_event, training_finished, False,
                               training_restarted, termination_config)
            except SystemExit as exit_error:
                caught_exit = True
                # We cannot use assertRaise instead, since termination is not always
                # triggered.
                self.assertEqual(exit_error.code, 143)  # pylint:disable=g-assert-in-except

            if maintenance_event.is_set() and not training_finished.is_set():
                self.assertTrue(caught_exit)

                logging.info('restarting workers')
                training_restarted.set()
                self.worker_fn(checkpoint_dir, cluster_spec, input_arg,
                               maintenance_event, training_finished, False,
                               training_restarted, termination_config)
                self.assertTrue(training_finished.is_set())
示例#23
0
    def test_multiple_workers_preempted_consecutively(self, grace_period):
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)
        maintenance_event = multi_process_runner.manager().Event()
        training_finished = multi_process_runner.manager().Event()

        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        termination_config = failure_handling.TerminationConfig(
            time_till_termination=grace_period)
        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, maintenance_event,
                  training_finished, True, termination_config),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()

        # wait for all cluster to exit with a time out
        waiting_time = 0
        exit_process_count = 0
        # this addition to mitigate the fact that our step time is too short in test
        while exit_process_count != CLUSTER_SIZE and waiting_time < max(
                grace_period + 15, 40):
            exit_process_count = 0
            for worker_id in range(CLUSTER_SIZE):
                if not mpr.process_exists('worker', worker_id):
                    exit_process_count += 1
            waiting_time += 1
            time.sleep(1)

        if waiting_time == max(grace_period + 5, 40):
            raise RuntimeError(
                'Waited long but at least one worker still exist. '
                'Considering size of our model, this should not'
                ' happen.')

        maintenance_event.set()
        logging.info('restarting workers')
        for worker_id in range(CLUSTER_SIZE):
            mpr.start_single_process('worker', worker_id, cluster_spec)
        logging.info('workers restarted')

        stdout = mpr.join(timeout=250).stdout
        found_message = 0
        checkpoint_count = []
        for msg in stdout:
            matched_group = re.search(r'.*has received termination notice*',
                                      msg)
            checkpoint_group = re.search(r'.*RUN_TO_CHECKPOINT set to (\d+)',
                                         msg)
            if matched_group:
                found_message += 1
            if checkpoint_group:
                checkpoint_count.append(int(checkpoint_group.group(1)))

        self.assertGreaterEqual(found_message, 1)
        if grace_period > 0:
            self.assertLen(set(checkpoint_count), 2)
    def _testStrategyRun(self, failure_task_type):
        def fn(functions_scheduled_event):
            # TODO(b/170664373): This is needed for TF2 parameter server training in
            # OSS. Remove this when resolved.
            os.environ["GRPC_FAIL_FAST"] = "use_caller"

            cluster_resolver = TFConfigClusterResolver()
            if cluster_resolver.task_type != "chief":
                utils.start_server(cluster_resolver, "grpc")
            strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
                cluster_resolver)
            ps_client = coordinator_lib.ClusterCoordinator(strategy)

            with strategy.scope():
                v = variables.Variable(initial_value=1)

                @def_function.function
                def worker_fn(input_tensor):
                    def replica_fn(input_tensor):
                        return input_tensor + v

                    run_result = strategy.run(replica_fn,
                                              args=(input_tensor, ))
                    check_ops.assert_equal_v2(run_result, 4)
                    return run_result

            for i in range(5000):
                if i % 500 == 0:
                    logging.info("Scheduling function-{}...".format(i))
                result = ps_client.schedule(worker_fn,
                                            args=(constant_op.constant(3), ))
            functions_scheduled_event.set()
            logging.info("Joining...")
            ps_client.join()
            logging.info("Finished joining.")
            if result.fetch() != 4:
                raise AssertionError(
                    "Unexpected RemoteValue result: {}".format(result.fetch()))
            logging.info("testStrategyRun succeeded")

        manager = multi_process_runner.manager()
        functions_scheduled_event = manager.Event()
        mpr = multi_process_runner.MultiProcessRunner(
            fn,
            multi_worker_test_base.create_cluster_spec(has_chief=True,
                                                       num_workers=1,
                                                       num_ps=1,
                                                       has_eval=False),
            args=(functions_scheduled_event, ),
            rpc_layer="grpc",
            return_output=True)
        mpr.start()

        if failure_task_type is not None:
            functions_scheduled_event.wait()
            logging.info("Before interrupting {}-0.".format(failure_task_type))
            mpr.terminate(failure_task_type, 0)

            if failure_task_type == "ps":
                with self.assertRaises(errors.UnavailableError):
                    mpr.join()
                return

            time.sleep(10)
            logging.info("Before restarting {}-0.".format(failure_task_type))
            mpr.start_single_process(task_type="worker", task_id=0)

        self.assertTrue(
            any([
                "testStrategyRun succeeded" in msg for msg in mpr.join().stdout
            ]))
示例#25
0
    def test_grace_period_continue_training(self):
        grace_period = 5
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)
        training_started_event = multi_process_runner.manager().Event()

        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        termination_config = failure_handling.TerminationConfig(
            time_till_termination=grace_period)
        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, [training_started_event], None,
                  termination_config),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()
        while not training_started_event.is_set():
            time.sleep(1)

        killed_worker = random.randrange(0, CLUSTER_SIZE)
        logging.info('sending SIGTERM')
        os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM)
        logging.info('SIGTERM sent')

        # wait for all cluster within the given grace period (plus a buffer since
        # our per-step time here is too small)
        waiting_time = 0
        exit_process_count = 0
        while exit_process_count != CLUSTER_SIZE and waiting_time < grace_period + 10:
            exit_process_count = 0
            for worker_id in range(CLUSTER_SIZE):
                if not mpr.process_exists('worker', worker_id):
                    exit_process_count += 1
            waiting_time += 1
            time.sleep(1)

        if waiting_time == grace_period + 10:
            raise RuntimeError('Waited exceeding grace period. ')

        logging.info('restarting workers')
        for worker_id in range(CLUSTER_SIZE):
            mpr.start_single_process('worker', worker_id, cluster_spec)
        logging.info('workers restarted')

        stdout = mpr.join(timeout=250).stdout
        all_start_point = []
        checkpoint_count = []
        for msg in stdout:
            # TODO(wxinyi): remove the string matching and assert checkpoint number.
            matched_group = re.search(r'.*Restored training at (\d+)', msg)
            checkpoint_group = re.search(r'.*RUN_TO_CHECKPOINT set to (\d+)',
                                         msg)

            if matched_group:
                all_start_point.append(int(matched_group.group(1)))

            if checkpoint_group:
                checkpoint_count.append(int(checkpoint_group.group(1)))

        # remove duplicate logs created due to presence of multiple workers
        start_points = all_start_point[::CLUSTER_SIZE]

        # assert that after restarting, we don't repeat previous training steps
        self.assertNotEqual(start_points[-1], 0)

        # One for timing, another for final call.
        self.assertLen(set(checkpoint_count), 2)
示例#26
0
    def test_grace_period_continue_training(self):
        grace_period = 7
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)

        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')
        maintenance_event = multi_process_runner.manager().Event()
        training_finished = multi_process_runner.manager().Event()

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        termination_config = failure_handling.TerminationConfig(
            time_till_termination=grace_period)
        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, maintenance_event,
                  training_finished, False, termination_config),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()

        while (not maintenance_event.is_set()) and (
                not training_finished.is_set()):
            time.sleep(1)

        # this addition to mitigate the fact that our step time is too short in test
        time.sleep(grace_period + 10)
        if not training_finished.is_set():
            logging.info('restarting workers')
            for worker_id in range(CLUSTER_SIZE):
                mpr.start_single_process('worker', worker_id, cluster_spec)
            logging.info('workers restarted')

        if maintenance_event.is_set():
            stdout = mpr.join(timeout=250).stdout
            all_start_point = []
            checkpoint_count = []
            for msg in stdout:
                matched_group = re.search(r'.*Start training at (\d+)', msg)
                checkpoint_group = re.search(
                    r'.*RUN_TO_CHECKPOINT set to (\d+)', msg)

                if matched_group:
                    all_start_point.append(int(matched_group.group(1)))

                if checkpoint_group:
                    checkpoint_count.append(int(checkpoint_group.group(1)))

            # remove duplicate logs created due to presence of multiple workers
            start_points = all_start_point[::CLUSTER_SIZE]

            # if maintenance_event is set at the very end of training and training
            # completes, there won't be a restart.
            if len(start_points) > 1:
                # assert that after restarting, we don't repeat previous training steps
                self.assertNotEqual(start_points[-1], 0)

                # One for timing, another for final call.
                self.assertLen(set(checkpoint_count), 2)
示例#27
0
  def test_two_workers_preempted_consecutively(self):
    has_chief = False
    cluster_spec = multi_worker_test_base.create_cluster_spec(
        has_chief=has_chief,
        num_workers=CLUSTER_SIZE)
    maintenance_event = multi_process_runner.manager().Event()
    training_finished = multi_process_runner.manager().Event()

    checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

    if _is_oss():
      rpc_layer = 'grpc'
    else:
      rpc_layer = 'grpc+loas'

    mpr = multi_process_runner.MultiProcessRunner(
        self.worker_fn,
        cluster_spec,
        args=(checkpoint_dir, cluster_spec,
              maintenance_event,
              training_finished, True),
        rpc_layer=rpc_layer,
        return_output=True,
        dependence_on_chief=has_chief)

    logging.info('Cluster starting.')
    mpr.start()

    time.sleep(5)

    # wait for all cluster to exit with a time out
    waiting_time = 0
    exit_process_count = 0
    while exit_process_count != CLUSTER_SIZE and waiting_time < 40:
      exit_process_count = 0
      for worker_id in range(CLUSTER_SIZE):
        if not mpr.process_exists('worker', worker_id):
          exit_process_count += 1
      waiting_time += 1
      time.sleep(1)

    if waiting_time == 100:
      raise RuntimeError('Waited long but at least one worker still exist. '
                         'Considering size of our model, this should not'
                         ' happen.')

    maintenance_event.set()
    logging.info('restarting workers')
    for worker_id in range(CLUSTER_SIZE):
      mpr.start_single_process('worker', worker_id, cluster_spec)
    logging.info('workers restarted')

    stdout = mpr.join().stdout
    found_message = 0
    for msg in stdout:
      matched_group = re.search(r'.*has received termination notice*', msg)

      if matched_group:
        found_message += 1

    self.assertGreaterEqual(found_message, 1)
  def test_grace_period_continue_training(self, input_arg, mwms_mode):
    if _is_oss():
      rpc_layer = 'grpc'
    else:
      rpc_layer = 'grpc+loas'

    checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

    if mwms_mode == 'multi_worker':
      grace_period = 5
      termination_config = failure_handling.TerminationConfig(
          grace_period=grace_period)
      has_chief = False
      cluster_spec = multi_worker_test_base.create_cluster_spec(
          has_chief=has_chief,
          num_workers=CLUSTER_SIZE)
      training_started_event = multi_process_runner.manager().Event()
      training_restarted = multi_process_runner.manager().Event()
      training_finished = multi_process_runner.manager().Event()

      mpr = multi_process_runner.MultiProcessRunner(
          self.worker_fn,
          cluster_spec,
          args=(checkpoint_dir, cluster_spec, input_arg,
                [training_started_event], None, training_restarted,
                training_finished, termination_config),
          rpc_layer=rpc_layer,
          return_output=True,
          dependence_on_chief=has_chief)

      logging.info('Cluster starting.')
      mpr.start()
      while not training_started_event.is_set():
        time.sleep(1)

      killed_worker = random.randrange(0, CLUSTER_SIZE)
      logging.info('sending SIGTERM')
      os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM)
      logging.info('SIGTERM sent')

      raise_if_not_all_exit(grace_period, mpr)

      logging.info('restarting workers')
      training_restarted.set()
      for worker_id in range(CLUSTER_SIZE):
        mpr.start_single_process('worker', worker_id, cluster_spec)
      logging.info('workers restarted')

      mpr.join(timeout=250)

    else:
      # This is because single worker trains super fast with regards to the size
      # of "model" here. With a longer grace period, the training just finishes
      # within the grace period so we can't verify the exit behavior.
      grace_period = 1
      termination_config = failure_handling.TerminationConfig(
          grace_period=grace_period)
      cluster_spec = server_lib.ClusterSpec({})

      training_started_event = threading.Event()
      training_restarted = threading.Event()
      training_finished = threading.Event()
      def sending_sigterm(training_started_event):
        while not training_started_event.is_set():
          time.sleep(1)
        logging.info('sending sigterm')
        training_started_event.set()
        os.kill(os.getpid(), signal.SIGTERM)

      preemption_sender_thread = threading.Thread(
          target=sending_sigterm, args=(training_started_event,))
      preemption_sender_thread.start()

      caught_exit = False
      try:
        self.worker_fn(checkpoint_dir, cluster_spec, input_arg,
                       [training_started_event], None, training_restarted,
                       training_finished, termination_config)

      except SystemExit as exit_error:
        caught_exit = True
        # We cannot use assertRaise instead, since termination is not always
        # triggered.
        self.assertEqual(exit_error.code, 42)  # pylint: disable=g-assert-in-except

      preemption_sender_thread.join(10)
      if not training_finished.is_set():
        self.assertTrue(caught_exit)

        logging.info('restarting workers')
        training_restarted.set()
        self.worker_fn(checkpoint_dir, cluster_spec, input_arg,
                       [training_started_event], None, training_restarted,
                       training_finished, termination_config)
示例#29
0
  def _test_translate_ps_failure_error(self,
                                       test_schedule=False,
                                       test_join=False):

    def proc_func(functions_scheduled_event, test_finished_event):
      cluster_resolver = TFConfigClusterResolver()
      if cluster_resolver.task_type != "chief":
        utils.start_server(cluster_resolver, "grpc")
      strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
          cluster_resolver)
      ps_client = client_lib.Client(strategy)

      with strategy.scope():
        v = variables.Variable(initial_value=0, dtype=dtypes.int32)

      @def_function.function
      def worker_fn():
        # An ever-running function.
        for _ in math_ops.range(100000):
          v.assign_add(1)

      # Keep the two workers occupied.
      ps_client.schedule(worker_fn)
      ps_client.schedule(worker_fn)
      # Now the main process can terminate.
      functions_scheduled_event.set()

      # Verified that join and schedule indeed raise
      # ParameterServerFailureError.
      try:
        if test_join:
          ps_client.join()
        if test_schedule:
          while ps_client.cluster._closure_queue._error is None:
            time.sleep(1)
          ps_client.schedule(worker_fn)
      except client_lib.ParameterServerFailureError:
        # The following verifies that after PS fails, continue executing
        # functions on workers should fail and indicate it's PS failure.
        for worker_id in range(3):
          with ops.device("/job:worker/replica:0/task:{}".format(worker_id)):
            try:
              # Executing a function after PS fails should result in a PS
              # failure.
              worker_fn()
            except Exception as e:  # pylint: disable=broad-except
              if client_lib._is_ps_failure(e):
                if worker_id < 2:
                  continue
                logging.info("_test_translate_ps_failure_error ends properly.")
                # Now we can safely exit the test.
                test_finished_event.set()
                return
            raise RuntimeError("Executing a function after PS fails, should "
                               "result in a PS failure.")

      raise RuntimeError("ParameterServerFailureError supposed to be raised.")

    manager = multi_process_runner.manager()
    functions_scheduled_event = manager.Event()
    test_finished_event = manager.Event()
    mpr = multi_process_runner.MultiProcessRunner(
        proc_func,
        multi_worker_test_base.create_cluster_spec(
            has_chief=True, num_workers=3, num_ps=1, has_eval=False),
        args=(functions_scheduled_event, test_finished_event),
        rpc_layer="grpc",
        list_stdout=True,
        use_dill_for_args=False)

    mpr.start()
    functions_scheduled_event.wait()
    mpr.terminate("ps", 0)
    while mpr.process_exists("ps", 0):
      time.sleep(0.01)
    test_finished_event.wait()
    self.assertTrue(
        any("_test_translate_ps_failure_error ends properly" in msg
            for msg in mpr.join().stdout))
  def test_preemption_checkpointing(self, input_arg, mwms_mode):
    has_chief = False

    if _is_oss():
      rpc_layer = 'grpc'
    else:
      rpc_layer = 'grpc+loas'

    checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

    if mwms_mode == 'multi_worker':
      cluster_spec = multi_worker_test_base.create_cluster_spec(
          has_chief=has_chief,
          num_workers=CLUSTER_SIZE)
      training_started_event = multi_process_runner.manager().Event()
      training_restarted = multi_process_runner.manager().Event()
      training_finished = multi_process_runner.manager().Event()

      mpr = multi_process_runner.MultiProcessRunner(
          self.worker_fn,
          cluster_spec,
          args=(checkpoint_dir, cluster_spec, input_arg,
                [training_started_event
                ], None, training_restarted, training_finished),
          rpc_layer=rpc_layer,
          return_output=True,
          dependence_on_chief=has_chief)

      logging.info('Cluster starting.')
      mpr.start()
      while not training_started_event.is_set():
        time.sleep(1)

      logging.info('sending sigterm')
      killed_worker = random.randrange(0, CLUSTER_SIZE)
      os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM)

      logging.info('sigterm sent')
      raise_if_not_all_exit(0, mpr)

      logging.info('restarting workers')
      training_restarted.set()
      for worker_id in range(CLUSTER_SIZE):
        mpr.start_single_process('worker', worker_id, cluster_spec)
      logging.info('workers restarted')

      mpr.join(timeout=270)

    else:
      cluster_spec = server_lib.ClusterSpec({})

      training_started_event = threading.Event()
      training_restarted = threading.Event()
      training_finished = threading.Event()

      def sending_sigterm(training_started_event):
        while not training_started_event.is_set():
          time.sleep(1)
        logging.info('sending sigterm')
        training_started_event.set()
        os.kill(os.getpid(), signal.SIGTERM)

      preemption_sender_thread = threading.Thread(
          target=sending_sigterm, args=(training_started_event,))
      preemption_sender_thread.start()

      caught_exit = False
      try:
        self.worker_fn(checkpoint_dir, cluster_spec, input_arg,
                       [training_started_event], None, training_restarted,
                       training_finished)

      except SystemExit as exit_error:
        caught_exit = True
        # We cannot use assertRaise instead, since termination is not always
        # triggered.
        self.assertEqual(exit_error.code, 42)  # pylint: disable=g-assert-in-except

      preemption_sender_thread.join(10)
      if not training_finished.is_set():
        self.assertTrue(caught_exit)

        logging.info('restarting workers')
        training_restarted.set()
        self.worker_fn(checkpoint_dir, cluster_spec, input_arg,
                       [training_started_event], None, training_restarted,
                       training_finished)