Пример #1
0
 def setup_clients(self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict):
     logging.info("############setup_clients (START)#############")
     for client_idx in range(self.args.client_num_per_round):
         c = Client(client_idx, train_data_local_dict[client_idx], test_data_local_dict[client_idx],
                    train_data_local_num_dict[client_idx], self.args, self.device)
         self.client_list.append(c)
     logging.info("############setup_clients (END)#############")
Пример #2
0
 def _setup_clients(self, train_data_local_num_dict, train_data_local_dict,
                    test_data_local_dict, model_trainer):
     logging.info("############setup_clients (START)#############")
     # 补充:判断总客户端数量是否小于每轮参与的客户端数量
     if self.args.client_num_in_total < self.args.client_num_per_round:
         logging.info(
             "client_num_in_total is less than client_num_per_round, please check params"
         )
     # 根据每轮参与的客户端数量生成相应个数的客户端(每轮参与的基本客户端列表)
     for client_idx in range(self.args.client_num_per_round):
         c = Client(client_idx, train_data_local_dict[client_idx],
                    test_data_local_dict[client_idx],
                    train_data_local_num_dict[client_idx], self.args,
                    self.device, model_trainer)
         self.client_list.append(c)
     logging.info("############setup_clients (END)#############")
Пример #3
0
    def setup_clients(self):
        self.logger.info("############setup_clients (START)#############")
        args_datadir = "./data/cifar10"
        for client_idx in range(self.args.client_number):
            self.logger.info("######client idx = " + str(client_idx))
            dataidxs = self.net_dataidx_map[client_idx]
            local_sample_number = len(dataidxs)

            # training batch size = 64; test batch size = 32
            train_dl_local, test_dl_local = get_dataloader(
                self.args.dataset, args_datadir, self.args.batch_size, 32,
                dataidxs)
            self.logger.info('n_sample: %d' % local_sample_number)
            self.logger.info('n_training: %d' % len(train_dl_local))
            self.logger.info('n_test: %d' % len(test_dl_local))

            c = Client(train_dl_local, test_dl_local, local_sample_number,
                       self.args, self.logger, self.device)
            self.client_list.append(c)

        self.logger.info("############setup_clients (END)#############")