def _model_setup(test_obj, file_format):
    """Set up a MNIST Keras model for testing purposes.

    This function builds a MNIST Keras model and returns relevant information
    for testing.

    Args:
      test_obj: The `TestCase` testing object.
      file_format: File format for checkpoints. 'tf' or 'h5'.

    Returns:
      A tuple of (model, saving_filepath, train_ds, steps) where train_ds is
      the training dataset.
    """
    batch_size = 64
    steps = 2
    with tf.distribute.MultiWorkerMirroredStrategy().scope():
        # TODO(b/142509827): In rare cases this errors out at C++ level with the
        # "Connect failed" error message.
        train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
            batch_size, steps)
        model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
    # Pass saving_filepath from the parent thread to ensure every worker has the
    # same filepath to save.
    saving_filepath = os.path.join(test_obj.get_temp_dir(),
                                   "checkpoint." + file_format)
    return model, saving_filepath, train_ds, steps
Exemplo n.º 2
0
    def testSimpleModelIndependentWorkerSync(self, strategy):
        verification_callback = MultiWorkerVerificationCallback(
            num_epoch=2,
            num_worker=len(
                json.loads(os.environ["TF_CONFIG"])["cluster"]["worker"]),
        )
        verification_callback.is_between_graph = (
            strategy.extended.experimental_between_graph)
        batch_size = 64
        steps = 2
        train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
            batch_size, steps)
        with strategy.scope():
            model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
        orig_loss, _ = model.evaluate(train_ds, steps=steps)
        history = model.fit(
            x=train_ds,
            epochs=2,
            steps_per_epoch=steps,
            callbacks=[verification_callback],
        )
        self.assertIsInstance(history, keras.callbacks.History)
        trained_loss, _ = model.evaluate(train_ds, steps=steps)
        self.assertLess(trained_loss, orig_loss)

        verification_callback.verify(self)
Exemplo n.º 3
0
 def testCheckpointExists(self, file_format, save_weights_only):
   with self.cached_session():
     train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(64, 2)
     model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
     saving_dir = self.get_temp_dir()
     saving_filepath = os.path.join(saving_dir, 'checkpoint.' + file_format)
     callbacks_list = [
         callbacks.ModelCheckpoint(
             filepath=saving_filepath, save_weights_only=save_weights_only)
     ]
     self.assertFalse(tf.io.gfile.exists(saving_filepath))
     model.fit(
         x=train_ds, epochs=2, steps_per_epoch=2, callbacks=callbacks_list)
     tf_saved_model_exists = tf.io.gfile.exists(saving_filepath)
     tf_weights_only_checkpoint_exists = tf.io.gfile.exists(
         saving_filepath + '.index')
     self.assertTrue(
         tf_saved_model_exists or tf_weights_only_checkpoint_exists)
Exemplo n.º 4
0
 def _independent_worker_fn(*args, **kwargs):  # pylint: disable=unused-argument
   """Simulates an Independent Worker inside of a thread."""
   with tf.compat.v1.test.mock.patch.object(dc, '_run_std_server',
                               self._make_mock_run_std_server()):
     strategy = strategy_cls()
     verification_callback.is_between_graph = \
         strategy.extended.experimental_between_graph
     batch_size = 64
     steps = 2
     train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
         batch_size, steps)
     with strategy.scope():
       model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
     orig_loss, _ = model.evaluate(train_ds, steps=steps)
     callbacks_for_fit = tf.nest.flatten(
         kwargs.get('verification_callback', []))
     history = model.fit(
         x=train_ds,
         epochs=num_epoch,
         steps_per_epoch=steps,
         callbacks=callbacks_for_fit)
     self.assertIsInstance(history, keras.callbacks.History)
     trained_loss, _ = model.evaluate(train_ds, steps=steps)
     self.assertLess(trained_loss, orig_loss)