def setUp(self):
    super(FaultToleranceTest, self).setUp()

    # Set the environment variable to prevent hanging upon job failure and
    # restart. Note that it defaults to 'use_caller' at Google, but defaults
    # to False in OSS.
    os.environ["GRPC_FAIL_FAST"] = "use_caller"

    self._cluster = multi_worker_test_base.create_multi_process_cluster(
        num_workers=FaultToleranceTest.NUM_WORKERS,
        num_ps=FaultToleranceTest.NUM_PS,
        rpc_layer="grpc")
    self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict()
    self._cluster_def["chief"] = [
        "localhost:%d" % multi_worker_test_base.pick_unused_port()
    ]
    cluster_resolver = SimpleClusterResolver(
        server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc")

    # The strategy's constructor would connect to the cluster.
    self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        cluster_resolver)
    self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy)

    self.thread_coord = thread_coordinator.Coordinator(
        clean_stop_exception_types=[])
Exemple #2
0
    def test_dataset_creator_usage_in_parameter_server_model_fit(self):
        cluster_def = multi_worker_test_base.create_in_process_cluster(
            num_workers=2, num_ps=1, rpc_layer="grpc")
        cluster_def["chief"] = [
            "localhost:%d" % multi_worker_test_base.pick_unused_port()
        ]
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc"))
        with strategy.scope():
            model = sequential.Sequential([core_layers.Dense(10)])
        model.compile(gradient_descent.SGD(), loss="mse")

        def dataset_fn(input_context):
            global_batch_size = 64
            batch_size = input_context.get_per_replica_batch_size(
                global_batch_size)
            dataset = dataset_ops.DatasetV2.from_tensors(([1.], [1.])).repeat()
            dataset = dataset.shard(input_context.num_input_pipelines,
                                    input_context.input_pipeline_id)
            dataset = dataset.batch(batch_size)
            dataset = dataset.prefetch(2)
            return dataset

        history = model.fit(dataset_creator.DatasetCreator(dataset_fn),
                            epochs=10,
                            steps_per_epoch=10,
                            verbose=0)
        self.assertLen(history.history["loss"], 10)
Exemple #3
0
def make_parameter_server_cluster(num_workers, num_ps):
  cluster_def = multi_worker_test_base.create_in_process_cluster(
      num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
  cluster_def["chief"] = [
      "localhost:%d" % multi_worker_test_base.pick_unused_port()
  ]
  return SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc")
Exemple #4
0
  def testClusterCoordinatorMetrics(self):

    metric_utils.enable_metrics = True

    cluster_def = multi_worker_test_base.create_in_process_cluster(
        num_workers=1, num_ps=1, rpc_layer=self.get_rpc_layer())
    cluster_def['chief'] = [
        'localhost:%d' % multi_worker_test_base.pick_unused_port()
    ]
    cluster_resolver = SimpleClusterResolver(
        ClusterSpec(cluster_def), rpc_layer=self.get_rpc_layer())
    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        cluster_resolver)
    cluster = coordinator_lib.Cluster(strategy)

    @def_function.function
    def func():
      time.sleep(0.5)
      return 3

    result = cluster.schedule(func, args=None, kwargs=None)
    result = cluster.schedule(func, args=None, kwargs=None)
    cluster.join()
    self.assertEqual(result.fetch(), 3)

    # Tracing, closure execution, and remote_value fetching should be executed
    # exactly once for running this function.
    metric_tracing = metric_utils.get_metric_summary('function_tracing')
    self.assertEqual(metric_tracing['num'], 1)
    # Tracing time should be longer than the sleep time in Python function.
    self.assertGreater(metric_tracing['sum'], 0.5)
    metric_closure = metric_utils.get_metric_summary('closure_execution')
    self.assertEqual(metric_closure['num'], 2)
    metric_remote_value = metric_utils.get_metric_summary('remote_value_fetch')
    self.assertEqual(metric_remote_value['num'], 2)
def make_coordinator(num_workers, num_ps):
    cluster_def = multi_worker_test_base.create_in_process_cluster(
        num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
    cluster_def["chief"] = [
        "localhost:%d" % multi_worker_test_base.pick_unused_port()
    ]
    cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def),
                                             rpc_layer="grpc")
    return tf.distribute.experimental.coordinator.ClusterCoordinator(
        tf.distribute.experimental.ParameterServerStrategy(cluster_resolver))
Exemple #6
0
 def testArbitraryCurrentTaskType(self):
   cluster_def = multi_worker_test_base._create_cluster(
       num_workers=1, num_ps=1)
   cluster_def["chief"] = [
       "localhost:%d" % multi_worker_test_base.pick_unused_port()
   ]
   cluster_resolver = SimpleClusterResolver(
       ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar")
   with self.assertRaisesRegexp(ValueError, "Unrecognized task_type: foobar"):
     parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
Exemple #7
0
 def testArbitraryJobName(self):
   cluster_def = multi_worker_test_base.create_cluster_spec(
       num_workers=1, num_ps=1, has_chief=True)
   cluster_def["some_arbitrary_name"] = [
       "localhost:%d" % multi_worker_test_base.pick_unused_port()
   ]
   cluster_resolver = SimpleClusterResolver(
       ClusterSpec(cluster_def), rpc_layer="grpc")
   with self.assertRaisesRegexp(ValueError, "Disallowed task type found in"):
     parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
Exemple #8
0
def make_client(num_workers, num_ps):
    # TODO(rchao): Test the internal rpc_layer version.
    cluster_def = multi_worker_test_base.create_in_process_cluster(
        num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
    cluster_def["chief"] = [
        "localhost:%d" % multi_worker_test_base.pick_unused_port()
    ]
    cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def),
                                             rpc_layer="grpc")
    return parameter_server_client.ParameterServerClient(cluster_resolver)
Exemple #9
0
def make_client(num_workers, num_ps):
    cluster_def = multi_worker_test_base.create_in_process_cluster(
        num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
    cluster_def["chief"] = [
        "localhost:%d" % multi_worker_test_base.pick_unused_port()
    ]
    cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def),
                                             rpc_layer="grpc")
    return client_lib.Client(
        parameter_server_strategy_v2.ParameterServerStrategyV2(
            cluster_resolver))
Exemple #10
0
 def testLessThanOneWorker(self):
   cluster_def = multi_worker_test_base._create_cluster(
       num_workers=0, num_ps=1)
   cluster_def["chief"] = [
       "localhost:%d" % multi_worker_test_base.pick_unused_port()
   ]
   cluster_resolver = SimpleClusterResolver(
       ClusterSpec(cluster_def), rpc_layer="grpc", task_type="ps", task_id=0)
   with self.assertRaisesRegexp(ValueError,
                                "There must be at least one worker."):
     parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
def make_coordinator(num_workers, num_ps):
  # TODO(rchao): Test the internal rpc_layer version.
  cluster_def = multi_worker_test_base.create_in_process_cluster(
      num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc')
  cluster_def['chief'] = [
      'localhost:%d' % multi_worker_test_base.pick_unused_port()
  ]
  cluster_resolver = SimpleClusterResolver(
      ClusterSpec(cluster_def), rpc_layer='grpc')
  strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
      cluster_resolver)
  return coordinator_lib.ClusterCoordinator(strategy)
Exemple #12
0
 def testMoreThanOneChief(self):
   cluster_def = multi_worker_test_base.create_cluster_spec(
       num_workers=1, num_ps=1)
   chief_ports = [multi_worker_test_base.pick_unused_port() for _ in range(3)]
   cluster_def["chief"] = ["localhost:%s" % port for port in chief_ports]
   cluster_resolver = SimpleClusterResolver(
       ClusterSpec(cluster_def),
       rpc_layer="grpc",
       task_type="chief",
       task_id=1)
   with self.assertRaisesRegexp(ValueError,
                                "There must be at most one 'chief' job."):
     parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
    def test_dataset_creator_usage_in_parameter_server_model_fit(self):
        cluster_def = multi_worker_test_base.create_in_process_cluster(
            num_workers=2, num_ps=1, rpc_layer="grpc")
        cluster_def["chief"] = [
            "localhost:%d" % multi_worker_test_base.pick_unused_port()
        ]
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc"))
        with strategy.scope():
            model = sequential.Sequential([core_layers.Dense(10)])
        model.compile(gradient_descent.SGD(), loss="mse")

        history = model.fit(dataset_creator.DatasetCreator(
            self._get_dataset_fn()),
                            epochs=10,
                            steps_per_epoch=10,
                            verbose=0)
        self.assertLen(history.history["loss"], 10)
def job_count_to_cluster_spec(job_count_dict):
    """Convert a job count dict to cluster spec.

  Args:
    job_count_dict: Dict for task_type/count of such task type.
        {'worker': 1, 'ps': 1} is an example of a cluster with a worker and a
          ps.

  Returns:
    The converted cluster spec dict.
  """

    cluster_spec = {}
    for task_type, count in job_count_dict.items():
        cluster_spec[task_type] = [
            'localhost:{}'.format(multi_worker_test_base.pick_unused_port())
            for _ in range(count)
        ]
    return cluster_spec
  def setUp(self, num_workers, num_ps):
    super(BaseFaultToleranceTest, self).setUp()

    self._cluster = multi_worker_test_base.create_multi_process_cluster(
        num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
    self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict()
    self._cluster_def["chief"] = [
        "localhost:%d" % multi_worker_test_base.pick_unused_port()
    ]
    cluster_resolver = SimpleClusterResolver(
        server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc")

    # The strategy's constructor would connect to the cluster.
    self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
        cluster_resolver)
    self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy)

    self.thread_coord = thread_coordinator.Coordinator(
        clean_stop_exception_types=[])
    self.num_workers = num_workers
    self.num_ps = num_ps
Exemple #16
0
    def testClientMetrics(self):
        if sys.version_info >= (3, 8) and platform.system() == 'Windows':
            # TODO(b/165013260): Fix this
            self.skipTest(
                'Test is currently broken on Windows with Python 3.8')

        metric_utils.enable_metrics = True

        cluster_def = multi_worker_test_base.create_in_process_cluster(
            num_workers=1, num_ps=1, rpc_layer=self.get_rpc_layer())
        cluster_def['chief'] = [
            'localhost:%d' % multi_worker_test_base.pick_unused_port()
        ]
        cluster_resolver = SimpleClusterResolver(
            ClusterSpec(cluster_def), rpc_layer=self.get_rpc_layer())
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            cluster_resolver)
        cluster = client.Cluster(strategy)

        @def_function.function
        def func():
            time.sleep(0.5)
            return 3

        result = cluster.schedule(func, args=None, kwargs=None)
        result = cluster.schedule(func, args=None, kwargs=None)
        cluster.join()
        self.assertEqual(result._get_value().numpy(), 3)

        # Tracing, closure execution, and remote_value fetching should be executed
        # exactly once for running this function.
        metric_tracing = metric_utils.get_metric_summary('function_tracing')
        self.assertEqual(metric_tracing['num'], 1)
        # Tracing time should be longer than the sleep time in Python function.
        self.assertGreater(metric_tracing['sum'], 0.5)
        metric_closure = metric_utils.get_metric_summary('closure_execution')
        self.assertEqual(metric_closure['num'], 2)
        metric_remote_value = metric_utils.get_metric_summary(
            'remote_value_fetch')
        self.assertEqual(metric_remote_value['num'], 2)
Exemple #17
0
    def testFaultToleranceInSyncStrategy(self, strategy_cls, file_format,
                                         preemption_callback):
        """Test fault-tolerance with multi-threading using sync dist-strat.

    This test simulates multi-worker training that is interrupted by a
    preemption, by having two threads, each of which represents a chief and a
    non-chief worker, where the non-chief raises an error in the middle of
    training loop. Upon excepting the error, a new thread with a new cluster
    spec is created to simulate the recovered non-chief worker. Meanwhile, the
    chief worker cannot proceed and hangs since the non-chief worker has
    crashed. To simulate a restart of the chief, a new thread has been prepared
    to run to take over chief with the help of a condition variable. It is
    expected that after the restart of both chief and non-chief workers, the
    training continues from the epoch they previously failed at. The test
    concludes by verifying the preemption-interrupted training can finish with
    the same loss and accuracy had the preemption not occurred.

    Arguments:
      strategy_cls: The strategy class to use.
      file_format: `h5` or `tf`.
      preemption_callback: The callback to simulate preemption.
    """
        def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
            with test.mock.patch.object(dc, '_run_std_server',
                                        self._make_mock_run_std_server()):
                # Condition variable that blocks the thread that represents the
                # restarted chief.
                cv = kwargs.get('cv', None)
                # `before_restart` is True for the threads that represent the original
                # chief and non-chief worker, and False for threads that represent the
                # restarted chief and non-chief workers.
                before_restart = kwargs['before_restart']
                if kwargs['new_chief']:
                    # `new_chief` is only True for the restarted chief thread. It waits
                    # until non-chief is preempted and restarted to simulate the causality
                    # where chief's restart results from non-chief's failure.
                    cv.acquire()
                    while not hasattr(cv, 'preempted'):
                        cv.wait()
                    cv.release()

                # Model building under strategy scope. Following is the code we expect
                # the user runs on every worker.
                strategy = get_strategy_object(strategy_cls)
                batch_size = 64
                steps = 3
                train_ds, _ = _mnist_synthetic_dataset(batch_size, steps)
                with strategy.scope():
                    model = _get_model((28, 28, 1))

                # Function to start a new thread. This will be called twice in the
                # following code: one represents the restart of the non-chief, and one
                # represents the restart of the chief as a result of the restart of the
                # non-chief (so the training can continue in sync).
                def start_new_thread(new_chief=False):
                    new_thread_tf_config = json.loads(os.environ['TF_CONFIG'])
                    new_thread_tf_config['cluster']['worker'] = kwargs[
                        'reserved_ports']
                    return self._run_task_in_thread(
                        task_fn=_independent_worker_fn,
                        cluster_spec=None,
                        task_type=None,
                        task_id=None,
                        tf_config=new_thread_tf_config,
                        before_restart=False,
                        cv=cv,
                        new_chief=new_chief)

                if test_base.is_chief() and before_restart:
                    # Chief to start a new thread (that will be blocked by a condition
                    # variable until the non-chief's new thread is started). The thread
                    # for (recovered) chief is started before entering `fit()` because
                    # the original chief thread will eventually hang and be ignored.
                    start_new_thread(new_chief=True)

                try:

                    class CkptSavedEpochAssertingCallback(callbacks.Callback):
                        def __init__(self, test_obj):
                            super(CkptSavedEpochAssertingCallback,
                                  self).__init__()
                            self.test_obj = test_obj

                        def on_epoch_begin(self, epoch, logs=None):
                            # `_ckpt_saved_epoch` attribute is set at the end of every epoch.
                            self.test_obj.assertEqual(
                                self.model._ckpt_saved_epoch is None,
                                epoch == 0)

                    callbacks_list = [
                        callbacks.ModelCheckpoint(
                            filepath=saving_filepath,
                            save_weights_only=True,
                            load_weights_on_restart=True),
                        CkptSavedEpochAssertingCallback(self)
                    ]
                    if before_restart:
                        callbacks_list.append(preemption_callback())

                    self.assertIsNone(model._ckpt_saved_epoch)
                    history = model.fit(x=train_ds,
                                        epochs=num_epoch,
                                        steps_per_epoch=steps,
                                        callbacks=callbacks_list)
                    self.assertIsNone(model._ckpt_saved_epoch)

                    # `history` of the training result is collected to be compared against
                    # each other. It is expected that the training results (loss and
                    # accuracy`) are the same with or without preemption.
                    self._histories.append(history.history)

                except RuntimeError:
                    # pylint: disable=g-assert-in-except
                    self.assertTrue(before_restart)
                    # Reset the barrier so the new threads simulating recovery can
                    # continue.
                    self._barrier._counter = 0
                    self._barrier._flag = False

                    # Now that the non-chief has been preempted, it notifies the thread
                    # that simulates the restarted chief to start so they can be back in
                    # sync.
                    cv.acquire()
                    cv.preempted = True
                    cv.notify()
                    cv.release()

                    # At this point we should discard the original non-chief thread, and
                    # start the new thread that simulates the restarted non-chief, hence
                    # joining the thread and return.
                    self.join_independent_workers([start_new_thread()])
                    return

                # Successful end of a `fit()` call.
                self._successful_thread_ends += 1
                self.assertFalse(before_restart)

        # Common parameters
        num_workers = 2
        num_epoch = 3
        # History list storing the results for preemption and no preemption cases.
        self._histories = []
        # Pass `saving_filepath` from the parent thread to ensure every worker has
        # the same filepath to save.
        saving_filepath = os.path.join(self.get_temp_dir(),
                                       'checkpoint.' + file_format)
        strategy = get_strategy_object(strategy_cls)

        # Case 1: Training for `num_epoch` without preemptions.
        cluster_spec = test_base.create_cluster_spec(num_workers=num_workers)
        self._barrier = dc._Barrier(2)
        self._successful_thread_ends = 0
        threads = self.run_multiple_tasks_in_threads(
            _independent_worker_fn,
            cluster_spec,
            saving_filepath=saving_filepath,
            before_restart=False,
            new_chief=False)
        if os.path.exists(saving_filepath):
            os.remove(saving_filepath)
        threads_to_join = []
        if strategy.extended.experimental_between_graph:
            for ts in threads.values():
                threads_to_join.extend(ts)
        else:
            threads_to_join = [threads['worker'][0]]
        self.join_independent_workers(threads_to_join)
        self.assertEqual(self._successful_thread_ends, 2)

        # Case 2: Training for `num_epoch` epoch with preemptions.
        # The preemption is simulated at both epoch boundary and batch boundary.
        cluster_spec = test_base.create_cluster_spec(num_workers=num_workers)
        cv = threading.Condition()
        self._barrier = dc._Barrier(2)
        # Ports reserved for new threads simulating recovery.
        reserved_ports = [
            'localhost:%s' % test_base.pick_unused_port()
            for _ in range(num_workers)
        ]
        self._successful_thread_ends = 0
        threads = self.run_multiple_tasks_in_threads(
            _independent_worker_fn,
            cluster_spec,
            saving_filepath=saving_filepath,
            reserved_ports=reserved_ports,
            before_restart=True,
            cv=cv,
            new_chief=False)
        if os.path.exists(saving_filepath):
            os.remove(saving_filepath)
        threads_to_join = []
        if strategy.extended.experimental_between_graph:
            # Only join the non-chief thread since the first thread for chief will
            # eventually hang and be ignored.
            threads_to_join = [threads['worker'][1]]
        else:
            threads_to_join = [threads['worker'][0]]
        self.join_independent_workers(threads_to_join)
        self.assertEqual(self._successful_thread_ends, 2)

        def assert_all_elements_are_identical(list_to_check):
            first_item = list_to_check[0]
            for item in list_to_check[1:]:
                self.assertAllClose(first_item, item, rtol=1e-5, atol=1e-5)

        # Important: the results from preemption interrupted and non-interrupted
        # cases should give the same final results.
        assert_all_elements_are_identical(
            [history['acc'][-1] for history in self._histories])
        assert_all_elements_are_identical(
            [history['loss'][-1] for history in self._histories])
        # The length of `self._histories` would be num_workers * num_runs (3).
        self.assertLen(self._histories, 4)
Exemple #18
0
    def testFaultToleranceInSyncStrategy(self, strategy_cls, file_format,
                                         preemption_callback,
                                         save_weights_only,
                                         load_weights_on_restart):
        """Test fault-tolerance with multi-threading using sync dist-strat.

    This test simulates multi-worker training that is interrupted by a
    preemption, by having two threads, each of which represents a chief and a
    non-chief worker, where the non-chief raises an error in the middle of
    training loop. Upon excepting the error, a new thread with a new cluster
    spec is created to simulate the recovered non-chief worker. Meanwhile, the
    chief worker cannot proceed and hangs since the non-chief worker has
    crashed. To simulate a restart of the chief, a new thread has been prepared
    to run to take over chief with the help of a condition variable. It is
    expected that after the restart of both chief and non-chief workers, the
    training continues from the epoch they previously failed at. The test
    concludes by verifying the preemption-interrupted training can finish with
    the same loss and accuracy had the preemption not occurred.

    TODO(rchao): Add test to check preemption on chief (possibly using multi
    processes).

    TODO(rchao): Add test to check fault-tolerance with multiple `model.fit()`.

    Arguments:
      strategy_cls: The strategy class to use.
      file_format: `h5` or `tf`.
      preemption_callback: The callback to simulate preemption.
      save_weights_only: The argument for `model.fit()`'s `save_weights_only`.
      load_weights_on_restart: The argument for `model.fit()`'s
        `load_weights_on_restart`.
    """
        def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
            with test.mock.patch.object(dc, '_run_std_server',
                                        self._make_mock_run_std_server()):
                # `before_restart` is True for the threads that represent the original
                # chief and non-chief worker, and False for threads that represent the
                # restarted chief and non-chief workers.
                before_restart = kwargs['before_restart']

                # Model building under strategy scope. Following is the code we expect
                # the user runs on every worker.
                strategy = get_strategy_object(strategy_cls)
                batch_size = 64
                steps = 3
                train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
                    batch_size, steps)

                with strategy.scope():
                    model = multi_worker_testing_utils.get_mnist_model(
                        (28, 28, 1))

                # Function to start a new thread. This will be called twice in the
                # following code: one represents the restart of the non-chief, and one
                # represents the restart of the chief as a result of the restart of the
                # non-chief (so the training can continue in sync).
                def start_new_thread(new_chief):
                    new_thread_tf_config = json.loads(os.environ['TF_CONFIG'])

                    # Update the ports in new chief and new worker threads.
                    new_thread_tf_config['cluster']['worker'] = kwargs[
                        'reserved_ports']

                    # Since both new chief and new worker threads are started from the
                    # worker thread, we need to overwrite the tf config task index.
                    new_thread_tf_config['task'][
                        'index'] = 0 if new_chief else 1
                    return self._run_task_in_thread(
                        task_fn=_independent_worker_fn,
                        cluster_spec=None,
                        task_type=None,
                        task_id=None,
                        tf_config=new_thread_tf_config,
                        before_restart=False,
                        new_chief=new_chief)

                try:

                    class CkptSavedEpochAssertingCallback(callbacks.Callback):
                        def __init__(self, test_obj):
                            super(CkptSavedEpochAssertingCallback,
                                  self).__init__()
                            self.test_obj = test_obj

                        def on_epoch_begin(self, epoch, logs=None):
                            # `_ckpt_saved_epoch` attribute is set at the end of every epoch.
                            self.test_obj.assertEqual(
                                K.eval(self.model._ckpt_saved_epoch) ==
                                training_state.CKPT_SAVED_EPOCH_UNUSED_VALUE,
                                epoch == 0)

                    callbacks_list = [
                        callbacks.ModelCheckpoint(
                            filepath=saving_filepath,
                            save_weights_only=save_weights_only,
                            load_weights_on_restart=load_weights_on_restart),
                        CkptSavedEpochAssertingCallback(self)
                    ]
                    if before_restart:
                        callbacks_list.append(preemption_callback())

                    self.assertFalse(
                        hasattr(model, training_state.CKPT_SAVED_EPOCH))
                    history = model.fit(x=train_ds,
                                        epochs=num_epoch,
                                        steps_per_epoch=steps,
                                        callbacks=callbacks_list)
                    self.assertFalse(
                        hasattr(model, training_state.CKPT_SAVED_EPOCH))

                    # `history` of the training result is collected to be compared against
                    # each other. It is expected that the training results (loss and
                    # accuracy`) are the same with or without preemption.
                    self._histories.append(history.history)

                except RuntimeError:
                    # pylint: disable=g-assert-in-except
                    self.assertTrue(before_restart)
                    # Reset the barrier so the new threads simulating recovery can
                    # continue.
                    self._barrier._counter = 0
                    self._barrier._flag = False

                    # At this point we block the original non-chief thread, and
                    # start the new threads that simulate the restarted chief and
                    # non-chief, joining the threads and return.
                    new_chief_thread = start_new_thread(new_chief=True)
                    new_worker_thread = start_new_thread(new_chief=False)
                    self.join_independent_workers(
                        [new_chief_thread, new_worker_thread])
                    return

                # Successful end of a `fit()` call.
                with self._lock:
                    self._successful_thread_ends += 1
                self.assertFalse(before_restart)

        # Common parameters
        num_workers = 2
        num_epoch = 3
        # History list storing the results for preemption and no preemption cases.
        self._histories = []
        # Lock required to prevent race condition between two threads.
        self._lock = threading.Lock()
        strategy = get_strategy_object(strategy_cls)

        def handler(signum, frame):
            del signum, frame
            # `session.run()` within `model.fit()` can time out. Skipping it as it
            # doesn't represent the failure of this test.
            self.skipTest('Skipping test due to `session.run()` timeout.')

        signal.signal(signal.SIGALRM, handler)
        # Alarming within 5 min before the test timeouts and fails.
        signal.alarm(240)

        def get_saving_dir_and_filepath():
            saving_dir = tempfile.mkdtemp(prefix=self.get_temp_dir())
            saving_filepath = os.path.join(saving_dir,
                                           'checkpoint.' + file_format)
            return saving_dir, saving_filepath

        # Case 1: Training for `num_epoch` without preemptions.
        cluster_spec = test_base.create_cluster_spec(num_workers=num_workers)
        self._barrier = dc._Barrier(2)
        self._successful_thread_ends = 0
        # Get a new temporary filepath to save the checkpoint to.
        saving_dir, saving_filepath = get_saving_dir_and_filepath()
        threads = self.run_multiple_tasks_in_threads(
            _independent_worker_fn,
            cluster_spec,
            # Pass `saving_filepath` from the parent thread to ensure every worker
            # has the same filepath to save.
            saving_filepath=saving_filepath,
            before_restart=False,
            new_chief=False)
        threads_to_join = []
        if strategy.extended.experimental_between_graph:
            for ts in threads.values():
                threads_to_join.extend(ts)
        else:
            threads_to_join = [threads['worker'][0]]
        self.join_independent_workers(threads_to_join)

        # `self.test_skipped_reason` could be set when a non-main thread attempts
        # to skip the test.
        # `multi_worker_test_base.skip_if_grpc_server_cant_be_started()` is an
        # example of where this can be set. Since raising `SkipTest` in a non-main
        # thread doesn't actually skip the test, we check if the test should be
        # skipped here once we have joined the threads.
        if getattr(self, 'test_skipped_reason', None) is not None:
            self.skipTest(self.test_skipped_reason)

        self.assertTrue(
            training_state.remove_checkpoint_if_exists(saving_dir,
                                                       saving_filepath))
        self.assertEqual(self._successful_thread_ends, 2)

        # Case 2: Training for `num_epoch` epoch with preemptions.
        # The preemption is simulated at both epoch boundary and batch boundary.
        cluster_spec = test_base.create_cluster_spec(num_workers=num_workers)
        self._barrier = dc._Barrier(2)
        # Ports reserved for new threads simulating recovery.
        reserved_ports = [
            'localhost:%s' % test_base.pick_unused_port()
            for _ in range(num_workers)
        ]
        self._successful_thread_ends = 0
        # Get a new temporary filepath to save the checkpoint to.
        saving_dir, saving_filepath = get_saving_dir_and_filepath()
        threads = self.run_multiple_tasks_in_threads(
            _independent_worker_fn,
            cluster_spec,
            # Pass `saving_filepath` from the parent thread to ensure every worker
            # has the same filepath to save.
            saving_filepath=saving_filepath,
            reserved_ports=reserved_ports,
            before_restart=True,
            new_chief=False)
        threads_to_join = []
        if strategy.extended.experimental_between_graph:
            # Only join the non-chief thread since the first thread for chief will
            # eventually hang and be ignored.
            threads_to_join = [threads['worker'][1]]
        else:
            threads_to_join = [threads['worker'][0]]
        self.join_independent_workers(threads_to_join)
        if getattr(self, 'test_skipped_reason', None) is not None:
            self.skipTest(self.test_skipped_reason)

        self.assertTrue(
            training_state.remove_checkpoint_if_exists(saving_dir,
                                                       saving_filepath))
        self.assertEqual(self._successful_thread_ends, 2)

        def assert_all_elements_are_identical(list_to_check):
            first_item = list_to_check[0]
            for item in list_to_check[1:]:
                self.assertAllClose(first_item, item, rtol=2e-5, atol=1e-5)

        # Important: the results from preemption interrupted and non-interrupted
        # cases should give the same final results.
        assert_all_elements_are_identical(
            [history['acc'][-1] for history in self._histories])
        assert_all_elements_are_identical(
            [history['loss'][-1] for history in self._histories])
        # The length of `self._histories` would be num_workers * num_runs (3).
        self.assertLen(self._histories, 4)

        # Results from case 1 should have 3 full epochs.
        self.assertLen(self._histories[0]['acc'], 3)
        # Results from case 2 should only have 2 full epochs because it restarted at
        # epoch 1.
        self.assertLen(self._histories[-1]['acc'], 2)
def run(proc_func,
        count_dict,
        proc_flags=None,
        timeout=200,
        time_to_exit=None,
        return_std_stream=False,
        args=None,
        kwargs=None):
  """Run functions on local sub-processes.

  Args:
    proc_func: Function to be run on the processes. This will be run on
      processes for all task types.
    count_dict: Dict for task_type/count of such task type.
    proc_flags: Dict that contains the key/values of the flags used on the
      processes.
    timeout: Time out in seconds. If the sub-process takes more than this time
      to complete, raise an error.
    time_to_exit: If set, sub-processes is forced to exit at approximately this
      many seconds after `run()` is called, through `signal.alarm()` api. This
      is for simulation of interruption on a process so in such cases no error
      is raised. Note that this is best effort at Python level since Python
      signal handler does not get executed inside the low-level (C) signal
      handler, so it can be delayed.
    return_std_stream: Boolean, whether the messages streamed to stdout and
      stderr in subprocesses are captured. If True, the messages are stored
      in a list returned as the second element.
    args: Positional arguments to be sent to functions run on processes.
    kwargs: Keyword arguments to be sent to functions run on processes.

  Returns:
    If `return_std_stream` is False, a list that stores the return data added
    by processes through `multi_process_runner.add_return_data(data)` call;
    if `return_std_stream` is True, a two-element tuple of
    `(return_data_list, std_stream_data_list)`, where `return_data_list` stores
    the return data added by processes through
    `multi_process_runner.add_return_data(data)` call, and
    `std_stream_data_list` stores the messages streamed to stdout and stderr
    in the subprocesses.

  Raises:
    RuntimeError: If any of the subprocesses raise an error, or if any of the
      subprocesses does not return or error out within `timeout` seconds.

  TODO(rchao): Open source this with a solution to handle multi_process_lib.
  """

  assert callable(proc_func)
  processes = []
  cluster_spec = {}
  args = args or ()
  kwargs = kwargs or {}

  for task_type, count in count_dict.items():
    cluster_spec[task_type] = [
        'localhost:{}'.format(multi_worker_test_base.pick_unused_port())
        for _ in range(count)
    ]

  def wrapper_func(tf_config_as_json, proc_func, proc_flags, time_to_exit,
                   *arg, **kwargs):
    """The wrapper function that actually gets run on the process(es)."""

    os.environ['TF_CONFIG'] = tf_config_as_json
    if proc_flags is not None:
      for flag_key, flag_value in proc_flags.items():
        setattr(flags.FLAGS, flag_key, flag_value)

    stdout_collector = _LogCollector(
        sys.__stdout__) if return_std_stream else None
    stderr_collector = _LogCollector(
        sys.__stderr__) if return_std_stream else None

    def finish_wrapper_func_properly(finish_message=_FINISH_PROPERLY_MESSAGE):
      """Call to finish `wrapper_func` properly."""
      # Clear the alarm.
      signal.alarm(0)
      if (return_std_stream and stdout_collector is not None and
          stderr_collector is not None):
        # If stdout and stderr are to be collected, add them to std stream
        # queue.
        _add_std_stream_data_flattened(stdout_collector.log)
        _add_std_stream_data_flattened(stderr_collector.log)
        # Un-redirect stdout and stderr.
        sys.stdout = sys.__stdout__
        sys.stderr = sys.__stderr__
      _get_internal_queue().put(finish_message)

    if time_to_exit is not None:

      def handler(signum, frame):
        del signum, frame
        finish_wrapper_func_properly()
        # pylint: disable=protected-access
        os._exit(0)

      signal.signal(signal.SIGALRM, handler)
      signal.alarm(time_to_exit)

    if return_std_stream:
      sys.stdout = stdout_collector
      sys.stderr = stderr_collector

    try:
      proc_func(*arg, **kwargs)
    # pylint: disable=broad-except
    except Exception as e:
      # Capture all exceptions to be reported to parent process.
      finish_wrapper_func_properly(
          'Exception raised by subprocess: {}: {} {}'.format(
              e.__class__.__name__, str(e), traceback.format_exc()))
      return

    finish_wrapper_func_properly()

  # Start number of processes according to `count_dict`.
  for task_type, count in count_dict.items():
    for task_id in range(count):
      tf_config_as_json = json.dumps({
          'cluster': cluster_spec,
          'task': {
              'type': task_type,
              'index': task_id
          }
      })
      p = multi_process_lib.Process(
          target=wrapper_func,
          args=(tf_config_as_json, proc_func, proc_flags, time_to_exit) + args,
          kwargs=kwargs)
      p.start()
      processes.append(p)

  internal_queue_results = []
  for _ in range(len(processes)):
    try:
      internal_queue_results.append(
          _get_internal_queue().get(timeout=timeout))
    except Queue.Empty:
      raise RuntimeError(
          'One or more subprocesses timed out. Please inspect logs for '
          'subprocess debugging info. Timeout = {} sec.'.format(timeout))

  for internal_queue_result in internal_queue_results:
    if internal_queue_result.startswith('Exception raised by subprocess'):
      raise RuntimeError(internal_queue_result)
    assert internal_queue_result == _FINISH_PROPERLY_MESSAGE

  def queue_to_list(queue_to_convert):
    """Convert `queue.Queue` to `list`."""
    list_to_return = []
    while True:
      try:
        list_to_return.append(queue_to_convert.get(block=False))
      except Queue.Empty:
        break
    return list_to_return

  if return_std_stream:
    return tuple(
        queue_to_list(multi_process_lib.get_user_data()[queue_name])
        for queue_name in
        [AvailableQueues.PUBLIC_QUEUE, AvailableQueues.STD_STREAM_QUEUE])
  else:
    return queue_to_list(
        multi_process_lib.get_user_data()[AvailableQueues.PUBLIC_QUEUE])
  def testFaultToleranceInSyncStrategy(self, strategy_cls, file_format,
                                       preemption_callback):
    """Test fault-tolerance with multi-threading using sync dist-strat.

    This test simulates multi-worker training that is interrupted by a
    preemption, by having two threads, each of which represents a chief and a
    non-chief worker, where the non-chief raises an error in the middle of
    training loop. Upon excepting the error, a new thread with a new cluster
    spec is created to simulate the recovered non-chief worker. Meanwhile, the
    chief worker cannot proceed and hangs since the non-chief worker has
    crashed. To simulate a restart of the chief, a new thread has been prepared
    to run to take over chief with the help of a condition variable. It is
    expected that after the restart of both chief and non-chief workers, the
    training continues from the epoch they previously failed at. The test
    concludes by verifying the preemption-interrupted training can finish with
    the same loss and accuracy had the preemption not occurred.

    Arguments:
      strategy_cls: The strategy class to use.
      file_format: `h5` or `tf`.
      preemption_callback: The callback to simulate preemption.
    """

    def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
      with test.mock.patch.object(dc, '_run_std_server',
                                  self._make_mock_run_std_server()):
        # Condition variable that blocks the thread that represents the
        # restarted chief.
        cv = kwargs.get('cv', None)
        # `before_restart` is True for the threads that represent the original
        # chief and non-chief worker, and False for threads that represent the
        # restarted chief and non-chief workers.
        before_restart = kwargs['before_restart']
        if kwargs['new_chief']:
          # `new_chief` is only True for the restarted chief thread. It waits
          # until non-chief is preempted and restarted to simulate the causality
          # where chief's restart results from non-chief's failure.
          cv.acquire()
          while not hasattr(cv, 'preempted'):
            cv.wait()
          cv.release()

        # Model building under strategy scope. Following is the code we expect
        # the user runs on every worker.
        strategy = get_strategy_object(strategy_cls)
        batch_size = 64
        steps = 3
        train_ds, _ = _mnist_synthetic_dataset(batch_size, steps)
        with strategy.scope():
          model = _get_model((28, 28, 1))

        # Function to start a new thread. This will be called twice in the
        # following code: one represents the restart of the non-chief, and one
        # represents the restart of the chief as a result of the restart of the
        # non-chief (so the training can continue in sync).
        def start_new_thread(new_chief=False):
          new_thread_tf_config = json.loads(os.environ['TF_CONFIG'])
          new_thread_tf_config['cluster']['worker'] = kwargs['reserved_ports']
          return self._run_task_in_thread(
              task_fn=_independent_worker_fn,
              cluster_spec=None,
              task_type=None,
              task_id=None,
              tf_config=new_thread_tf_config,
              before_restart=False,
              cv=cv,
              new_chief=new_chief)

        if test_base.is_chief() and before_restart:
          # Chief to start a new thread (that will be blocked by a condition
          # variable until the non-chief's new thread is started). The thread
          # for (recovered) chief is started before entering `fit()` because
          # the original chief thread will eventually hang and be ignored.
          start_new_thread(new_chief=True)

        try:

          class CkptSavedEpochAssertingCallback(callbacks.Callback):

            def __init__(self, test_obj):
              super(CkptSavedEpochAssertingCallback, self).__init__()
              self.test_obj = test_obj

            def on_epoch_begin(self, epoch, logs=None):
              # `_ckpt_saved_epoch` attribute is set at the end of every epoch.
              self.test_obj.assertEqual(self.model._ckpt_saved_epoch is None,
                                        epoch == 0)

          callbacks_list = [
              callbacks.ModelCheckpoint(
                  filepath=saving_filepath,
                  save_weights_only=True,
                  load_weights_on_restart=True),
              CkptSavedEpochAssertingCallback(self)
          ]
          if before_restart:
            callbacks_list.append(preemption_callback())

          self.assertIsNone(model._ckpt_saved_epoch)
          history = model.fit(
              x=train_ds,
              epochs=num_epoch,
              steps_per_epoch=steps,
              callbacks=callbacks_list)
          self.assertIsNone(model._ckpt_saved_epoch)

          # `history` of the training result is collected to be compared against
          # each other. It is expected that the training results (loss and
          # accuracy`) are the same with or without preemption.
          self._histories.append(history.history)

        except RuntimeError:
          # pylint: disable=g-assert-in-except
          self.assertTrue(before_restart)
          # Reset the barrier so the new threads simulating recovery can
          # continue.
          self._barrier._counter = 0
          self._barrier._flag = False

          # Now that the non-chief has been preempted, it notifies the thread
          # that simulates the restarted chief to start so they can be back in
          # sync.
          cv.acquire()
          cv.preempted = True
          cv.notify()
          cv.release()

          # At this point we should discard the original non-chief thread, and
          # start the new thread that simulates the restarted non-chief, hence
          # joining the thread and return.
          self.join_independent_workers([start_new_thread()])
          return

        # Successful end of a `fit()` call.
        self._successful_thread_ends += 1
        self.assertFalse(before_restart)

    # Common parameters
    num_workers = 2
    num_epoch = 3
    # History list storing the results for preemption and no preemption cases.
    self._histories = []
    # Pass `saving_filepath` from the parent thread to ensure every worker has
    # the same filepath to save.
    saving_filepath = os.path.join(self.get_temp_dir(),
                                   'checkpoint.' + file_format)
    strategy = get_strategy_object(strategy_cls)

    # Case 1: Training for `num_epoch` without preemptions.
    cluster_spec = test_base.create_cluster_spec(num_workers=num_workers)
    self._barrier = dc._Barrier(2)
    self._successful_thread_ends = 0
    threads = self.run_multiple_tasks_in_threads(
        _independent_worker_fn,
        cluster_spec,
        saving_filepath=saving_filepath,
        before_restart=False,
        new_chief=False)
    if os.path.exists(saving_filepath):
      os.remove(saving_filepath)
    threads_to_join = []
    if strategy.extended.experimental_between_graph:
      for ts in threads.values():
        threads_to_join.extend(ts)
    else:
      threads_to_join = [threads['worker'][0]]
    self.join_independent_workers(threads_to_join)
    self.assertEqual(self._successful_thread_ends, 2)

    # Case 2: Training for `num_epoch` epoch with preemptions.
    # The preemption is simulated at both epoch boundary and batch boundary.
    cluster_spec = test_base.create_cluster_spec(num_workers=num_workers)
    cv = threading.Condition()
    self._barrier = dc._Barrier(2)
    # Ports reserved for new threads simulating recovery.
    reserved_ports = [
        'localhost:%s' % test_base.pick_unused_port()
        for _ in range(num_workers)
    ]
    self._successful_thread_ends = 0
    threads = self.run_multiple_tasks_in_threads(
        _independent_worker_fn,
        cluster_spec,
        saving_filepath=saving_filepath,
        reserved_ports=reserved_ports,
        before_restart=True,
        cv=cv,
        new_chief=False)
    if os.path.exists(saving_filepath):
      os.remove(saving_filepath)
    threads_to_join = []
    if strategy.extended.experimental_between_graph:
      # Only join the non-chief thread since the first thread for chief will
      # eventually hang and be ignored.
      threads_to_join = [threads['worker'][1]]
    else:
      threads_to_join = [threads['worker'][0]]
    self.join_independent_workers(threads_to_join)
    self.assertEqual(self._successful_thread_ends, 2)

    def assert_all_elements_are_identical(list_to_check):
      first_item = list_to_check[0]
      for item in list_to_check[1:]:
        self.assertAllClose(first_item, item, rtol=1e-5, atol=1e-5)

    # Important: the results from preemption interrupted and non-interrupted
    # cases should give the same final results.
    assert_all_elements_are_identical(
        [history['acc'][-1] for history in self._histories])
    assert_all_elements_are_identical(
        [history['loss'][-1] for history in self._histories])
    # The length of `self._histories` would be num_workers * num_runs (3).
    self.assertLen(self._histories, 4)