def DISABLED_test_mnist_training_tpu(self): # TODO(scottzhu): Enable TPU test once the dtensor_test rule is migrated # out of learning/brain tpu_util.dtensor_initialize_tpu_system() total_tpu_device_count = dtensor.num_global_devices("TPU") mesh_shape = [total_tpu_device_count] mesh = tpu_util.create_tpu_mesh(["batch"], mesh_shape, "tpu_mesh") # Needed by keras initializers. tf_utils.set_random_seed(1337) model = integration_test_utils.get_model_with_layout_map( integration_test_utils.get_all_replicated_layout_map(mesh)) optimizer = optimizer_lib.Adam(learning_rate=0.001, mesh=mesh) optimizer.build(model.trainable_variables) train_losses = integration_test_utils.train_mnist_model_batch_sharded( model, optimizer, mesh, num_epochs=3, steps_per_epoch=100, global_batch_size=64, ) # Make sure the losses are decreasing self.assertEqual(train_losses, sorted(train_losses, reverse=True))
def test_mnist_training_cpu(self): devices = tf.config.list_physical_devices("CPU") tf.config.set_logical_device_configuration( devices[0], [ tf.config.LogicalDeviceConfiguration(), ] * 8, ) mesh = mesh_util.create_mesh(devices=["CPU:%d" % i for i in range(8)], mesh_dims=[("batch", 8)]) backend.enable_tf_random_generator() # Needed by keras initializers. tf_utils.set_random_seed(1337) model = integration_test_utils.get_model_with_layout_map( integration_test_utils.get_all_replicated_layout_map(mesh)) optimizer = optimizer_lib.Adam(learning_rate=0.001, mesh=mesh) optimizer.build(model.trainable_variables) train_losses = integration_test_utils.train_mnist_model_batch_sharded( model, optimizer, mesh, num_epochs=3, steps_per_epoch=100, global_batch_size=64, ) # Make sure the losses are decreasing self.assertEqual(train_losses, sorted(train_losses, reverse=True))