Example #1
0
    def test_template(self, strategy_cls, file_format):
        num_workers = 2
        num_epoch = 2

        cluster_spec = test_base.create_cluster_spec(num_workers=num_workers,
                                                     test_obj=self)
        self._barrier = dc._Barrier(2)

        def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
            """Simulates an Independent Worker inside of a thread."""
            with test.mock.patch.object(dc, '_run_std_server',
                                        self._make_mock_run_std_server()):
                strategy = get_strategy_object(strategy_cls)
                batch_size = 64
                steps = 2
                train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
                    batch_size, steps)
                with strategy.scope():
                    model = multi_worker_testing_utils.get_mnist_model(
                        (28, 28, 1))

                custom_callable(model,
                                self,
                                train_ds,
                                num_epoch,
                                steps,
                                strategy,
                                saving_filepath=kwargs['saving_filepath'],
                                barrier=kwargs['barrier'],
                                threading_local=kwargs['threading_local'])

        # Pass saving_filepath from the parent thread to ensure every worker has the
        # same fileapth to save.
        saving_filepath = os.path.join(self.get_temp_dir(),
                                       'checkpoint.' + file_format)
        barrier = dc._Barrier(2)
        threading_local = threading.local()
        threads = self.run_multiple_tasks_in_threads(
            _independent_worker_fn,
            cluster_spec,
            saving_filepath=saving_filepath,
            barrier=barrier,
            threading_local=threading_local)
        self.assertFalse(training_state.checkpoint_exists(saving_filepath))

        threads_to_join = []
        strategy = get_strategy_object(strategy_cls)
        if strategy.extended.experimental_between_graph:
            for ts in threads.values():
                threads_to_join.extend(ts)
        else:
            threads_to_join = [threads['worker'][0]]
        self.join_independent_workers(threads_to_join)
Example #2
0
    def run_independent_workers(self,
                                worker_fn,
                                strategy_cls,
                                num_workers,
                                num_ps=None,
                                **kwargs):
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            num_workers=num_workers, num_ps=num_ps)
        self._barrier = dc._Barrier(num_workers + (num_ps or 0))  # pylint: disable=protected-access

        def _worker_fn(**kwargs):
            """Runs the worker function in a thread."""
            with test.mock.patch.object(dc, '_run_std_server',
                                        self._make_mock_run_std_server()):
                strategy = get_strategy_object(strategy_cls)
                with strategy.scope():
                    return worker_fn(**kwargs)

        threads = self.run_multiple_tasks_in_threads(_worker_fn, cluster_spec,
                                                     **kwargs)
        strategy = get_strategy_object(strategy_cls)
        if strategy.extended.experimental_between_graph:
            threads_to_join = threads.get('chief', []) + threads.get(
                'worker', [])
        else:
            threads_to_join = [
                threads['chief'][0]
                if 'chief' in threads else threads['worker'][0]
            ]
        self.join_independent_workers(threads_to_join)
    def test_complete_flow_indepedent_worker_between_graph(
            self, train_distribute_cls, eval_distribute_cls):
        train_distribute = train_distribute_cls(
            num_gpus_per_worker=context.num_gpus())

        if eval_distribute_cls:
            eval_distribute = eval_distribute_cls()
        else:
            eval_distribute = None

        cluster_spec = multi_worker_test_base.create_cluster_spec(
            num_workers=3, num_ps=2, has_eval=True)
        # 3 workers, 2 ps and 1 evaluator.
        self._barrier = dc._Barrier(6)

        threads = self._run_multiple_tasks_in_threads(cluster_spec,
                                                      train_distribute,
                                                      eval_distribute)
        for task_type, ts in threads.items():
            if task_type == PS:
                continue
            for t in ts:
                t.join()

        estimator = self._get_estimator(train_distribute, eval_distribute)
        self._inspect_train_and_eval_events(estimator)
Example #4
0
 def setUp(self):
     self._result_correct = 0
     self._lock = threading.Lock()
     self._worker_context = {}
     self._strategy_property = {}
     self._std_servers = {}
     self._barrier = distribute_coordinator._Barrier(NUM_WORKERS)
 def setUp(self):
   self._result_correct = 0
   self._lock = threading.Lock()
   self._worker_context = {}
   self._strategy_property = {}
   self._std_servers = {}
   self._barrier = distribute_coordinator._Barrier(NUM_WORKERS)
  def run_independent_workers(self,
                              worker_fn,
                              strategy_cls,
                              num_workers,
                              num_ps=None,
                              **kwargs):
    cluster_spec = multi_worker_test_base.create_cluster_spec(
        num_workers=num_workers, num_ps=num_ps)
    self._barrier = dc._Barrier(num_workers + (num_ps or 0))  # pylint: disable=protected-access

    def _worker_fn(**kwargs):
      """Runs the worker function in a thread."""
      with test.mock.patch.object(dc, '_run_std_server',
                                  self._make_mock_run_std_server()):
        strategy = get_strategy_object(strategy_cls)
        with strategy.scope():
          return worker_fn(**kwargs)

    threads = self.run_multiple_tasks_in_threads(_worker_fn, cluster_spec,
                                                 **kwargs)
    strategy = get_strategy_object(strategy_cls)
    if strategy.extended.experimental_between_graph:
      threads_to_join = [
          ts for task_type, ts in threads.items() if task_type == 'ps'
      ]
    else:
      threads_to_join = [threads['worker'][0]]
    self.join_independent_workers(threads_to_join)
Example #7
0
    def testSimpleModelIndependentWorkerAsync(self, strategy_cls):
        num_workers = 2
        num_epoch = 2
        cluster_spec = test_base.create_cluster_spec(num_workers=num_workers,
                                                     num_ps=2,
                                                     test_obj=self)
        self._barrier = dc._Barrier(4)

        # The verification callback will be shared by multiple threads.
        verification_callback = MultiWorkerVerificationCallback(
            num_epoch=num_epoch, num_worker=num_workers)

        def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
            """Simulates an Independent Worker inside of a thread."""
            # TODO(rchao/yuefengz): The following is run by both worker and ps
            # threads. The distribute coordinator should run std server immediately
            # without configuring the session (or building the graph) on PS.
            with test.mock.patch.object(dc, '_run_std_server',
                                        self._make_mock_run_std_server()):
                batch_size = 64
                steps = 2
                strategy = strategy_cls()
                verification_callback.is_between_graph = \
                    strategy.extended.experimental_between_graph

                train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
                    batch_size, steps)
                val_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
                    batch_size, steps)
                with strategy.scope():
                    model = multi_worker_testing_utils.get_mnist_model(
                        (28, 28, 1))

                    # TODO(b/123868066): Verify callback for model.evaluate().
                    callbacks_for_fit = nest.flatten(
                        kwargs.get('verification_callback', []))
                    history = model.fit(x=train_ds,
                                        epochs=num_epoch,
                                        steps_per_epoch=steps,
                                        validation_data=val_ds,
                                        validation_steps=steps,
                                        callbacks=callbacks_for_fit)
                self.assertIsInstance(history, keras.callbacks.History)

        threads = self.run_multiple_tasks_in_threads(
            _independent_worker_fn,
            cluster_spec,
            verification_callback=verification_callback)

        threads_to_join = []
        for task_type, ts in threads.items():
            # This test can finish once the worker threads complete, and thus
            # the ps threads don't need to be joined.
            if task_type == 'ps':
                continue
            threads_to_join.extend(ts)
        self.join_independent_workers(threads_to_join)
        verification_callback.verify(self)
  def testSimpleModelIndependentWorkerAsync(self, strategy_cls):
    num_workers = 2
    num_epoch = 2
    cluster_spec = test_base.create_cluster_spec(
        num_workers=num_workers, num_ps=2)
    self._barrier = dc._Barrier(4)

    # The verification callback will be shared by multiple threads.
    verification_callback = MultiWorkerVerificationCallback(
        num_epoch=num_epoch, num_worker=num_workers)

    def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
      """Simulates an Independent Worker inside of a thread."""
      # TODO(rchao/yuefengz): The following is run by both worker and ps
      # threads. The distribute coordinator should run std server immediately
      # without configuring the session (or building the graph) on PS.
      with test.mock.patch.object(dc, '_run_std_server',
                                  self._make_mock_run_std_server()):
        batch_size = 64
        steps = 2
        strategy = strategy_cls()
        verification_callback.is_between_graph = \
            strategy.extended.experimental_between_graph

        train_ds, _ = _mnist_synthetic_dataset(batch_size, steps)
        val_ds, _ = _mnist_synthetic_dataset(batch_size, steps)
        with strategy.scope():
          model = _get_model((28, 28, 1))

          # TODO(b/123868066): Verify callback for model.evaluate().
          callbacks_for_fit = nest.flatten(
              kwargs.get('verification_callback', []))
          history = model.fit(
              x=train_ds,
              epochs=num_epoch,
              steps_per_epoch=steps,
              validation_data=val_ds,
              validation_steps=steps,
              callbacks=callbacks_for_fit)
        self.assertIsInstance(history, keras.callbacks.History)

    threads = self.run_multiple_tasks_in_threads(
        _independent_worker_fn,
        cluster_spec,
        verification_callback=verification_callback)

    threads_to_join = []
    for task_type, ts in threads.items():
      # This test can finish once the worker threads complete, and thus
      # the ps threads don't need to be joined.
      if task_type == 'ps':
        continue
      threads_to_join.extend(ts)
    self.join_independent_workers(threads_to_join)
    verification_callback.verify(self)
Example #9
0
    def test_complete_flow_independent_worker_between_graph(
            self, train_distribute_cls, eval_distribute_cls):
        if (context.num_gpus() < 2 and eval_distribute_cls
                == tf.distribute.experimental.MultiWorkerMirroredStrategy):
            self.skipTest(
                "`CollectiveAllReduceStrategy` needs at least two towers.")

        if (train_distribute_cls ==
                tf.distribute.experimental.ParameterServerStrategy):
            cluster_spec = multi_worker_test_base.create_cluster_spec(
                num_workers=3, num_ps=2, has_eval=True)
            # 3 workers, 2 ps and 1 evaluator.
            self._barrier = dc._Barrier(6)
        else:
            cluster_spec = multi_worker_test_base.create_cluster_spec(
                num_workers=3, num_ps=0, has_eval=True)
            # 3 workers and 1 evaluator.
            self._barrier = dc._Barrier(4)

        train_distribute = self._get_strategy_object(train_distribute_cls,
                                                     cluster_spec=cluster_spec)

        if eval_distribute_cls:
            eval_distribute = self._get_strategy_object(eval_distribute_cls,
                                                        eval_strategy=True)
        else:
            eval_distribute = None

        threads = self.run_multiple_tasks_in_threads(
            self._independent_worker_fn, cluster_spec, train_distribute,
            eval_distribute)
        threads_to_join = []
        for task_type, ts in threads.items():
            if task_type == PS:
                continue
            for t in ts:
                threads_to_join.append(t)
        self.join_independent_workers(threads_to_join)

        estimator = self._get_estimator(train_distribute, eval_distribute)
        self._inspect_train_and_eval_events(estimator)
    def test_complete_flow_indepedent_worker_between_graph(
            self, train_distribute_cls, eval_distribute_cls):
        train_distribute = train_distribute_cls(
            num_gpus_per_worker=context.num_gpus())

        if (context.num_gpus() < 2 and eval_distribute_cls
                == collective_all_reduce_strategy.CollectiveAllReduceStrategy):
            self.skipTest(
                "`CollectiveAllReduceStrategy` needs at least two towers.")

        if eval_distribute_cls:
            eval_distribute = eval_distribute_cls(
                num_gpus_per_worker=context.num_gpus())
        else:
            eval_distribute = None

        if (train_distribute_cls ==
                parameter_server_strategy.ParameterServerStrategy):
            cluster_spec = multi_worker_test_base.create_cluster_spec(
                num_workers=3, num_ps=2, has_eval=True)
            # 3 workers, 2 ps and 1 evaluator.
            self._barrier = dc._Barrier(6)
        else:
            cluster_spec = multi_worker_test_base.create_cluster_spec(
                num_workers=3, num_ps=0, has_eval=True)
            # 3 workers and 1 evaluator.
            self._barrier = dc._Barrier(4)

        threads = self._run_multiple_tasks_in_threads(cluster_spec,
                                                      train_distribute,
                                                      eval_distribute)
        for task_type, ts in threads.items():
            if task_type == PS:
                continue
            for t in ts:
                t.join()

        estimator = self._get_estimator(train_distribute, eval_distribute)
        self._inspect_train_and_eval_events(estimator)
  def test_complete_flow_independent_worker_between_graph(
      self, train_distribute_cls, eval_distribute_cls):
    if (context.num_gpus() < 2 and eval_distribute_cls ==
        collective_all_reduce_strategy.CollectiveAllReduceStrategy):
      self.skipTest("`CollectiveAllReduceStrategy` needs at least two towers.")

    train_distribute = self._get_strategy_object(train_distribute_cls)

    if eval_distribute_cls:
      eval_distribute = self._get_strategy_object(
          eval_distribute_cls, eval_strategy=True)
    else:
      eval_distribute = None

    if (train_distribute_cls == parameter_server_strategy
        .ParameterServerStrategy):
      cluster_spec = multi_worker_test_base.create_cluster_spec(
          num_workers=3, num_ps=2, has_eval=True)
      # 3 workers, 2 ps and 1 evaluator.
      self._barrier = dc._Barrier(6)
    else:
      cluster_spec = multi_worker_test_base.create_cluster_spec(
          num_workers=3, num_ps=0, has_eval=True)
      # 3 workers and 1 evaluator.
      self._barrier = dc._Barrier(4)

    threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn,
                                                 cluster_spec, train_distribute,
                                                 eval_distribute)
    threads_to_join = []
    for task_type, ts in threads.items():
      if task_type == PS:
        continue
      for t in ts:
        threads_to_join.append(t)
    self.join_independent_workers(threads_to_join)

    estimator = self._get_estimator(train_distribute, eval_distribute)
    self._inspect_train_and_eval_events(estimator)
Example #12
0
    def testSimpleModelIndependentWorkerSync(self, strategy_cls):
        num_workers = 2
        num_epoch = 2

        cluster_spec = tf.__internal__.distribute.multi_process_runner.create_cluster_spec(
            num_workers=num_workers)
        self._barrier = dc._Barrier(2)

        # The verification callback will be shared by multiple threads.
        verification_callback = MultiWorkerVerificationCallback(
            num_epoch=num_epoch, num_worker=num_workers)

        def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
            """Simulates an Independent Worker inside of a thread."""
            with tf.compat.v1.test.mock.patch.object(
                    dc, '_run_std_server', self._make_mock_run_std_server()):
                strategy = strategy_cls()
                verification_callback.is_between_graph = \
                    strategy.extended.experimental_between_graph
                batch_size = 64
                steps = 2
                train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
                    batch_size, steps)
                with strategy.scope():
                    model = multi_worker_testing_utils.get_mnist_model(
                        (28, 28, 1))
                orig_loss, _ = model.evaluate(train_ds, steps=steps)
                callbacks_for_fit = tf.nest.flatten(
                    kwargs.get('verification_callback', []))
                history = model.fit(x=train_ds,
                                    epochs=num_epoch,
                                    steps_per_epoch=steps,
                                    callbacks=callbacks_for_fit)
                self.assertIsInstance(history, keras.callbacks.History)
                trained_loss, _ = model.evaluate(train_ds, steps=steps)
                self.assertLess(trained_loss, orig_loss)

        threads = self.run_multiple_tasks_in_threads(
            _independent_worker_fn,
            cluster_spec,
            verification_callback=verification_callback)

        threads_to_join = []
        strategy = strategy_cls()
        if strategy.extended.experimental_between_graph:
            for ts in threads.values():
                threads_to_join.extend(ts)
        else:
            threads_to_join = [threads['worker'][0]]
        self.join_independent_workers(threads_to_join)
        verification_callback.verify(self)
  def testSimpleModelIndependentWorkerSync(self, strategy_cls):
    num_workers = 2
    num_epoch = 2

    cluster_spec = test_base.create_cluster_spec(num_workers=num_workers)
    self._barrier = dc._Barrier(2)

    # The verification callback will be shared by multiple threads.
    verification_callback = MultiWorkerVerificationCallback(
        num_epoch=num_epoch, num_worker=num_workers)

    def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
      """Simulates an Independent Worker inside of a thread."""
      with test.mock.patch.object(dc, '_run_std_server',
                                  self._make_mock_run_std_server()):
        strategy = strategy_cls()
        verification_callback.is_between_graph = \
            strategy.extended.experimental_between_graph
        batch_size = 64
        steps = 2
        train_ds, _ = _mnist_synthetic_dataset(batch_size, steps)
        with strategy.scope():
          model = _get_model((28, 28, 1))
        orig_loss, _ = model.evaluate(train_ds, steps=steps)
        callbacks_for_fit = nest.flatten(
            kwargs.get('verification_callback', []))
        history = model.fit(
            x=train_ds,
            epochs=num_epoch,
            steps_per_epoch=steps,
            callbacks=callbacks_for_fit)
        self.assertIsInstance(history, keras.callbacks.History)
        trained_loss, _ = model.evaluate(train_ds, steps=steps)
        self.assertLess(trained_loss, orig_loss)

    threads = self.run_multiple_tasks_in_threads(
        _independent_worker_fn,
        cluster_spec,
        verification_callback=verification_callback)

    threads_to_join = []
    strategy = strategy_cls()
    if strategy.extended.experimental_between_graph:
      for ts in threads.values():
        threads_to_join.extend(ts)
    else:
      threads_to_join = [threads['worker'][0]]
    self.join_independent_workers(threads_to_join)
    verification_callback.verify(self)
  def test_template(self, strategy_cls):
    num_workers = 2
    num_epoch = 2

    cluster_spec = test_base.create_cluster_spec(num_workers=num_workers)
    self._barrier = dc._Barrier(2)

    def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
      """Simulates an Independent Worker inside of a thread."""
      with test.mock.patch.object(dc, '_run_std_server',
                                  self._make_mock_run_std_server()):
        strategy = get_strategy_object(strategy_cls)
        batch_size = 64
        steps = 2
        train_ds, _ = _mnist_synthetic_dataset(batch_size, steps)
        with strategy.scope():
          model = _get_model((28, 28, 1))

        custom_callable(
            model,
            self,
            train_ds,
            num_epoch,
            steps,
            strategy,
            saving_filepath=kwargs['saving_filepath'])

    # Pass saving_filepath from the parent thread to ensure every worker has the
    # same fileapth to save.
    saving_filepath = os.path.join(self.get_temp_dir(), 'checkpoint.h5')
    threads = self.run_multiple_tasks_in_threads(
        _independent_worker_fn, cluster_spec, saving_filepath=saving_filepath)
    if os.path.exists(saving_filepath):
      os.remove(saving_filepath)

    threads_to_join = []
    strategy = get_strategy_object(strategy_cls)
    if strategy.extended.experimental_between_graph:
      for ts in threads.values():
        threads_to_join.extend(ts)
    else:
      threads_to_join = [threads['worker'][0]]
    self.join_independent_workers(threads_to_join)
Example #15
0
  def test_complete_flow_independent_worker_in_graph(self, train_distribute_cls,
                                                     eval_distribute_cls):
    train_distribute = self._get_strategy_object(train_distribute_cls)

    if eval_distribute_cls:
      eval_distribute = self._get_strategy_object(eval_distribute_cls)
    else:
      eval_distribute = None

    cluster_spec = multi_worker_test_base.create_cluster_spec(
        num_workers=3, num_ps=0, has_eval=True)
    # 3 workers and 1 evaluator.
    self._barrier = dc._Barrier(4)
    threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn,
                                                 cluster_spec, train_distribute,
                                                 eval_distribute)
    self.join_independent_workers([threads[WORKER][0], threads[EVALUATOR][0]])

    estimator = self._get_estimator(train_distribute, eval_distribute)
    self._inspect_train_and_eval_events(estimator)
  def test_complete_flow_independent_worker_in_graph(self, train_distribute_cls,
                                                     eval_distribute_cls):
    train_distribute = self._get_strategy_object(train_distribute_cls)

    if eval_distribute_cls:
      eval_distribute = self._get_strategy_object(eval_distribute_cls)
    else:
      eval_distribute = None

    cluster_spec = multi_worker_test_base.create_cluster_spec(
        num_workers=3, num_ps=0, has_eval=True)
    # 3 workers and 1 evaluator.
    self._barrier = dc._Barrier(4)
    threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn,
                                                 cluster_spec, train_distribute,
                                                 eval_distribute)
    self.join_independent_workers([threads[WORKER][0], threads[EVALUATOR][0]])

    estimator = self._get_estimator(train_distribute, eval_distribute)
    self._inspect_train_and_eval_events(estimator)
Example #17
0
  def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls,
                                                    eval_distribute_cls):
    train_distribute = train_distribute_cls(
        num_gpus_per_worker=context.num_gpus())

    if eval_distribute_cls:
      eval_distribute = eval_distribute_cls()
    else:
      eval_distribute = None

    cluster_spec = multi_worker_test_base.create_cluster_spec(
        num_workers=3, num_ps=0, has_eval=True)
    # 3 workers and 1 evaluator.
    self._barrier = dc._Barrier(4)
    threads = self._run_multiple_tasks_in_threads(
        cluster_spec, train_distribute, eval_distribute)
    threads[WORKER][0].join()
    threads[EVALUATOR][0].join()

    estimator = self._get_estimator(train_distribute, eval_distribute)
    self._inspect_train_and_eval_events(estimator)
  def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls,
                                                    eval_distribute_cls):
    train_distribute = train_distribute_cls(
        num_gpus_per_worker=context.num_gpus())

    if eval_distribute_cls:
      eval_distribute = eval_distribute_cls(
          num_gpus_per_worker=context.num_gpus())
    else:
      eval_distribute = None

    cluster_spec = multi_worker_test_base.create_cluster_spec(
        num_workers=3, num_ps=0, has_eval=True)
    # 3 workers and 1 evaluator.
    self._barrier = dc._Barrier(4)
    threads = self._run_multiple_tasks_in_threads(
        cluster_spec, train_distribute, eval_distribute)
    threads[WORKER][0].join()
    threads[EVALUATOR][0].join()

    estimator = self._get_estimator(train_distribute, eval_distribute)
    self._inspect_train_and_eval_events(estimator)
  def testFaultToleranceInSyncStrategy(self, strategy_cls, file_format,
                                       preemption_callback):
    """Test fault-tolerance with multi-threading using sync dist-strat.

    This test simulates multi-worker training that is interrupted by a
    preemption, by having two threads, each of which represents a chief and a
    non-chief worker, where the non-chief raises an error in the middle of
    training loop. Upon excepting the error, a new thread with a new cluster
    spec is created to simulate the recovered non-chief worker. Meanwhile, the
    chief worker cannot proceed and hangs since the non-chief worker has
    crashed. To simulate a restart of the chief, a new thread has been prepared
    to run to take over chief with the help of a condition variable. It is
    expected that after the restart of both chief and non-chief workers, the
    training continues from the epoch they previously failed at. The test
    concludes by verifying the preemption-interrupted training can finish with
    the same loss and accuracy had the preemption not occurred.

    Arguments:
      strategy_cls: The strategy class to use.
      file_format: `h5` or `tf`.
      preemption_callback: The callback to simulate preemption.
    """

    def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
      with test.mock.patch.object(dc, '_run_std_server',
                                  self._make_mock_run_std_server()):
        # Condition variable that blocks the thread that represents the
        # restarted chief.
        cv = kwargs.get('cv', None)
        # `before_restart` is True for the threads that represent the original
        # chief and non-chief worker, and False for threads that represent the
        # restarted chief and non-chief workers.
        before_restart = kwargs['before_restart']
        if kwargs['new_chief']:
          # `new_chief` is only True for the restarted chief thread. It waits
          # until non-chief is preempted and restarted to simulate the causality
          # where chief's restart results from non-chief's failure.
          cv.acquire()
          while not hasattr(cv, 'preempted'):
            cv.wait()
          cv.release()

        # Model building under strategy scope. Following is the code we expect
        # the user runs on every worker.
        strategy = get_strategy_object(strategy_cls)
        batch_size = 64
        steps = 3
        train_ds, _ = _mnist_synthetic_dataset(batch_size, steps)
        with strategy.scope():
          model = _get_model((28, 28, 1))

        # Function to start a new thread. This will be called twice in the
        # following code: one represents the restart of the non-chief, and one
        # represents the restart of the chief as a result of the restart of the
        # non-chief (so the training can continue in sync).
        def start_new_thread(new_chief=False):
          new_thread_tf_config = json.loads(os.environ['TF_CONFIG'])
          new_thread_tf_config['cluster']['worker'] = kwargs['reserved_ports']
          return self._run_task_in_thread(
              task_fn=_independent_worker_fn,
              cluster_spec=None,
              task_type=None,
              task_id=None,
              tf_config=new_thread_tf_config,
              before_restart=False,
              cv=cv,
              new_chief=new_chief)

        if test_base.is_chief() and before_restart:
          # Chief to start a new thread (that will be blocked by a condition
          # variable until the non-chief's new thread is started). The thread
          # for (recovered) chief is started before entering `fit()` because
          # the original chief thread will eventually hang and be ignored.
          start_new_thread(new_chief=True)

        try:

          class CkptSavedEpochAssertingCallback(callbacks.Callback):

            def __init__(self, test_obj):
              super(CkptSavedEpochAssertingCallback, self).__init__()
              self.test_obj = test_obj

            def on_epoch_begin(self, epoch, logs=None):
              # `_ckpt_saved_epoch` attribute is set at the end of every epoch.
              self.test_obj.assertEqual(self.model._ckpt_saved_epoch is None,
                                        epoch == 0)

          callbacks_list = [
              callbacks.ModelCheckpoint(
                  filepath=saving_filepath,
                  save_weights_only=True,
                  load_weights_on_restart=True),
              CkptSavedEpochAssertingCallback(self)
          ]
          if before_restart:
            callbacks_list.append(preemption_callback())

          self.assertIsNone(model._ckpt_saved_epoch)
          history = model.fit(
              x=train_ds,
              epochs=num_epoch,
              steps_per_epoch=steps,
              callbacks=callbacks_list)
          self.assertIsNone(model._ckpt_saved_epoch)

          # `history` of the training result is collected to be compared against
          # each other. It is expected that the training results (loss and
          # accuracy`) are the same with or without preemption.
          self._histories.append(history.history)

        except RuntimeError:
          # pylint: disable=g-assert-in-except
          self.assertTrue(before_restart)
          # Reset the barrier so the new threads simulating recovery can
          # continue.
          self._barrier._counter = 0
          self._barrier._flag = False

          # Now that the non-chief has been preempted, it notifies the thread
          # that simulates the restarted chief to start so they can be back in
          # sync.
          cv.acquire()
          cv.preempted = True
          cv.notify()
          cv.release()

          # At this point we should discard the original non-chief thread, and
          # start the new thread that simulates the restarted non-chief, hence
          # joining the thread and return.
          self.join_independent_workers([start_new_thread()])
          return

        # Successful end of a `fit()` call.
        self._successful_thread_ends += 1
        self.assertFalse(before_restart)

    # Common parameters
    num_workers = 2
    num_epoch = 3
    # History list storing the results for preemption and no preemption cases.
    self._histories = []
    # Pass `saving_filepath` from the parent thread to ensure every worker has
    # the same filepath to save.
    saving_filepath = os.path.join(self.get_temp_dir(),
                                   'checkpoint.' + file_format)
    strategy = get_strategy_object(strategy_cls)

    # Case 1: Training for `num_epoch` without preemptions.
    cluster_spec = test_base.create_cluster_spec(num_workers=num_workers)
    self._barrier = dc._Barrier(2)
    self._successful_thread_ends = 0
    threads = self.run_multiple_tasks_in_threads(
        _independent_worker_fn,
        cluster_spec,
        saving_filepath=saving_filepath,
        before_restart=False,
        new_chief=False)
    if os.path.exists(saving_filepath):
      os.remove(saving_filepath)
    threads_to_join = []
    if strategy.extended.experimental_between_graph:
      for ts in threads.values():
        threads_to_join.extend(ts)
    else:
      threads_to_join = [threads['worker'][0]]
    self.join_independent_workers(threads_to_join)
    self.assertEqual(self._successful_thread_ends, 2)

    # Case 2: Training for `num_epoch` epoch with preemptions.
    # The preemption is simulated at both epoch boundary and batch boundary.
    cluster_spec = test_base.create_cluster_spec(num_workers=num_workers)
    cv = threading.Condition()
    self._barrier = dc._Barrier(2)
    # Ports reserved for new threads simulating recovery.
    reserved_ports = [
        'localhost:%s' % test_base.pick_unused_port()
        for _ in range(num_workers)
    ]
    self._successful_thread_ends = 0
    threads = self.run_multiple_tasks_in_threads(
        _independent_worker_fn,
        cluster_spec,
        saving_filepath=saving_filepath,
        reserved_ports=reserved_ports,
        before_restart=True,
        cv=cv,
        new_chief=False)
    if os.path.exists(saving_filepath):
      os.remove(saving_filepath)
    threads_to_join = []
    if strategy.extended.experimental_between_graph:
      # Only join the non-chief thread since the first thread for chief will
      # eventually hang and be ignored.
      threads_to_join = [threads['worker'][1]]
    else:
      threads_to_join = [threads['worker'][0]]
    self.join_independent_workers(threads_to_join)
    self.assertEqual(self._successful_thread_ends, 2)

    def assert_all_elements_are_identical(list_to_check):
      first_item = list_to_check[0]
      for item in list_to_check[1:]:
        self.assertAllClose(first_item, item, rtol=1e-5, atol=1e-5)

    # Important: the results from preemption interrupted and non-interrupted
    # cases should give the same final results.
    assert_all_elements_are_identical(
        [history['acc'][-1] for history in self._histories])
    assert_all_elements_are_identical(
        [history['loss'][-1] for history in self._histories])
    # The length of `self._histories` would be num_workers * num_runs (3).
    self.assertLen(self._histories, 4)
Example #20
0
    def testFaultToleranceInSyncStrategy(self, strategy_cls, file_format,
                                         preemption_callback,
                                         save_weights_only,
                                         load_weights_on_restart):
        """Test fault-tolerance with multi-threading using sync dist-strat.

    This test simulates multi-worker training that is interrupted by a
    preemption, by having two threads, each of which represents a chief and a
    non-chief worker, where the non-chief raises an error in the middle of
    training loop. Upon excepting the error, a new thread with a new cluster
    spec is created to simulate the recovered non-chief worker. Meanwhile, the
    chief worker cannot proceed and hangs since the non-chief worker has
    crashed. To simulate a restart of the chief, a new thread has been prepared
    to run to take over chief with the help of a condition variable. It is
    expected that after the restart of both chief and non-chief workers, the
    training continues from the epoch they previously failed at. The test
    concludes by verifying the preemption-interrupted training can finish with
    the same loss and accuracy had the preemption not occurred.

    TODO(rchao): Add test to check preemption on chief (possibly using multi
    processes).

    TODO(rchao): Add test to check fault-tolerance with multiple `model.fit()`.

    Arguments:
      strategy_cls: The strategy class to use.
      file_format: `h5` or `tf`.
      preemption_callback: The callback to simulate preemption.
      save_weights_only: The argument for `model.fit()`'s `save_weights_only`.
      load_weights_on_restart: The argument for `model.fit()`'s
        `load_weights_on_restart`.
    """
        def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
            with test.mock.patch.object(dc, '_run_std_server',
                                        self._make_mock_run_std_server()):
                # `before_restart` is True for the threads that represent the original
                # chief and non-chief worker, and False for threads that represent the
                # restarted chief and non-chief workers.
                before_restart = kwargs['before_restart']

                # Model building under strategy scope. Following is the code we expect
                # the user runs on every worker.
                strategy = get_strategy_object(strategy_cls)
                batch_size = 64
                steps = 3
                train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
                    batch_size, steps)

                with strategy.scope():
                    model = multi_worker_testing_utils.get_mnist_model(
                        (28, 28, 1))

                # Function to start a new thread. This will be called twice in the
                # following code: one represents the restart of the non-chief, and one
                # represents the restart of the chief as a result of the restart of the
                # non-chief (so the training can continue in sync).
                def start_new_thread(new_chief):
                    new_thread_tf_config = json.loads(os.environ['TF_CONFIG'])

                    # Update the ports in new chief and new worker threads.
                    new_thread_tf_config['cluster']['worker'] = kwargs[
                        'reserved_ports']

                    # Since both new chief and new worker threads are started from the
                    # worker thread, we need to overwrite the tf config task index.
                    new_thread_tf_config['task'][
                        'index'] = 0 if new_chief else 1
                    return self._run_task_in_thread(
                        task_fn=_independent_worker_fn,
                        cluster_spec=None,
                        task_type=None,
                        task_id=None,
                        tf_config=new_thread_tf_config,
                        before_restart=False,
                        new_chief=new_chief)

                try:

                    class CkptSavedEpochAssertingCallback(callbacks.Callback):
                        def __init__(self, test_obj):
                            super(CkptSavedEpochAssertingCallback,
                                  self).__init__()
                            self.test_obj = test_obj

                        def on_epoch_begin(self, epoch, logs=None):
                            # `_ckpt_saved_epoch` attribute is set at the end of every epoch.
                            self.test_obj.assertEqual(
                                K.eval(self.model._ckpt_saved_epoch) ==
                                training_state.CKPT_SAVED_EPOCH_UNUSED_VALUE,
                                epoch == 0)

                    callbacks_list = [
                        callbacks.ModelCheckpoint(
                            filepath=saving_filepath,
                            save_weights_only=save_weights_only,
                            load_weights_on_restart=load_weights_on_restart),
                        CkptSavedEpochAssertingCallback(self)
                    ]
                    if before_restart:
                        callbacks_list.append(preemption_callback())

                    self.assertFalse(
                        hasattr(model, training_state.CKPT_SAVED_EPOCH))
                    history = model.fit(x=train_ds,
                                        epochs=num_epoch,
                                        steps_per_epoch=steps,
                                        callbacks=callbacks_list)
                    self.assertFalse(
                        hasattr(model, training_state.CKPT_SAVED_EPOCH))

                    # `history` of the training result is collected to be compared against
                    # each other. It is expected that the training results (loss and
                    # accuracy`) are the same with or without preemption.
                    self._histories.append(history.history)

                except RuntimeError:
                    # pylint: disable=g-assert-in-except
                    self.assertTrue(before_restart)
                    # Reset the barrier so the new threads simulating recovery can
                    # continue.
                    self._barrier._counter = 0
                    self._barrier._flag = False

                    # At this point we block the original non-chief thread, and
                    # start the new threads that simulate the restarted chief and
                    # non-chief, joining the threads and return.
                    new_chief_thread = start_new_thread(new_chief=True)
                    new_worker_thread = start_new_thread(new_chief=False)
                    self.join_independent_workers(
                        [new_chief_thread, new_worker_thread])
                    return

                # Successful end of a `fit()` call.
                with self._lock:
                    self._successful_thread_ends += 1
                self.assertFalse(before_restart)

        # Common parameters
        num_workers = 2
        num_epoch = 3
        # History list storing the results for preemption and no preemption cases.
        self._histories = []
        # Lock required to prevent race condition between two threads.
        self._lock = threading.Lock()
        strategy = get_strategy_object(strategy_cls)

        def handler(signum, frame):
            del signum, frame
            # `session.run()` within `model.fit()` can time out. Skipping it as it
            # doesn't represent the failure of this test.
            self.skipTest('Skipping test due to `session.run()` timeout.')

        signal.signal(signal.SIGALRM, handler)
        # Alarming within 5 min before the test timeouts and fails.
        signal.alarm(240)

        def get_saving_dir_and_filepath():
            saving_dir = tempfile.mkdtemp(prefix=self.get_temp_dir())
            saving_filepath = os.path.join(saving_dir,
                                           'checkpoint.' + file_format)
            return saving_dir, saving_filepath

        # Case 1: Training for `num_epoch` without preemptions.
        cluster_spec = test_base.create_cluster_spec(num_workers=num_workers)
        self._barrier = dc._Barrier(2)
        self._successful_thread_ends = 0
        # Get a new temporary filepath to save the checkpoint to.
        saving_dir, saving_filepath = get_saving_dir_and_filepath()
        threads = self.run_multiple_tasks_in_threads(
            _independent_worker_fn,
            cluster_spec,
            # Pass `saving_filepath` from the parent thread to ensure every worker
            # has the same filepath to save.
            saving_filepath=saving_filepath,
            before_restart=False,
            new_chief=False)
        threads_to_join = []
        if strategy.extended.experimental_between_graph:
            for ts in threads.values():
                threads_to_join.extend(ts)
        else:
            threads_to_join = [threads['worker'][0]]
        self.join_independent_workers(threads_to_join)

        # `self.test_skipped_reason` could be set when a non-main thread attempts
        # to skip the test.
        # `multi_worker_test_base.skip_if_grpc_server_cant_be_started()` is an
        # example of where this can be set. Since raising `SkipTest` in a non-main
        # thread doesn't actually skip the test, we check if the test should be
        # skipped here once we have joined the threads.
        if getattr(self, 'test_skipped_reason', None) is not None:
            self.skipTest(self.test_skipped_reason)

        self.assertTrue(
            training_state.remove_checkpoint_if_exists(saving_dir,
                                                       saving_filepath))
        self.assertEqual(self._successful_thread_ends, 2)

        # Case 2: Training for `num_epoch` epoch with preemptions.
        # The preemption is simulated at both epoch boundary and batch boundary.
        cluster_spec = test_base.create_cluster_spec(num_workers=num_workers)
        self._barrier = dc._Barrier(2)
        # Ports reserved for new threads simulating recovery.
        reserved_ports = [
            'localhost:%s' % test_base.pick_unused_port()
            for _ in range(num_workers)
        ]
        self._successful_thread_ends = 0
        # Get a new temporary filepath to save the checkpoint to.
        saving_dir, saving_filepath = get_saving_dir_and_filepath()
        threads = self.run_multiple_tasks_in_threads(
            _independent_worker_fn,
            cluster_spec,
            # Pass `saving_filepath` from the parent thread to ensure every worker
            # has the same filepath to save.
            saving_filepath=saving_filepath,
            reserved_ports=reserved_ports,
            before_restart=True,
            new_chief=False)
        threads_to_join = []
        if strategy.extended.experimental_between_graph:
            # Only join the non-chief thread since the first thread for chief will
            # eventually hang and be ignored.
            threads_to_join = [threads['worker'][1]]
        else:
            threads_to_join = [threads['worker'][0]]
        self.join_independent_workers(threads_to_join)
        if getattr(self, 'test_skipped_reason', None) is not None:
            self.skipTest(self.test_skipped_reason)

        self.assertTrue(
            training_state.remove_checkpoint_if_exists(saving_dir,
                                                       saving_filepath))
        self.assertEqual(self._successful_thread_ends, 2)

        def assert_all_elements_are_identical(list_to_check):
            first_item = list_to_check[0]
            for item in list_to_check[1:]:
                self.assertAllClose(first_item, item, rtol=2e-5, atol=1e-5)

        # Important: the results from preemption interrupted and non-interrupted
        # cases should give the same final results.
        assert_all_elements_are_identical(
            [history['acc'][-1] for history in self._histories])
        assert_all_elements_are_identical(
            [history['loss'][-1] for history in self._histories])
        # The length of `self._histories` would be num_workers * num_runs (3).
        self.assertLen(self._histories, 4)

        # Results from case 1 should have 3 full epochs.
        self.assertLen(self._histories[0]['acc'], 3)
        # Results from case 2 should only have 2 full epochs because it restarted at
        # epoch 1.
        self.assertLen(self._histories[-1]['acc'], 2)
  def run_optimizer_comparison_with_simple_bias_model(
      self, strategy_cls, optimizer_class_1, optimizer_class_2):

    def get_input_datasets():
      # Simple training input.
      train_input = [[1]] * 16
      train_label = [[0]] * 16
      ds = dataset_ops.Dataset.from_tensor_slices((train_input, train_label))
      ds = maybe_shard_dataset(ds)
      # TODO(rchao): Investigate to figure out the reason for having 8 workers
      # instead of 2 as expected.
      return ds.batch(8, drop_remainder=True)

    def get_simple_bias_model():

      class Bias(base_layer.Layer):

        def build(self, input_shape):
          self.bias = self.add_variable('bias', (1,), initializer='zeros')

        def call(self, inputs):
          return inputs + self.bias

      model = sequential.Sequential()
      model.add(Bias(input_shape=(1,)))

      return model

    self._lock = threading.Lock()
    cluster_spec = test_base.create_cluster_spec(num_workers=2)
    self._barrier = dc._Barrier(2)

    def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
      """Simulates an Independent Worker inside a thread."""
      # TODO(rchao): Refactor to abstract the common boilerplate out.
      with test.mock.patch.object(dc, '_run_std_server',
                                  self._make_mock_run_std_server()):

        model = get_simple_bias_model()

        initial_weights = model.get_weights()

        def _get_model_results(optimizer, initial_weights):

          # Clear Keras session to reset device assignment
          keras.backend._SESSION.session = None
          strategy = strategy_cls()

          with strategy.scope():
            train_ds = get_input_datasets()
            model = get_simple_bias_model()
            model.set_weights(initial_weights)
            model.compile(loss='mae', optimizer=optimizer, metrics=['mae'])

          return {
              'trained_loss_and_accuracy':
                  model.fit(x=train_ds, epochs=20).history,
              'trained_weights':
                  model.get_weights(),
          }

        results1 = _get_model_results(optimizer_class_1(0.01), initial_weights)
        results2 = _get_model_results(optimizer_class_2(0.01), initial_weights)

        for key in results1:
          self.assertAllClose(
              results1[key],
              results2[key],
              atol=1e-5,
              rtol=1e-5,
              msg='Fail to assert {}'.format(key))

    threads = self.run_multiple_tasks_in_threads(_independent_worker_fn,
                                                 cluster_spec)

    threads_to_join = []
    strategy = strategy_cls()
    if strategy.extended.experimental_between_graph:
      for ts in threads.values():
        threads_to_join.extend(ts)
    else:
      threads_to_join = [threads['worker'][0]]
    self.join_independent_workers(threads_to_join)
Example #22
0
    def run_optimizer_comparison_with_simple_bias_model(
            self, strategy_cls, optimizer_class_1, optimizer_class_2):
        def get_input_datasets():
            # Simple training input.
            train_input = [[1]] * 16
            train_label = [[0]] * 16
            ds = dataset_ops.Dataset.from_tensor_slices(
                (train_input, train_label))
            # TODO(rchao): Investigate to figure out the reason for having 8 workers
            # instead of 2 as expected.
            return ds.batch(8, drop_remainder=True)

        def get_simple_bias_model():
            class Bias(base_layer.Layer):
                def build(self, input_shape):
                    self.bias = self.add_variable('bias', (1, ),
                                                  initializer='zeros')

                def call(self, inputs):
                    return inputs + self.bias

            model = sequential.Sequential()
            model.add(Bias(input_shape=(1, )))

            return model

        self._lock = threading.Lock()
        cluster_spec = test_base.create_cluster_spec(num_workers=2)
        self._barrier = dc._Barrier(2)

        def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
            """Simulates an Independent Worker inside a thread."""
            # TODO(rchao): Refactor to abstract the common boilerplate out.
            with test.mock.patch.object(dc, '_run_std_server',
                                        self._make_mock_run_std_server()):

                model = get_simple_bias_model()

                initial_weights = model.get_weights()

                def _get_model_results(optimizer, initial_weights):

                    # Clear Keras session to reset device assignment
                    keras.backend._SESSION.session = None
                    strategy = strategy_cls()

                    with strategy.scope():
                        train_ds = get_input_datasets()
                        model = get_simple_bias_model()
                        model.set_weights(initial_weights)
                        model.compile(loss='mae',
                                      optimizer=optimizer,
                                      metrics=['mae'])

                    return {
                        'trained_loss_and_accuracy':
                        model.fit(x=train_ds, epochs=20).history,
                        'trained_weights':
                        model.get_weights(),
                    }

                results1 = _get_model_results(optimizer_class_1(0.01),
                                              initial_weights)
                results2 = _get_model_results(optimizer_class_2(0.01),
                                              initial_weights)

                for key in results1:
                    self.assertAllClose(results1[key],
                                        results2[key],
                                        atol=1e-5,
                                        rtol=1e-5,
                                        msg='Fail to assert {}'.format(key))

        threads = self.run_multiple_tasks_in_threads(_independent_worker_fn,
                                                     cluster_spec)

        threads_to_join = []
        strategy = strategy_cls()
        if strategy.extended.experimental_between_graph:
            for ts in threads.values():
                threads_to_join.extend(ts)
        else:
            threads_to_join = [threads['worker'][0]]
        self.join_independent_workers(threads_to_join)
Example #23
0
    def testFaultToleranceInSyncStrategy(self, strategy_cls, file_format,
                                         preemption_callback):
        """Test fault-tolerance with multi-threading using sync dist-strat.

    This test simulates multi-worker training that is interrupted by a
    preemption, by having two threads, each of which represents a chief and a
    non-chief worker, where the non-chief raises an error in the middle of
    training loop. Upon excepting the error, a new thread with a new cluster
    spec is created to simulate the recovered non-chief worker. Meanwhile, the
    chief worker cannot proceed and hangs since the non-chief worker has
    crashed. To simulate a restart of the chief, a new thread has been prepared
    to run to take over chief with the help of a condition variable. It is
    expected that after the restart of both chief and non-chief workers, the
    training continues from the epoch they previously failed at. The test
    concludes by verifying the preemption-interrupted training can finish with
    the same loss and accuracy had the preemption not occurred.

    Arguments:
      strategy_cls: The strategy class to use.
      file_format: `h5` or `tf`.
      preemption_callback: The callback to simulate preemption.
    """
        def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
            with test.mock.patch.object(dc, '_run_std_server',
                                        self._make_mock_run_std_server()):
                # Condition variable that blocks the thread that represents the
                # restarted chief.
                cv = kwargs.get('cv', None)
                # `before_restart` is True for the threads that represent the original
                # chief and non-chief worker, and False for threads that represent the
                # restarted chief and non-chief workers.
                before_restart = kwargs['before_restart']
                if kwargs['new_chief']:
                    # `new_chief` is only True for the restarted chief thread. It waits
                    # until non-chief is preempted and restarted to simulate the causality
                    # where chief's restart results from non-chief's failure.
                    cv.acquire()
                    while not hasattr(cv, 'preempted'):
                        cv.wait()
                    cv.release()

                # Model building under strategy scope. Following is the code we expect
                # the user runs on every worker.
                strategy = get_strategy_object(strategy_cls)
                batch_size = 64
                steps = 3
                train_ds, _ = _mnist_synthetic_dataset(batch_size, steps)
                with strategy.scope():
                    model = _get_model((28, 28, 1))

                # Function to start a new thread. This will be called twice in the
                # following code: one represents the restart of the non-chief, and one
                # represents the restart of the chief as a result of the restart of the
                # non-chief (so the training can continue in sync).
                def start_new_thread(new_chief=False):
                    new_thread_tf_config = json.loads(os.environ['TF_CONFIG'])
                    new_thread_tf_config['cluster']['worker'] = kwargs[
                        'reserved_ports']
                    return self._run_task_in_thread(
                        task_fn=_independent_worker_fn,
                        cluster_spec=None,
                        task_type=None,
                        task_id=None,
                        tf_config=new_thread_tf_config,
                        before_restart=False,
                        cv=cv,
                        new_chief=new_chief)

                if test_base.is_chief() and before_restart:
                    # Chief to start a new thread (that will be blocked by a condition
                    # variable until the non-chief's new thread is started). The thread
                    # for (recovered) chief is started before entering `fit()` because
                    # the original chief thread will eventually hang and be ignored.
                    start_new_thread(new_chief=True)

                try:

                    class CkptSavedEpochAssertingCallback(callbacks.Callback):
                        def __init__(self, test_obj):
                            super(CkptSavedEpochAssertingCallback,
                                  self).__init__()
                            self.test_obj = test_obj

                        def on_epoch_begin(self, epoch, logs=None):
                            # `_ckpt_saved_epoch` attribute is set at the end of every epoch.
                            self.test_obj.assertEqual(
                                self.model._ckpt_saved_epoch is None,
                                epoch == 0)

                    callbacks_list = [
                        callbacks.ModelCheckpoint(
                            filepath=saving_filepath,
                            save_weights_only=True,
                            load_weights_on_restart=True),
                        CkptSavedEpochAssertingCallback(self)
                    ]
                    if before_restart:
                        callbacks_list.append(preemption_callback())

                    self.assertIsNone(model._ckpt_saved_epoch)
                    history = model.fit(x=train_ds,
                                        epochs=num_epoch,
                                        steps_per_epoch=steps,
                                        callbacks=callbacks_list)
                    self.assertIsNone(model._ckpt_saved_epoch)

                    # `history` of the training result is collected to be compared against
                    # each other. It is expected that the training results (loss and
                    # accuracy`) are the same with or without preemption.
                    self._histories.append(history.history)

                except RuntimeError:
                    # pylint: disable=g-assert-in-except
                    self.assertTrue(before_restart)
                    # Reset the barrier so the new threads simulating recovery can
                    # continue.
                    self._barrier._counter = 0
                    self._barrier._flag = False

                    # Now that the non-chief has been preempted, it notifies the thread
                    # that simulates the restarted chief to start so they can be back in
                    # sync.
                    cv.acquire()
                    cv.preempted = True
                    cv.notify()
                    cv.release()

                    # At this point we should discard the original non-chief thread, and
                    # start the new thread that simulates the restarted non-chief, hence
                    # joining the thread and return.
                    self.join_independent_workers([start_new_thread()])
                    return

                # Successful end of a `fit()` call.
                self._successful_thread_ends += 1
                self.assertFalse(before_restart)

        # Common parameters
        num_workers = 2
        num_epoch = 3
        # History list storing the results for preemption and no preemption cases.
        self._histories = []
        # Pass `saving_filepath` from the parent thread to ensure every worker has
        # the same filepath to save.
        saving_filepath = os.path.join(self.get_temp_dir(),
                                       'checkpoint.' + file_format)
        strategy = get_strategy_object(strategy_cls)

        # Case 1: Training for `num_epoch` without preemptions.
        cluster_spec = test_base.create_cluster_spec(num_workers=num_workers)
        self._barrier = dc._Barrier(2)
        self._successful_thread_ends = 0
        threads = self.run_multiple_tasks_in_threads(
            _independent_worker_fn,
            cluster_spec,
            saving_filepath=saving_filepath,
            before_restart=False,
            new_chief=False)
        if os.path.exists(saving_filepath):
            os.remove(saving_filepath)
        threads_to_join = []
        if strategy.extended.experimental_between_graph:
            for ts in threads.values():
                threads_to_join.extend(ts)
        else:
            threads_to_join = [threads['worker'][0]]
        self.join_independent_workers(threads_to_join)
        self.assertEqual(self._successful_thread_ends, 2)

        # Case 2: Training for `num_epoch` epoch with preemptions.
        # The preemption is simulated at both epoch boundary and batch boundary.
        cluster_spec = test_base.create_cluster_spec(num_workers=num_workers)
        cv = threading.Condition()
        self._barrier = dc._Barrier(2)
        # Ports reserved for new threads simulating recovery.
        reserved_ports = [
            'localhost:%s' % test_base.pick_unused_port()
            for _ in range(num_workers)
        ]
        self._successful_thread_ends = 0
        threads = self.run_multiple_tasks_in_threads(
            _independent_worker_fn,
            cluster_spec,
            saving_filepath=saving_filepath,
            reserved_ports=reserved_ports,
            before_restart=True,
            cv=cv,
            new_chief=False)
        if os.path.exists(saving_filepath):
            os.remove(saving_filepath)
        threads_to_join = []
        if strategy.extended.experimental_between_graph:
            # Only join the non-chief thread since the first thread for chief will
            # eventually hang and be ignored.
            threads_to_join = [threads['worker'][1]]
        else:
            threads_to_join = [threads['worker'][0]]
        self.join_independent_workers(threads_to_join)
        self.assertEqual(self._successful_thread_ends, 2)

        def assert_all_elements_are_identical(list_to_check):
            first_item = list_to_check[0]
            for item in list_to_check[1:]:
                self.assertAllClose(first_item, item, rtol=1e-5, atol=1e-5)

        # Important: the results from preemption interrupted and non-interrupted
        # cases should give the same final results.
        assert_all_elements_are_identical(
            [history['acc'][-1] for history in self._histories])
        assert_all_elements_are_identical(
            [history['loss'][-1] for history in self._histories])
        # The length of `self._histories` would be num_workers * num_runs (3).
        self.assertLen(self._histories, 4)