Exemple #1
0
    def DISABLED_test_mnist_training_tpu(self):
        # TODO(scottzhu): Enable TPU test once the dtensor_test rule is migrated
        # out of learning/brain
        tpu_util.dtensor_initialize_tpu_system()
        total_tpu_device_count = dtensor.num_global_devices("TPU")
        mesh_shape = [total_tpu_device_count]
        mesh = tpu_util.create_tpu_mesh(["batch"], mesh_shape, "tpu_mesh")

        # Needed by keras initializers.
        tf_utils.set_random_seed(1337)

        model = integration_test_utils.get_model_with_layout_map(
            integration_test_utils.get_all_replicated_layout_map(mesh))

        optimizer = optimizer_lib.Adam(learning_rate=0.001, mesh=mesh)
        optimizer.build(model.trainable_variables)

        train_losses = integration_test_utils.train_mnist_model_batch_sharded(
            model,
            optimizer,
            mesh,
            num_epochs=3,
            steps_per_epoch=100,
            global_batch_size=64,
        )
        # Make sure the losses are decreasing
        self.assertEqual(train_losses, sorted(train_losses, reverse=True))
Exemple #2
0
    def test_mnist_training_cpu(self):
        devices = tf.config.list_physical_devices("CPU")
        tf.config.set_logical_device_configuration(
            devices[0],
            [
                tf.config.LogicalDeviceConfiguration(),
            ] * 8,
        )

        mesh = mesh_util.create_mesh(devices=["CPU:%d" % i for i in range(8)],
                                     mesh_dims=[("batch", 8)])

        backend.enable_tf_random_generator()
        # Needed by keras initializers.
        tf_utils.set_random_seed(1337)

        model = integration_test_utils.get_model_with_layout_map(
            integration_test_utils.get_all_replicated_layout_map(mesh))

        optimizer = optimizer_lib.Adam(learning_rate=0.001, mesh=mesh)
        optimizer.build(model.trainable_variables)

        train_losses = integration_test_utils.train_mnist_model_batch_sharded(
            model,
            optimizer,
            mesh,
            num_epochs=3,
            steps_per_epoch=100,
            global_batch_size=64,
        )
        # Make sure the losses are decreasing
        self.assertEqual(train_losses, sorted(train_losses, reverse=True))
Exemple #3
0
 def test_build_index_dict(self):
   optimizer = optimizers.Adam(mesh=self.mesh)
   variable_init_value = tf.ones(
       shape=(), dtype=tf.float32,
       layout=dtensor.Layout.replicated(self.mesh, rank=0))
   var_list = [dtensor.DVariable(variable_init_value, name=f'var{i}')
               for i in range(10)]
   optimizer._build_index_dict(var_list)
   self.assertEqual(optimizer._index_dict[optimizer._var_key(var_list[7])], 7)
Exemple #4
0
 def test_add_variable_from_reference(self):
   optimizer = optimizers.Adam(mesh=self.mesh)
   variable_init_value = tf.ones(
       [4, 4], dtype=tf.float32,
       layout=dtensor.Layout.replicated(self.mesh, rank=2))
   model_variable = dtensor.DVariable(variable_init_value,
                                      trainable=True,
                                      name='tmp')
   state_variable = optimizer.add_variable_from_reference(
       model_variable, 'test')
   self.assertEqual(state_variable._shared_name, 'test/tmp')
   self.assertAllClose(self.evaluate(state_variable), tf.zeros([4, 4]))
   # Make sure the variable contains the correct layout info
   self.assertEqual(state_variable.layout, model_variable.layout)