def DISABLED_test_mnist_training_tpu(self): # TODO(scottzhu): Enable TPU test once the dtensor_test rule is migrated # out of learning/brain tpu_util.dtensor_initialize_tpu_system() total_tpu_device_count = dtensor.num_global_devices("TPU") mesh_shape = [total_tpu_device_count] mesh = tpu_util.create_tpu_mesh(["batch"], mesh_shape, "tpu_mesh") # Needed by keras initializers. tf_utils.set_random_seed(1337) model = integration_test_utils.get_model_with_layout_map( integration_test_utils.get_all_replicated_layout_map(mesh)) optimizer = optimizer_lib.Adam(learning_rate=0.001, mesh=mesh) optimizer.build(model.trainable_variables) train_losses = integration_test_utils.train_mnist_model_batch_sharded( model, optimizer, mesh, num_epochs=3, steps_per_epoch=100, global_batch_size=64, ) # Make sure the losses are decreasing self.assertEqual(train_losses, sorted(train_losses, reverse=True))
def test_mnist_training_cpu(self): devices = tf.config.list_physical_devices("CPU") tf.config.set_logical_device_configuration( devices[0], [ tf.config.LogicalDeviceConfiguration(), ] * 8, ) mesh = mesh_util.create_mesh(devices=["CPU:%d" % i for i in range(8)], mesh_dims=[("batch", 8)]) backend.enable_tf_random_generator() # Needed by keras initializers. tf_utils.set_random_seed(1337) model = integration_test_utils.get_model_with_layout_map( integration_test_utils.get_all_replicated_layout_map(mesh)) optimizer = optimizer_lib.Adam(learning_rate=0.001, mesh=mesh) optimizer.build(model.trainable_variables) train_losses = integration_test_utils.train_mnist_model_batch_sharded( model, optimizer, mesh, num_epochs=3, steps_per_epoch=100, global_batch_size=64, ) # Make sure the losses are decreasing self.assertEqual(train_losses, sorted(train_losses, reverse=True))
def test_build_index_dict(self): optimizer = optimizers.Adam(mesh=self.mesh) variable_init_value = tf.ones( shape=(), dtype=tf.float32, layout=dtensor.Layout.replicated(self.mesh, rank=0)) var_list = [dtensor.DVariable(variable_init_value, name=f'var{i}') for i in range(10)] optimizer._build_index_dict(var_list) self.assertEqual(optimizer._index_dict[optimizer._var_key(var_list[7])], 7)
def test_add_variable_from_reference(self): optimizer = optimizers.Adam(mesh=self.mesh) variable_init_value = tf.ones( [4, 4], dtype=tf.float32, layout=dtensor.Layout.replicated(self.mesh, rank=2)) model_variable = dtensor.DVariable(variable_init_value, trainable=True, name='tmp') state_variable = optimizer.add_variable_from_reference( model_variable, 'test') self.assertEqual(state_variable._shared_name, 'test/tmp') self.assertAllClose(self.evaluate(state_variable), tf.zeros([4, 4])) # Make sure the variable contains the correct layout info self.assertEqual(state_variable.layout, model_variable.layout)