def _model_setup(test_obj, file_format): """Set up a MNIST Keras model for testing purposes. This function builds a MNIST Keras model and returns relevant information for testing. Args: test_obj: The `TestCase` testing object. file_format: File format for checkpoints. 'tf' or 'h5'. Returns: A tuple of (model, saving_filepath, train_ds, steps) where train_ds is the training dataset. """ batch_size = 64 steps = 2 with tf.distribute.MultiWorkerMirroredStrategy().scope(): # TODO(b/142509827): In rare cases this errors out at C++ level with the # "Connect failed" error message. train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset( batch_size, steps) model = multi_worker_testing_utils.get_mnist_model((28, 28, 1)) # Pass saving_filepath from the parent thread to ensure every worker has the # same filepath to save. saving_filepath = os.path.join(test_obj.get_temp_dir(), "checkpoint." + file_format) return model, saving_filepath, train_ds, steps
def testSimpleModelIndependentWorkerSync(self, strategy): verification_callback = MultiWorkerVerificationCallback( num_epoch=2, num_worker=len( json.loads(os.environ["TF_CONFIG"])["cluster"]["worker"]), ) verification_callback.is_between_graph = ( strategy.extended.experimental_between_graph) batch_size = 64 steps = 2 train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset( batch_size, steps) with strategy.scope(): model = multi_worker_testing_utils.get_mnist_model((28, 28, 1)) orig_loss, _ = model.evaluate(train_ds, steps=steps) history = model.fit( x=train_ds, epochs=2, steps_per_epoch=steps, callbacks=[verification_callback], ) self.assertIsInstance(history, keras.callbacks.History) trained_loss, _ = model.evaluate(train_ds, steps=steps) self.assertLess(trained_loss, orig_loss) verification_callback.verify(self)
def testCheckpointExists(self, file_format, save_weights_only): with self.cached_session(): train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(64, 2) model = multi_worker_testing_utils.get_mnist_model((28, 28, 1)) saving_dir = self.get_temp_dir() saving_filepath = os.path.join(saving_dir, 'checkpoint.' + file_format) callbacks_list = [ callbacks.ModelCheckpoint( filepath=saving_filepath, save_weights_only=save_weights_only) ] self.assertFalse(tf.io.gfile.exists(saving_filepath)) model.fit( x=train_ds, epochs=2, steps_per_epoch=2, callbacks=callbacks_list) tf_saved_model_exists = tf.io.gfile.exists(saving_filepath) tf_weights_only_checkpoint_exists = tf.io.gfile.exists( saving_filepath + '.index') self.assertTrue( tf_saved_model_exists or tf_weights_only_checkpoint_exists)
def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument """Simulates an Independent Worker inside of a thread.""" with tf.compat.v1.test.mock.patch.object(dc, '_run_std_server', self._make_mock_run_std_server()): strategy = strategy_cls() verification_callback.is_between_graph = \ strategy.extended.experimental_between_graph batch_size = 64 steps = 2 train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset( batch_size, steps) with strategy.scope(): model = multi_worker_testing_utils.get_mnist_model((28, 28, 1)) orig_loss, _ = model.evaluate(train_ds, steps=steps) callbacks_for_fit = tf.nest.flatten( kwargs.get('verification_callback', [])) history = model.fit( x=train_ds, epochs=num_epoch, steps_per_epoch=steps, callbacks=callbacks_for_fit) self.assertIsInstance(history, keras.callbacks.History) trained_loss, _ = model.evaluate(train_ds, steps=steps) self.assertLess(trained_loss, orig_loss)