コード例 #1
0
 def test_build_randomized_omega(self):
     artemis_update = ArtemisUpdate(self.params, self.workers)
     artemis_update.workers_sub_set = self.workers
     artemis_update.value_to_compress = torch.FloatTensor(
         [i for i in range(0, 100, 10)])
     # We initilize omega_k with two values (as if we are at iteration 2)
     artemis_update.omega_k = [[
         torch.FloatTensor([i for i in range(0, 100, 10)]),
         torch.FloatTensor([i for i in range(10)])
     ],
                               [
                                   torch.FloatTensor(
                                       [i for i in range(0, 20, 2)]),
                                   torch.FloatTensor([i for i in range(10)])
                               ]]
     nb_try = 1
     # We want to check that we have two different quantization of the value to compress.
     # But in quantization there is some randomness, and thus vectors can some time be identical.
     # We carry out five try, it after that there are still equal we consider that it is uncorrect.
     artemis_update.build_randomized_omega(self.cost_models)
     self.assertEqual(
         len(artemis_update.omega), 2,
         "The number of compressed value kept on central server must be equal to 2."
     )
     while (nb_try < 5 and torch.all(artemis_update.omega[0].eq(
             artemis_update.omega[1]))):
         artemis_update.build_randomized_omega(self.cost_models)
         nb_try += 1
     self.assertTrue(
         nb_try < 5,
         "After 5 try, the two different quantizations are still identical."
     )
     self.assertTrue(len(artemis_update.omega_k) == 3)
コード例 #2
0
 def test_send_back_global_informations_and_update(self):
     artemis_update = ArtemisUpdate(self.params, self.workers)
     self.workers[0].idx_last_update = 1
     self.workers[1].idx_last_update = 1
     artemis_update.workers_sub_set = [
         (self.workers[i], self.cost_models[i])
         for i in range(self.params.nb_devices)
     ]
     artemis_update.omega_k = [[
         torch.FloatTensor([0, 50]),
         torch.FloatTensor([0, 10])
     ], [torch.FloatTensor([2, 4]),
         torch.FloatTensor([10, 20])]]
     nb_try = 1
     artemis_update.step = 1 / 10
     artemis_update.send_back_global_informations_and_update(
         self.cost_models)
     while (nb_try < 5 and torch.all(
             artemis_update.workers[0].local_update.model_param.eq(
                 artemis_update.workers[1].local_update.model_param))):
         self.workers[0].idx_last_update = 1
         self.workers[1].idx_last_update = 1
         artemis_update.send_back_global_informations_and_update(
             self.cost_models)
         nb_try += 1
     self.assertFalse(
         torch.all(artemis_update.workers[0].local_update.model_param.eq(
             artemis_update.workers[1].local_update.model_param)),
         "The models on workers are expected to be different.")
     self.assertTrue(
         self.workers[0].idx_last_update == 2
         and self.workers[1].idx_last_update == 2,
         "Index of last participation of each worker should be updated to 2"
     )