Beispiel #1
0
    def test_reduce_small_tensor(self):
        # This test simulates the case when a worker fails before or during reducing
        # a small tensors, e.g. reading a metric.
        #
        # Note that this is written for a specific corner case that used to happen
        # only when all of the following conditions are met:
        #   - There're two workers.
        #   - They're reducing a small tensor. The definition of small varies
        #     per platform.
        #   - They're reducing a single tensor. Batched all-reduce are not affected.
        #   - It must be worker-1 that fails.
        # Under this case, the all-reduce is effectively two send/recv operation,
        # the first one from worker-0 to worker-1, and the second one vice versa.
        # The first one blocks the second one. In send/recv, the sending party is
        # not aware of the failures of the receiving party.

        def worker_fn():
            strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
            value = tf.identity([1.])
            strategy.reduce("sum", value, axis=None)
            # worker-1 dies here.
            if strategy.cluster_resolver.task_id == 1:
                quick_exit(1)
            strategy.reduce("sum", value, axis=None)

        cluster_spec = multi_worker_test_base.create_cluster_spec(
            num_workers=2)
        mpr = multi_process_runner.MultiProcessRunner(worker_fn,
                                                      cluster_spec,
                                                      rpc_layer=RPC_PROTOCOL)
        mpr.start()
        # TODO(b/151232436): Always raise UnavailableError when a peer fails.
        with self.assertRaises(
            (tf.errors.UnavailableError, tf.errors.DeadlineExceededError)):
            mpr.join(timeout=60)
Beispiel #2
0
    def test_creating_variable(self):
        # See PeerFailureTest.test_creating_variable

        def worker_fn(attempts):
            strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
            task_id, attempt = get_attempt(strategy, attempts)
            with strategy.scope():
                tf.Variable(1.)
                # worker-1 dies here.
                if attempt == 1 and task_id == 1:
                    quick_exit(1)
                v = tf.Variable(tf.random.uniform(()))
                return v.read_value().numpy()

        cluster_spec = multi_worker_test_base.create_cluster_spec(
            num_workers=2)
        attempts = multi_process_runner.manager().dict()
        mpr = multi_process_runner.MultiProcessRunner(worker_fn,
                                                      cluster_spec,
                                                      rpc_layer=RPC_PROTOCOL,
                                                      args=(attempts, ),
                                                      auto_restart=True)
        mpr.start()
        results = mpr.join(timeout=90).return_value
        self.assertEqual(results[0], results[1])
Beispiel #3
0
    def test_creating_variable(self):
        # This test simulates the case when a worker fails before or during creating
        # a variable. Creating variables involve broadcasting the initial value from
        # the first replica to all replicas.

        def worker_fn():
            strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
            with strategy.scope():
                tf.Variable(1.)
                # worker-1 dies here.
                if strategy.cluster_resolver.task_id == 1:
                    quick_exit(1)
                v = tf.Variable(tf.random.uniform(()))
                return v.read_value().numpy()

        cluster_spec = multi_worker_test_base.create_cluster_spec(
            num_workers=2)
        mpr = multi_process_runner.MultiProcessRunner(worker_fn,
                                                      cluster_spec,
                                                      rpc_layer=RPC_PROTOCOL)
        mpr.start()
        # TODO(b/151232436): Always raise UnavailableError when a peer fails.
        with self.assertRaises(
            (tf.errors.UnavailableError, tf.errors.DeadlineExceededError)):
            mpr.join(timeout=30)
Beispiel #4
0
  def test_termination_and_start_single_process(self):

    def proc_func():
      for i in range(0, 10):
        print(
            'index {}, iteration {}'.format(self._worker_idx(), i), flush=True)
        time.sleep(1)

    mpr = multi_process_runner.MultiProcessRunner(
        proc_func,
        multi_worker_test_base.create_cluster_spec(num_workers=2),
        list_stdout=True)
    mpr.start()
    time.sleep(3)
    mpr.terminate('worker', 0)
    mpr.start_single_process('worker', 0)
    std_stream_results = mpr.join().stdout

    # Worker 0 is terminated in the middle, but a new worker 0 is added, so it
    # should still have iteration 9 printed. Moreover, iteration 0 of worker 0
    # should happen twice.
    self.assertLen(
        [s for s in std_stream_results if 'index 0, iteration 0' in s], 2)
    self.assertIn('[worker-0]:    index 0, iteration 9\n', std_stream_results)
    self.assertIn('[worker-1]:    index 1, iteration 0\n', std_stream_results)
    self.assertIn('[worker-1]:    index 1, iteration 9\n', std_stream_results)
Beispiel #5
0
  def test_reduce_small_tensor(self):
    # See PeerFailureTest.test_reduce_small_tensor

    def worker_fn(attempts):
      context.context().enable_coordination_service(COORDINATION_SERVICE)
      strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
      task_id, attempt = get_attempt(strategy, attempts)
      value = tf.identity([1.])
      strategy.reduce("sum", value, axis=None)
      # worker-1 dies here.
      if attempt == 1 and task_id == 1:
        quick_exit(1)
      return strategy.reduce("sum", value, axis=None).numpy()

    cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
    attempts = multi_process_runner.manager().dict()
    mpr = multi_process_runner.MultiProcessRunner(
        worker_fn,
        cluster_spec,
        rpc_layer=RPC_PROTOCOL,
        args=(attempts,),
        auto_restart=True)
    mpr.start()
    results = mpr.join(timeout=90).return_value
    self.assertAllEqual(results, [[2.], [2.]])
Beispiel #6
0
  def test_start_in_process_as(self):

    def proc_func():
      for i in range(5):
        logging.info('%s-%d, i: %d', multi_worker_test_base.get_task_type(),
                     self._worker_idx(), i)
        time.sleep(1)

    mpr = multi_process_runner.MultiProcessRunner(
        proc_func,
        multi_worker_test_base.create_cluster_spec(
            has_chief=True, num_workers=1),
        list_stdout=True)

    def eval_func():
      time.sleep(1)
      mpr.start_single_process(task_type='evaluator', task_id=0)

    eval_thread = threading.Thread(target=eval_func)
    eval_thread.start()
    mpr.start_in_process_as(as_task_type='chief', as_task_id=0)
    eval_thread.join()
    list_to_assert = mpr.join().stdout
    for job in ['worker', 'evaluator']:
      for iteration in range(5):
        self.assertTrue(
            any('{}-0, i: {}'.format(job, iteration) in line
                for line in list_to_assert))
    def test_creating_variable_broken(self):
        # This test simulates the case when a worker fails before or during creating
        # a variable. Creating variables involve broadcasting the initial value from
        # the first replica to all replicas.

        def worker_fn(attempts):
            strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
            task_id, attempt = get_attempt(strategy, attempts)
            with strategy.scope():
                tf.Variable(1.)
                # worker-1 dies here.
                if attempt == 1 and task_id == 1:
                    quick_exit(1)
                v = tf.Variable(tf.random.uniform(()))
                return v.read_value().numpy()

        cluster_spec = multi_worker_test_base.create_cluster_spec(
            num_workers=2)
        attempts = multi_process_runner.manager().dict()
        mpr = multi_process_runner.MultiProcessRunner(worker_fn,
                                                      cluster_spec,
                                                      args=(attempts, ),
                                                      auto_restart=True)
        mpr.start()
        # TODO(b/151232436): worker-0 should raises Unavailable instead of hanging.
        # Now after worker-1 fails, worker-0 waits on the second variable creation;
        # after worker-1 recovers, worker-1 waits on the first variable creation.
        with self.assertRaises(multi_process_runner.SubprocessTimeoutError):
            mpr.join(timeout=30)
Beispiel #8
0
  def test_termination(self):

    def proc_func():
      for i in range(0, 10):
        print(
            'index {}, iteration {}'.format(self._worker_idx(), i), flush=True)
        time.sleep(5)

    mpr = multi_process_runner.MultiProcessRunner(
        proc_func,
        multi_worker_test_base.create_cluster_spec(num_workers=2),
        list_stdout=True)
    mpr.start()
    time.sleep(5)
    mpr.terminate('worker', 0)
    with self.assertRaises(
        multi_process_runner.UnexpectedSubprocessExitError) as cm:
      mpr.join()

    std_stream_results = cm.exception.mpr_result.stdout

    # Worker 0 is terminated in the middle, so it should not have iteration 9
    # printed.
    self.assertIn('[worker-0]:    index 0, iteration 0\n', std_stream_results)
    self.assertNotIn('[worker-0]:    index 0, iteration 9\n',
                     std_stream_results)
    self.assertIn('[worker-1]:    index 1, iteration 0\n', std_stream_results)
    self.assertIn('[worker-1]:    index 1, iteration 9\n', std_stream_results)
Beispiel #9
0
    def test_error_propagation(self):
        error_worker = random.randint(0, CLUSTER_SIZE)
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=False, num_workers=CLUSTER_SIZE)
        checkpoint_dir = self.get_temp_dir()

        def assert_raise_error():
            # Asserts that an error raised during a training step on one of the worker
            # is caught on all workers.
            with self.assertRaises(
                    errors_impl.ResourceExhaustedError) as error:
                self.worker_fn(checkpoint_dir,
                               cluster_spec,
                               raise_app_error_on_worker=error_worker)
            self.assertIn('Running out of resources', str(error.exception))

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        mpr = multi_process_runner.MultiProcessRunner(
            assert_raise_error,
            cluster_spec,
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=False)

        logging.info('Cluster starting.')
        mpr.start()
        mpr.join(timeout=250)
Beispiel #10
0
    def start(self):
        """Starts one TensorFlow server for each task in the cluster_resolver.

    It will wait until all the servers are up before returns.
    """
        if self._mpr:
            raise ValueError('The cluster has already been started.')
        for task_type, task_addresses in self._cluster_spec.items():
            self._start_events[task_type] = []
            self._finish_events[task_type] = []
            for _ in task_addresses:
                self._start_events[task_type].append(self._mpr_manager.Event())
                self._finish_events[task_type].append(
                    self._mpr_manager.Event())

        self._mpr = multi_process_runner.MultiProcessRunner(
            self._task_function,
            self._cluster_spec,
            args=(self._start_events, self._finish_events),
            rpc_layer=self._rpc_layer,
            stream_stdout=False,
            list_stdout=False,
            use_dill_for_args=False)
        self._mpr.start()
        for task_type, task_addresses in self._cluster_spec.items():
            for i in range(len(task_addresses)):
                self._start_events[task_type][i].wait()
    def test_reduce_small_tensor_broken(self):
        # This test simulates the case when a worker fails before or during reducing
        # a small tensors, e.g. reading a metric.
        #
        # Note that this is a rather corner case and only happens when all of the
        # following conditions are met:
        #   - There're two workers.
        #   - They're reducing a small tensor. The definition of small varies
        #     per platform.
        #   - They're reducing a single tensor. Batched all-reduce are not affected.
        #   - It must be worker-1 that fails.

        def worker_fn(attempts):
            strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
            task_id, attempt = get_attempt(strategy, attempts)
            value = tf.identity([1.])
            strategy.reduce("sum", value, axis=None)
            # worker-1 dies here.
            if attempt == 1 and task_id == 1:
                quick_exit(1)
            strategy.reduce("sum", value, axis=None)

        cluster_spec = multi_worker_test_base.create_cluster_spec(
            num_workers=2)
        attempts = multi_process_runner.manager().dict()
        mpr = multi_process_runner.MultiProcessRunner(worker_fn,
                                                      cluster_spec,
                                                      args=(attempts, ),
                                                      auto_restart=True)
        mpr.start()
        # TODO(b/151232436): worker-0 should raises Unavailable instead of hanging.
        # Now after worker-1 fails, worker-0 waits on the second reduce; after
        # worker-1 recovers, worker-1 waits on the first reduce.
        with self.assertRaises(multi_process_runner.SubprocessTimeoutError):
            mpr.join(timeout=30)
    def test_termination(self):
        def fn():
            for i in range(0, 10):
                print('index {}, iteration {}'.format(self._worker_idx(), i),
                      flush=True)
                time.sleep(5)

        mpr = multi_process_runner.MultiProcessRunner(
            fn,
            multi_worker_test_base.create_cluster_spec(num_workers=2),
            return_output=True)
        mpr.start()
        time.sleep(5)
        mpr.terminate('worker', 0)

        std_stream_results = mpr.join().stdout

        # Worker 0 is terminated in the middle, so it should not have iteration 9
        # printed.
        self.assertIn('[worker-0]:    index 0, iteration 0\n',
                      std_stream_results)
        self.assertNotIn('[worker-0]:    index 0, iteration 9\n',
                         std_stream_results)
        self.assertIn('[worker-1]:    index 1, iteration 0\n',
                      std_stream_results)
        self.assertIn('[worker-1]:    index 1, iteration 9\n',
                      std_stream_results)
Beispiel #13
0
    def test_grace_period_continue_training(self, input_arg):
        grace_period = 5
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)
        training_started_event = multi_process_runner.manager().Event()
        training_restarted = multi_process_runner.manager().Event()
        training_finished = multi_process_runner.manager().Event()
        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        termination_config = failure_handling.TerminationConfig(
            grace_period=grace_period)
        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, input_arg,
                  [training_started_event], None, training_restarted,
                  training_finished, termination_config),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()
        while not training_started_event.is_set():
            time.sleep(1)

        killed_worker = random.randrange(0, CLUSTER_SIZE)
        logging.info('sending SIGTERM')
        os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM)
        logging.info('SIGTERM sent')

        # wait for all cluster within the given grace period (plus a buffer since
        # our per-step time here is too small)
        waiting_time = 0
        exit_process_count = 0
        while exit_process_count != CLUSTER_SIZE and waiting_time < grace_period + 10:
            exit_process_count = 0
            for worker_id in range(CLUSTER_SIZE):
                if not mpr.process_exists('worker', worker_id):
                    exit_process_count += 1
            waiting_time += 1
            time.sleep(1)

        if waiting_time == grace_period + 10:
            raise RuntimeError('Waited exceeding grace period. ')

        logging.info('restarting workers')
        training_restarted.set()
        for worker_id in range(CLUSTER_SIZE):
            mpr.start_single_process('worker', worker_id, cluster_spec)
        logging.info('workers restarted')

        mpr.join(timeout=250)
    def test_multiple_workers_preempted_consecutively(self, grace_period,
                                                      input_arg):
        has_chief = False
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=has_chief, num_workers=CLUSTER_SIZE)
        maintenance_event = multi_process_runner.manager().Event()
        training_finished = multi_process_runner.manager().Event()
        training_restarted = multi_process_runner.manager().Event()

        checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt/')

        if _is_oss():
            rpc_layer = 'grpc'
        else:
            rpc_layer = 'grpc+loas'

        termination_config = failure_handling.TerminationConfig(
            grace_period=grace_period)
        mpr = multi_process_runner.MultiProcessRunner(
            self.worker_fn,
            cluster_spec,
            args=(checkpoint_dir, cluster_spec, input_arg, maintenance_event,
                  training_finished, True, training_restarted,
                  termination_config),
            rpc_layer=rpc_layer,
            return_output=True,
            dependence_on_chief=has_chief)

        logging.info('Cluster starting.')
        mpr.start()

        # wait for all cluster to exit with a time out
        waiting_time = 0
        exit_process_count = 0
        # this addition to mitigate the fact that our step time is too short in test
        while exit_process_count != CLUSTER_SIZE and waiting_time < max(
                grace_period + 15, 40):
            exit_process_count = 0
            for worker_id in range(CLUSTER_SIZE):
                if not mpr.process_exists('worker', worker_id):
                    exit_process_count += 1
            waiting_time += 1
            time.sleep(1)

        if waiting_time == max(grace_period + 5, 40):
            raise RuntimeError(
                'Waited long but at least one worker still exist. '
                'Considering size of our model, this should not'
                ' happen.')

        maintenance_event.set()
        logging.info('restarting workers')
        training_restarted.set()
        for worker_id in range(CLUSTER_SIZE):
            mpr.start_single_process('worker', worker_id, cluster_spec)
        logging.info('workers restarted')

        mpr.join(timeout=250)
        self.assertTrue(training_finished.is_set())
Beispiel #15
0
 def test_multi_process_runner_error_propagates_from_subprocesses(self):
   runner = multi_process_runner.MultiProcessRunner(
       proc_func_that_errors,
       multi_worker_test_base.create_cluster_spec(num_workers=1, num_ps=1),
       max_run_time=20)
   runner.start()
   with self.assertRaisesRegexp(ValueError, 'This is an error.'):
     runner.join()
Beispiel #16
0
 def test_terminate_all_does_not_ignore_error(self):
     mpr = multi_process_runner.MultiProcessRunner(
         proc_func_that_errors,
         multi_worker_test_base.create_cluster_spec(num_workers=2),
         list_stdout=True)
     mpr.start()
     mpr.terminate_all()
     with self.assertRaisesRegexp(ValueError, 'This is an error.'):
         mpr.join()
Beispiel #17
0
 def test_terminate_all_does_not_ignore_error(self):
     mpr = multi_process_runner.MultiProcessRunner(
         fn_that_errors,
         multi_worker_test_base.create_cluster_spec(num_workers=2),
         return_output=True)
     mpr.start()
     time.sleep(60)
     mpr.terminate_all()
     with self.assertRaisesRegex(ValueError, 'This is an error.'):
         mpr.join()
    def test_streaming(self):
        def proc_func():
            for i in range(5):
                logging.info('(logging) %s-%d, i: %d',
                             multi_worker_test_base.get_task_type(),
                             self._worker_idx(), i)
                print('(print) {}-{}, i: {}'.format(
                    multi_worker_test_base.get_task_type(), self._worker_idx(),
                    i),
                      flush=True)
                time.sleep(1)

        mpr = multi_process_runner.MultiProcessRunner(
            proc_func,
            multi_worker_test_base.create_cluster_spec(has_chief=True,
                                                       num_workers=2,
                                                       num_ps=2,
                                                       has_eval=True),
            list_stdout=True)
        mpr._dependence_on_chief = False

        mpr.start()
        mpr.start_single_process('worker', 2)
        mpr.start_single_process('ps', 2)
        mpr_result = mpr.join()

        list_to_assert = mpr_result.stdout

        for job in ['chief', 'evaluator']:
            for iteration in range(5):
                self.assertTrue(
                    any('(logging) {}-0, i: {}'.format(job, iteration) in line
                        for line in list_to_assert))
                self.assertTrue(
                    any('(print) {}-0, i: {}'.format(job, iteration) in line
                        for line in list_to_assert))

        for job in ['worker', 'ps']:
            for iteration in range(5):
                for task in range(3):
                    self.assertTrue(
                        any('(logging) {}-{}, i: {}'.format(
                            job, task, iteration) in line
                            for line in list_to_assert))
                    self.assertTrue(
                        any('(print) {}-{}, i: {}'.format(
                            job, task, iteration) in line
                            for line in list_to_assert))
                task = 3
                self.assertFalse(
                    any('(logging) {}-{}, i: {}'.format(job, task, iteration)
                        in line for line in list_to_assert))
                self.assertFalse(
                    any('(print) {}-{}, i: {}'.format(job, task, iteration) in
                        line for line in list_to_assert))
    def test_timeout_none(self):
        def fn():
            time.sleep(250)
            raise ValueError('Worker 0 errored')

        mpr = multi_process_runner.MultiProcessRunner(
            fn, multi_worker_test_base.create_cluster_spec(num_workers=1))

        mpr.start()
        with self.assertRaisesRegex(ValueError, 'Worker 0 errored'):
            mpr.join(timeout=None)
Beispiel #20
0
  def testCheckHealthInvalidPeer(self):

    def worker_fn():
      enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
      context.context().check_collective_ops_peer_health("localhost:12345",)

    cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
    mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
    mpr.start_single_process("worker", 0)
    with self.assertRaises(errors.InvalidArgumentError):
      mpr.join()
Beispiel #21
0
 def test_signal_doesnt_fire_after_process_exits(self):
     mpr = multi_process_runner.MultiProcessRunner(
         proc_func_that_does_nothing,
         multi_worker_test_base.create_cluster_spec(num_workers=1),
         max_run_time=10)
     mpr.start()
     mpr.join()
     with self.assertRaisesRegexp(Queue.Empty, ''):
         # If the signal was fired, another message would be added to internal
         # queue, so verifying it's empty.
         mpr._get_process_status_queue().get(block=False)
Beispiel #22
0
  def testCheckHealthPeerDown(self):

    def worker_fn():
      enable_collective_ops(cluster_resolver_lib.TFConfigClusterResolver())
      context.context().check_collective_ops_peer_health(
          "/job:worker/replica:0/task:1",)

    cluster_spec = multi_worker_test_base.create_cluster_spec(num_workers=2)
    mpr = multi_process_runner.MultiProcessRunner(worker_fn, cluster_spec)
    mpr.start_single_process("worker", 0)
    with self.assertRaises(errors.UnavailableError):
      mpr.join()
Beispiel #23
0
    def test_error_reporting_overrides_timeout_reporting(self):
        def fn():
            if self._worker_idx() == 1:
                time.sleep(10000)
            raise ValueError('Worker 0 errored')

        mpr = multi_process_runner.MultiProcessRunner(
            fn, multi_worker_test_base.create_cluster_spec(num_workers=2))
        mpr.start()

        with self.assertRaisesRegex(ValueError, 'Worker 0 errored'):
            mpr.join(timeout=20)
Beispiel #24
0
    def test_process_exists(self):
        def fn():
            time.sleep(100000)

        mpr = multi_process_runner.MultiProcessRunner(
            fn, multi_worker_test_base.create_cluster_spec(num_workers=1))
        mpr.start()
        self.assertTrue(mpr.process_exists('worker', 0))
        mpr.terminate('worker', 0)
        # Worker 0 should exit at some point, or else the test would time out.
        while mpr.process_exists('worker', 0):
            time.sleep(1)
Beispiel #25
0
    def test_auto_restart_and_timeout(self):
        def proc_func():
            time.sleep(1)
            raise ValueError

        mpr = multi_process_runner.MultiProcessRunner(
            proc_func,
            multi_worker_test_base.create_cluster_spec(num_workers=1),
            auto_restart=True)
        mpr.start()
        with self.assertRaises(ValueError):
            mpr.join(timeout=10)
Beispiel #26
0
  def test_preemption_checkpointing(self):
    has_chief = False
    cluster_spec = multi_worker_test_base.create_cluster_spec(
        has_chief=has_chief,
        num_workers=CLUSTER_SIZE)
    training_started_event = multi_process_runner.manager().Event()

    checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

    if _is_oss():
      rpc_layer = 'grpc'
    else:
      rpc_layer = 'grpc+loas'

    mpr = multi_process_runner.MultiProcessRunner(
        self.worker_fn,
        cluster_spec,
        args=(checkpoint_dir, cluster_spec, [training_started_event]),
        rpc_layer=rpc_layer,
        return_output=True,
        dependence_on_chief=has_chief)

    logging.info('Cluster starting.')
    mpr.start()
    while not training_started_event.is_set():
      time.sleep(1)

    logging.info('sending sigterm')
    killed_worker = random.randrange(0, CLUSTER_SIZE)
    os.kill(mpr.get_process_id('worker', killed_worker), signal.SIGTERM)

    logging.info('sigterm sent')
    time.sleep(5)

    logging.info('restarting workers')
    for worker_id in range(CLUSTER_SIZE):
      mpr.start_single_process('worker', worker_id, cluster_spec)
    logging.info('workers restarted')

    stdout = mpr.join().stdout
    all_start_point = []
    for msg in stdout:
      matched_group = re.search(r'.*Restored training at (\d+)', msg)

      if matched_group:
        all_start_point.append(int(matched_group.group(1)))

    # remove duplicate logs created due to presence of multiple workers
    start_points = all_start_point[::CLUSTER_SIZE]

    # assert that after restarting, we don't repeat previous training steps
    self.assertNotEqual(start_points[-1], 0)
  def test_basic_run(self):
    has_chief = False
    cluster_spec = multi_worker_test_base.create_cluster_spec(
        has_chief=has_chief,
        num_workers=CLUSTER_SIZE)
    maintenance_event = multi_process_runner.manager().Event()
    training_finished = multi_process_runner.manager().Event()

    checkpoint_dir = os.path.join(self.get_temp_dir(), 'fh_ckpt')

    if _is_oss():
      rpc_layer = 'grpc'
    else:
      rpc_layer = 'grpc+loas'

    mpr = multi_process_runner.MultiProcessRunner(
        self.worker_fn,
        cluster_spec,
        args=(checkpoint_dir, cluster_spec, maintenance_event,
              training_finished),
        rpc_layer=rpc_layer,
        return_output=True,
        dependence_on_chief=has_chief)

    logging.info('Cluster starting.')
    mpr.start()

    while (not maintenance_event.is_set()) and (not training_finished.is_set()):
      time.sleep(1)

    time.sleep(5)
    if not training_finished.is_set():
      logging.info('restarting workers')
      for worker_id in range(CLUSTER_SIZE):
        mpr.start_single_process('worker', worker_id, cluster_spec)
      logging.info('workers restarted')

    stdout = mpr.join().stdout
    if maintenance_event.is_set():
      all_start_point = []
      for msg in stdout:
        matched_group = re.search(r'.*Start training at (\d+)', msg)

        if matched_group:
          all_start_point.append(int(matched_group.group(1)))

      # remove duplicate logs created due to presence of multiple workers
      start_points = all_start_point[::CLUSTER_SIZE]

      if len(start_points) > 1:
        # assert that after restarting, we don't repeat previous training steps
        self.assertNotEqual(start_points[-1], 0)
Beispiel #28
0
    def test_exit_code_is_reported_by_subprocess(self):
        def proc_func_expected_to_exit_with_10():
            sys.exit(10)

        mpr = multi_process_runner.MultiProcessRunner(
            proc_func_expected_to_exit_with_10,
            multi_worker_test_base.create_cluster_spec(num_workers=1))
        mpr.start()

        with self.assertRaisesRegex(
                multi_process_runner.UnexpectedSubprocessExitError,
                'Subprocess worker-0 exited with exit code 10'):
            mpr.join()
Beispiel #29
0
    def test_process_that_exits(self):
        def func_to_exit_in_10_sec():
            time.sleep(5)
            mpr._add_return_data('foo')
            time.sleep(20)
            mpr._add_return_data('bar')

        mpr = multi_process_runner.MultiProcessRunner(
            func_to_exit_in_10_sec,
            multi_worker_test_base.create_cluster_spec(num_workers=1),
            max_run_time=10)

        mpr.start()
        returned_data, _ = mpr.join()
        self.assertLen(returned_data, 1)
Beispiel #30
0
    def test_timeout_none(self):

        if multi_process_runner.is_oss():
            self.skipTest('Intentionally skipping longer test in OSS.')

        def fn():
            time.sleep(250)
            raise ValueError('Worker 0 errored')

        mpr = multi_process_runner.MultiProcessRunner(
            fn, multi_worker_test_base.create_cluster_spec(num_workers=1))

        mpr.start()
        with self.assertRaisesRegex(ValueError, 'Worker 0 errored'):
            mpr.join(timeout=None)