Ejemplo n.º 1
0
    def test_mnist_training_cpu(self):
        devices = tf.config.list_physical_devices("CPU")
        tf.config.set_logical_device_configuration(
            devices[0],
            [
                tf.config.LogicalDeviceConfiguration(),
            ] * 8,
        )

        mesh = mesh_util.create_mesh(devices=["CPU:%d" % i for i in range(8)],
                                     mesh_dims=[("batch", 8)])

        backend.enable_tf_random_generator()
        # Needed by keras initializers.
        tf_utils.set_random_seed(1337)

        model = integration_test_utils.get_model_with_layout_map(
            integration_test_utils.get_all_replicated_layout_map(mesh))

        optimizer = optimizer_lib.Adam(learning_rate=0.001, mesh=mesh)
        optimizer.build(model.trainable_variables)

        train_losses = integration_test_utils.train_mnist_model_batch_sharded(
            model,
            optimizer,
            mesh,
            num_epochs=3,
            steps_per_epoch=100,
            global_batch_size=64,
        )
        # Make sure the losses are decreasing
        self.assertEqual(train_losses, sorted(train_losses, reverse=True))
Ejemplo n.º 2
0
 def setUp(self):
     super(LayersTest, self).setUp()
     backend.enable_tf_random_generator()
     tf_utils.set_random_seed(1337)
     global_ids = test_util.create_device_ids_array((2, 2))
     local_device_ids = np.ravel(global_ids).tolist()
     mesh_dict = {
         'CPU':
         dtensor.Mesh(['X', 'Y'], global_ids, local_device_ids,
                      test_util.create_device_list((2, 2), 'CPU'))
     }
     self.mesh = self.configTestMesh(mesh_dict)
Ejemplo n.º 3
0
 def setUp(self):
     super().setUp()
     backend.enable_tf_random_generator()
     tf_utils.set_random_seed(1337)
     global_ids = test_util.create_device_ids_array((2, 2))
     local_device_ids = np.ravel(global_ids).tolist()
     mesh_dict = {
         "CPU":
         dtensor.Mesh(
             ["X", "Y"],
             global_ids,
             local_device_ids,
             test_util.create_device_list((2, 2), "CPU"),
         )
     }
     self.mesh = self.configTestMesh(mesh_dict)
Ejemplo n.º 4
0
    def setUp(self):
        super(LayoutMapTest, self).setUp()
        backend.enable_tf_random_generator()
        tf_utils.set_random_seed(1337)
        global_ids = test_util.create_device_ids_array((2, 2))
        local_device_ids = np.ravel(global_ids).tolist()
        mesh_dict = {
            'CPU':
            dtensor.Mesh(['X', 'Y'], global_ids, local_device_ids,
                         test_util.create_device_list((2, 2), 'CPU'))
        }
        self.mesh = self.configTestMesh(mesh_dict)
        self.layout_2d = dtensor.Layout.replicated(self.mesh, rank=2)
        self.layout_1d = dtensor.Layout.replicated(self.mesh, rank=1)

        self.sharded_2d = dtensor.Layout.batch_sharded(self.mesh, 'X', rank=2)
        self.sharded_1d = dtensor.Layout.batch_sharded(self.mesh, 'X', rank=1)