def _model_setup(test_obj, file_format):
    """Set up a MNIST Keras model for testing purposes.

  This function builds a MNIST Keras model and returns relevant information
  for testing.

  Args:
    test_obj: The `TestCase` testing object.
    file_format: File format for checkpoints. 'tf' or 'h5'.

  Returns:
    A tuple of (model, saving_filepath, train_ds, steps) where train_ds is
    the training dataset.
  """
    batch_size = 64
    steps = 2
    with collective_strategy.CollectiveAllReduceStrategy().scope():
        # TODO(b/142509827): In rare cases this errors out at C++ level with the
        # "Connect failed" error message.
        train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
            batch_size, steps)
        model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
    # Pass saving_filepath from the parent thread to ensure every worker has the
    # same filepath to save.
    saving_filepath = os.path.join(test_obj.get_temp_dir(),
                                   'checkpoint.' + file_format)
    return model, saving_filepath, train_ds, steps
def _model_setup(test_obj, file_format):
    """Set up a MNIST Keras model for testing purposes.

  This function builds a MNIST Keras model and returns relevant information
  for testing.

  Args:
    test_obj: The `TestCase` testing object.
    file_format: File format for checkpoints. 'tf' or 'h5'.

  Returns:
    A tuple of (model, saving_filepath, train_ds, steps) where train_ds is
    the training dataset.
  """
    batch_size = 64
    steps = 2
    with collective_strategy.CollectiveAllReduceStrategy().scope():
        # TODO(b/142509827): In rare cases this errors out at C++ level with the
        # following error message:
        # subchannel.cc:1000] Connect failed: {"created":"@1570753640.827421717",
        # "description":"Failed to connect to remote host: Connection refused",
        # "errno":111,"file":"third_party/grpc/src/core/lib/iomgr/tcp_client_posix.cc",
        # "file_line":200,"os_error":"Connection refused","syscall":"connect",
        # "target_address":"ipv6:[::1]:17271"}
        train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
            batch_size, steps)
        model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
    # Pass saving_filepath from the parent thread to ensure every worker has the
    # same filepath to save.
    saving_filepath = os.path.join(test_obj.get_temp_dir(),
                                   'checkpoint.' + file_format)
    return model, saving_filepath, train_ds, steps
        def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
            """Simulates an Independent Worker inside of a thread."""
            # TODO(rchao/yuefengz): The following is run by both worker and ps
            # threads. The distribute coordinator should run std server immediately
            # without configuring the session (or building the graph) on PS.
            with test.mock.patch.object(dc, '_run_std_server',
                                        self._make_mock_run_std_server()):
                batch_size = 64
                steps = 2
                strategy = strategy_cls()
                verification_callback.is_between_graph = \
                    strategy.extended.experimental_between_graph

                train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
                    batch_size, steps)
                val_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
                    batch_size, steps)
                with strategy.scope():
                    model = multi_worker_testing_utils.get_mnist_model(
                        (28, 28, 1))

                    # TODO(b/123868066): Verify callback for model.evaluate().
                    callbacks_for_fit = nest.flatten(
                        kwargs.get('verification_callback', []))
                    history = model.fit(x=train_ds,
                                        epochs=num_epoch,
                                        steps_per_epoch=steps,
                                        validation_data=val_ds,
                                        validation_steps=steps,
                                        callbacks=callbacks_for_fit)
                self.assertIsInstance(history, keras.callbacks.History)
示例#4
0
def _run_standalone_client(test_obj, strategy, cluster_spec):
    input_shape = (28, 28, 1)
    with strategy.scope():
        orig_model = multi_worker_testing_utils.get_mnist_model(input_shape)

    def worker_fn(strategy):
        with ops.Graph().as_default():
            batch_size = 64
            steps = 2

            with strategy.scope():
                train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
                    batch_size, steps)
                model = _clone_and_build_model(orig_model, strategy)

                orig_loss, orig_acc = model.evaluate(train_ds, steps=steps)

                # Workaround for the metrics issue (b/122928955) in async training. This
                # can only be used in standalone client mode.
                multi_worker_util.wait_for_other_workers()

                model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)

                multi_worker_util.wait_for_other_workers()

                trained_loss, trained_acc = model.evaluate(train_ds,
                                                           steps=steps)

            test_obj.assertLessEqual(trained_loss, orig_loss)
            test_obj.assertGreaterEqual(trained_acc, orig_acc)

    dc.run_distribute_coordinator(worker_fn,
                                  strategy,
                                  mode=dc.CoordinatorMode.STANDALONE_CLIENT,
                                  cluster_spec=cluster_spec)
    def testCheckpointExists(self, file_format, save_weights_only):
        train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(64, 2)
        model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
        saving_dir = self.get_temp_dir()
        saving_filepath = os.path.join(saving_dir, 'checkpoint.' + file_format)
        callbacks_list = [
            callbacks.ModelCheckpoint(filepath=saving_filepath,
                                      save_weights_only=save_weights_only)
        ]
        self.assertFalse(file_io.file_exists_v2(saving_filepath))

        try:
            model.fit(x=train_ds,
                      epochs=2,
                      steps_per_epoch=2,
                      callbacks=callbacks_list)
        except NotFoundError as e:
            if 'Failed to create a NewWriteableFile' in e.message:
                self.skipTest(
                    'b/138941852, path not found error in Windows py35.')
        tf_saved_model_exists = file_io.file_exists_v2(saving_filepath)
        tf_weights_only_checkpoint_exists = file_io.file_exists_v2(
            saving_filepath + '.index')
        self.assertTrue(tf_saved_model_exists
                        or tf_weights_only_checkpoint_exists)
  def testCheckpointExists(self, file_format, save_weights_only):
    train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(64, 2)
    model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
    saving_dir = self.get_temp_dir()
    saving_filepath = os.path.join(saving_dir, 'checkpoint.' + file_format)
    callbacks_list = [
        callbacks.ModelCheckpoint(
            filepath=saving_filepath, save_weights_only=save_weights_only)
    ]
    self.assertFalse(training_state.checkpoint_exists(saving_filepath))

    model.fit(x=train_ds, epochs=2, steps_per_epoch=2, callbacks=callbacks_list)
    self.assertTrue(training_state.checkpoint_exists(saving_filepath))
    self.assertTrue(
        training_state.remove_checkpoint_if_exists(saving_dir, saving_filepath))
    self.assertFalse(training_state.checkpoint_exists(saving_filepath))
示例#7
0
 def testCheckpointExists(self, file_format, save_weights_only):
   with self.cached_session():
     train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(64, 2)
     model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
     saving_dir = self.get_temp_dir()
     saving_filepath = os.path.join(saving_dir, 'checkpoint.' + file_format)
     callbacks_list = [
         callbacks.ModelCheckpoint(
             filepath=saving_filepath, save_weights_only=save_weights_only)
     ]
     self.assertFalse(file_io.file_exists_v2(saving_filepath))
     model.fit(
         x=train_ds, epochs=2, steps_per_epoch=2, callbacks=callbacks_list)
     tf_saved_model_exists = file_io.file_exists_v2(saving_filepath)
     tf_weights_only_checkpoint_exists = file_io.file_exists_v2(
         saving_filepath + '.index')
     self.assertTrue(
         tf_saved_model_exists or tf_weights_only_checkpoint_exists)
示例#8
0
        def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
            """Simulates an Independent Worker inside of a thread."""
            with test.mock.patch.object(dc, '_run_std_server',
                                        self._make_mock_run_std_server()):
                strategy = get_strategy_object(strategy_cls)
                batch_size = 64
                steps = 2
                train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
                    batch_size, steps)
                with strategy.scope():
                    model = multi_worker_testing_utils.get_mnist_model(
                        (28, 28, 1))

                custom_callable(model,
                                self,
                                train_ds,
                                num_epoch,
                                steps,
                                strategy,
                                saving_filepath=kwargs['saving_filepath'])
示例#9
0
    def testSimpleModelIndependentWorkerSync(self, strategy):
        verification_callback = MultiWorkerVerificationCallback(
            num_epoch=2,
            num_worker=len(
                json.loads(os.environ['TF_CONFIG'])['cluster']['worker']))
        verification_callback.is_between_graph = \
            strategy.extended.experimental_between_graph
        batch_size = 64
        steps = 2
        train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
            batch_size, steps)
        with strategy.scope():
            model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
        orig_loss, _ = model.evaluate(train_ds, steps=steps)
        history = model.fit(x=train_ds,
                            epochs=2,
                            steps_per_epoch=steps,
                            callbacks=[verification_callback])
        self.assertIsInstance(history, keras.callbacks.History)
        trained_loss, _ = model.evaluate(train_ds, steps=steps)
        self.assertLess(trained_loss, orig_loss)

        verification_callback.verify(self)
示例#10
0
 def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
   """Simulates an Independent Worker inside of a thread."""
   with test.mock.patch.object(dc, '_run_std_server',
                               self._make_mock_run_std_server()):
     strategy = strategy_cls()
     verification_callback.is_between_graph = \
         strategy.extended.experimental_between_graph
     batch_size = 64
     steps = 2
     train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
         batch_size, steps)
     with strategy.scope():
       model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
     orig_loss, _ = model.evaluate(train_ds, steps=steps)
     callbacks_for_fit = nest.flatten(
         kwargs.get('verification_callback', []))
     history = model.fit(
         x=train_ds,
         epochs=num_epoch,
         steps_per_epoch=steps,
         callbacks=callbacks_for_fit)
     self.assertIsInstance(history, keras.callbacks.History)
     trained_loss, _ = model.evaluate(train_ds, steps=steps)
     self.assertLess(trained_loss, orig_loss)
        def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
            with test.mock.patch.object(dc, '_run_std_server',
                                        self._make_mock_run_std_server()):
                # Condition variable that blocks the thread that represents the
                # restarted chief.
                cv = kwargs.get('cv', None)
                # `before_restart` is True for the threads that represent the original
                # chief and non-chief worker, and False for threads that represent the
                # restarted chief and non-chief workers.
                before_restart = kwargs['before_restart']
                if kwargs['new_chief']:
                    # `new_chief` is only True for the restarted chief thread. It waits
                    # until non-chief is preempted and restarted to simulate the causality
                    # where chief's restart results from non-chief's failure.
                    cv.acquire()
                    while not hasattr(cv, 'preempted'):
                        cv.wait()
                    cv.release()

                # Model building under strategy scope. Following is the code we expect
                # the user runs on every worker.
                strategy = get_strategy_object(strategy_cls)
                batch_size = 64
                steps = 3
                train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
                    batch_size, steps)
                with strategy.scope():
                    model = multi_worker_testing_utils.get_mnist_model(
                        (28, 28, 1))

                # Function to start a new thread. This will be called twice in the
                # following code: one represents the restart of the non-chief, and one
                # represents the restart of the chief as a result of the restart of the
                # non-chief (so the training can continue in sync).
                def start_new_thread(new_chief=False):
                    new_thread_tf_config = json.loads(os.environ['TF_CONFIG'])
                    new_thread_tf_config['cluster']['worker'] = kwargs[
                        'reserved_ports']
                    return self._run_task_in_thread(
                        task_fn=_independent_worker_fn,
                        cluster_spec=None,
                        task_type=None,
                        task_id=None,
                        tf_config=new_thread_tf_config,
                        before_restart=False,
                        cv=cv,
                        new_chief=new_chief)

                if test_base.is_chief() and before_restart:
                    # Chief to start a new thread (that will be blocked by a condition
                    # variable until the non-chief's new thread is started). The thread
                    # for (recovered) chief is started before entering `fit()` because
                    # the original chief thread will eventually hang and be ignored.
                    start_new_thread(new_chief=True)

                try:

                    class CkptSavedEpochAssertingCallback(callbacks.Callback):
                        def __init__(self, test_obj):
                            super(CkptSavedEpochAssertingCallback,
                                  self).__init__()
                            self.test_obj = test_obj

                        def on_epoch_begin(self, epoch, logs=None):
                            # `_ckpt_saved_epoch` attribute is set at the end of every epoch.
                            self.test_obj.assertEqual(
                                K.eval(self.model._ckpt_saved_epoch) ==
                                training_state.CKPT_SAVED_EPOCH_UNUSED_VALUE,
                                epoch == 0)

                    callbacks_list = [
                        callbacks.ModelCheckpoint(
                            filepath=saving_filepath,
                            save_weights_only=save_weights_only,
                            load_weights_on_restart=load_weights_on_restart),
                        CkptSavedEpochAssertingCallback(self)
                    ]
                    if before_restart:
                        callbacks_list.append(preemption_callback())

                    self.assertFalse(
                        hasattr(model, training_state.CKPT_SAVED_EPOCH))
                    history = model.fit(x=train_ds,
                                        epochs=num_epoch,
                                        steps_per_epoch=steps,
                                        callbacks=callbacks_list)
                    self.assertFalse(
                        hasattr(model, training_state.CKPT_SAVED_EPOCH))

                    # `history` of the training result is collected to be compared against
                    # each other. It is expected that the training results (loss and
                    # accuracy`) are the same with or without preemption.
                    self._histories.append(history.history)

                except RuntimeError:
                    # pylint: disable=g-assert-in-except
                    self.assertTrue(before_restart)
                    # Reset the barrier so the new threads simulating recovery can
                    # continue.
                    self._barrier._counter = 0
                    self._barrier._flag = False

                    # Now that the non-chief has been preempted, it notifies the thread
                    # that simulates the restarted chief to start so they can be back in
                    # sync.
                    cv.acquire()
                    cv.preempted = True
                    cv.notify()
                    cv.release()

                    # At this point we should discard the original non-chief thread, and
                    # start the new thread that simulates the restarted non-chief, hence
                    # joining the thread and return.
                    self.join_independent_workers([start_new_thread()])
                    return

                # Successful end of a `fit()` call.
                self._successful_thread_ends += 1
                self.assertFalse(before_restart)
示例#12
0
        def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
            with test.mock.patch.object(dc, '_run_std_server',
                                        self._make_mock_run_std_server()):
                # `before_restart` is True for the threads that represent the original
                # chief and non-chief worker, and False for threads that represent the
                # restarted chief and non-chief workers.
                before_restart = kwargs['before_restart']

                # Model building under strategy scope. Following is the code we expect
                # the user runs on every worker.
                strategy = get_strategy_object(strategy_cls)
                batch_size = 64
                steps = 3
                train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
                    batch_size, steps)

                with strategy.scope():
                    model = multi_worker_testing_utils.get_mnist_model(
                        (28, 28, 1))

                # Function to start a new thread. This will be called twice in the
                # following code: one represents the restart of the non-chief, and one
                # represents the restart of the chief as a result of the restart of the
                # non-chief (so the training can continue in sync).
                def start_new_thread(new_chief):
                    new_thread_tf_config = json.loads(os.environ['TF_CONFIG'])

                    # Update the ports in new chief and new worker threads.
                    new_thread_tf_config['cluster']['worker'] = kwargs[
                        'reserved_ports']

                    # Since both new chief and new worker threads are started from the
                    # worker thread, we need to overwrite the tf config task index.
                    new_thread_tf_config['task'][
                        'index'] = 0 if new_chief else 1
                    return self._run_task_in_thread(
                        task_fn=_independent_worker_fn,
                        cluster_spec=None,
                        task_type=None,
                        task_id=None,
                        tf_config=new_thread_tf_config,
                        before_restart=False,
                        new_chief=new_chief)

                try:

                    class CkptSavedEpochAssertingCallback(callbacks.Callback):
                        def __init__(self, test_obj):
                            super(CkptSavedEpochAssertingCallback,
                                  self).__init__()
                            self.test_obj = test_obj

                        def on_epoch_begin(self, epoch, logs=None):
                            # `_ckpt_saved_epoch` attribute is set at the end of every epoch.
                            self.test_obj.assertEqual(
                                K.eval(self.model._ckpt_saved_epoch) ==
                                training_state.CKPT_SAVED_EPOCH_UNUSED_VALUE,
                                epoch == 0)

                    callbacks_list = [
                        callbacks.ModelCheckpoint(
                            filepath=saving_filepath,
                            save_weights_only=save_weights_only,
                            load_weights_on_restart=load_weights_on_restart),
                        CkptSavedEpochAssertingCallback(self)
                    ]
                    if before_restart:
                        callbacks_list.append(preemption_callback())

                    self.assertFalse(
                        hasattr(model, training_state.CKPT_SAVED_EPOCH))
                    history = model.fit(x=train_ds,
                                        epochs=num_epoch,
                                        steps_per_epoch=steps,
                                        callbacks=callbacks_list)
                    self.assertFalse(
                        hasattr(model, training_state.CKPT_SAVED_EPOCH))

                    # `history` of the training result is collected to be compared against
                    # each other. It is expected that the training results (loss and
                    # accuracy`) are the same with or without preemption.
                    self._histories.append(history.history)

                except RuntimeError:
                    # pylint: disable=g-assert-in-except
                    self.assertTrue(before_restart)
                    # Reset the barrier so the new threads simulating recovery can
                    # continue.
                    self._barrier._counter = 0
                    self._barrier._flag = False

                    # At this point we block the original non-chief thread, and
                    # start the new threads that simulate the restarted chief and
                    # non-chief, joining the threads and return.
                    new_chief_thread = start_new_thread(new_chief=True)
                    new_worker_thread = start_new_thread(new_chief=False)
                    self.join_independent_workers(
                        [new_chief_thread, new_worker_thread])
                    return

                # Successful end of a `fit()` call.
                with self._lock:
                    self._successful_thread_ends += 1
                self.assertFalse(before_restart)