def __init__(self, client_index, device, model, args): self.provider = SQLDataProvider(args) self.client_index = client_index self.host = client_index self.batch_size = args.batch_size self.used_x = [] self.round = 0 self.train_local = self.provider.cache(client_index, False) self.local_sample_number = self.provider.size() self.device = device self.args = args self.model = model # logging.info(self.model) self.model.to(self.device) self.served_client = [] self.criterion = nn.CrossEntropyLoss().to(self.device) if self.args.client_optimizer == "sgd": self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr) else: self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.args.lr, weight_decay=self.args.wd, amsgrad=True)
def __init__(self, all_clients): self.args = args() self.all_clients = all_clients self.model_stats = {} self.models = {} self.sample_dict = {} self.data_dict = {} self.test_data = SQLDataProvider(args()).cache(100)
def build(self): for index, client_idx in enumerate(database_clients): data = SQLDataProvider(args()).cache(client_idx) model = LogisticRegression(28 * 28, 10) trained = train(model, data.batch(8)) self.data_dict[client_idx] = data self.model_stats[client_idx] = trained self.models[client_idx] = model self.sample_dict[client_idx] = len(data) print("model accuracy:", infer(model, self.test_data.batch(8)))
def __init__(self): self.database_clients = [ 0, 1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 30, 31, 32, 33, 40, 41, 42, 43, 50, 51, 52, 53, 60, 61, 62, 63, 70, 71, 72, 73, 80, 81, 82, 83, 90, 91, 92, 93 ] self.model_stats = {} self.models = {} self.sample_dict = {} self.data_dict = {} self.test_data = SQLDataProvider(args()).cache(20)
def build(self, test_models=False, round_idx=0): print("Building Models --Started") for index, client_idx in enumerate(self.all_clients): data = SQLDataProvider(args()).cache(client_idx) model = LogisticRegression(28 * 28, 10) trained = tools.train(model, data.batch(8)) self.data_dict[client_idx] = data self.model_stats[client_idx] = trained self.models[client_idx] = model self.sample_dict[client_idx] = len(data) if test_models: print("model accuracy:", tools.infer(model, self.test_data.batch(8))) print("Building Models --Finished")
def __init__(self, worker_num, device, model, args): self.provider = SQLDataProvider(args) self.batch_size = args.batch_size self.worker_num = worker_num self.device = device self.args = args self.model_dict = dict() self.sample_num_dict = dict() self.flag_client_model_uploaded_dict = dict() for idx in range(self.worker_num): self.flag_client_model_uploaded_dict[idx] = False self.model, _ = self.init_model(model) self.cached_model, _ = self.init_model(model) self.model_influence = {}
class Context: def __init__(self): self.database_clients = [ 0, 1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 30, 31, 32, 33, 40, 41, 42, 43, 50, 51, 52, 53, 60, 61, 62, 63, 70, 71, 72, 73, 80, 81, 82, 83, 90, 91, 92, 93 ] self.model_stats = {} self.models = {} self.sample_dict = {} self.data_dict = {} self.test_data = SQLDataProvider(args()).cache(20) def build(self): for index, client_idx in enumerate(database_clients): data = SQLDataProvider(args()).cache(client_idx) model = LogisticRegression(28 * 28, 10) trained = train(model, data.batch(8)) self.data_dict[client_idx] = data self.model_stats[client_idx] = trained self.models[client_idx] = model self.sample_dict[client_idx] = len(data) print("model accuracy:", infer(model, self.test_data.batch(8))) def test_selection_accuracy(self, client_idx, title='test accuracy'): print('-----------------' + title + '-----------------') global_model_stats = aggregate(dict_select(client_idx, self.model_stats), dict_select(client_idx, self.sample_dict)) global_model = LogisticRegression(28 * 28, 10) load(global_model, global_model_stats) print("test case:", client_idx) acc_loss = infer(global_model, self.test_data.batch(8)) print("global model accuracy:", acc_loss[0]) print("global model loss:", acc_loss[1]) return acc_loss def test_selection_fitness(self, client_idx, title='test_fitness'): aggregated = aggregate(dict_select(client_idx, self.model_stats), dict_select(client_idx, self.sample_dict)) influences = [] for key in client_idx: influence = influence_ecl(aggregated, self.model_stats[key]) influences.append(influence) return statistics.variance(normalize(influences))
class Context: def __init__(self, all_clients): self.args = args() self.all_clients = all_clients self.model_stats = {} self.models = {} self.sample_dict = {} self.data_dict = {} self.test_data = SQLDataProvider(args()).cache(100) def build(self, test_models=False, round_idx=0): print("Building Models --Started") for index, client_idx in enumerate(self.all_clients): data = SQLDataProvider(args()).cache(client_idx) model = LogisticRegression(28 * 28, 10) trained = tools.train(model, data.batch(8)) self.data_dict[client_idx] = data self.model_stats[client_idx] = trained self.models[client_idx] = model self.sample_dict[client_idx] = len(data) if test_models: print("model accuracy:", tools.infer(model, self.test_data.batch(8))) print("Building Models --Finished") def cluster(self, cluster_size=10): print("Clustering Models --Started") weights = [] client_ids = [] clustered = {} for client_id, stats in self.model_stats.items(): client_ids.append(client_id) weights.append(stats['linear.weight'].numpy().flatten()) kmeans = KMeans(n_clusters=cluster_size).fit(weights) for i, label in enumerate(kmeans.labels_): clustered[client_ids[i]] = label print("Clustering Models --Finished") return clustered def test_selection_accuracy(self, client_idx, title='test accuracy', output=True): print('-----------------' + title + '-----------------') global_model_stats = tools.aggregate(tools.dict_select(client_idx, self.model_stats), tools.dict_select(client_idx, self.sample_dict)) global_model = LogisticRegression(28 * 28, 10) tools.load(global_model, global_model_stats) acc_loss = tools.infer(global_model, self.test_data.batch(8)) if output: print("test case:", client_idx) print("global model accuracy:", acc_loss[0], 'loss:', acc_loss[1]) return acc_loss def test_selection_fitness(self, client_idx, title='test_fitness', output=True): aggregated = tools.aggregate(tools.dict_select(client_idx, self.model_stats), tools.dict_select(client_idx, self.sample_dict)) influences = [] for key in client_idx: influence = tools.influence_ecl(aggregated, self.model_stats[key]) influences.append(influence) fitness = statistics.variance(tools.normalize(influences)) fitness = fitness * 10 ** 5 if output: print("test case:", client_idx) print("selection fitness:", fitness) return fitness def fitness(self, client_idx): return self.test_selection_fitness(client_idx, output=False)
class FedAVGTrainer(object): def __init__(self, client_index, device, model, args): self.provider = SQLDataProvider(args) self.client_index = client_index self.host = client_index self.batch_size = args.batch_size self.used_x = [] self.round = 0 self.train_local = self.provider.cache(client_index, False) self.local_sample_number = self.provider.size() self.device = device self.args = args self.model = model # logging.info(self.model) self.model.to(self.device) self.served_client = [] self.criterion = nn.CrossEntropyLoss().to(self.device) if self.args.client_optimizer == "sgd": self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.args.lr) else: self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.args.lr, weight_decay=self.args.wd, amsgrad=True) def update_model(self, weights): logging.info("update_model. client_index = %d" % self.client_index) self.model.load_state_dict(weights) def update_dataset(self, client_index): self.client_index = client_index self.served_client.append(client_index) logging.info("update_dataset. client_index = %d" % self.client_index) self.train_local = self.provider.cache(client_index, False).batch(self.batch_size) self.local_sample_number = self.provider.size() def update_round(self, round): self.round = round def train(self): self.model.to(self.device) # change to train mode self.model.train() epoch_loss = [] for epoch in range(self.args.epochs): batch_loss = [] for batch_idx, (x, labels) in enumerate(self.train_local): # logging.info(images.shape) x, labels = x.to(self.device), labels.to(self.device) self.optimizer.zero_grad() log_probs = self.model(x) loss = self.criterion(log_probs, labels) loss.backward() self.optimizer.step() batch_loss.append(loss.item()) if len(batch_loss) > 0: epoch_loss.append(sum(batch_loss) / len(batch_loss)) stats = {'client': self.client_index, 'loss': sum(epoch_loss)} logging.info('(client {}. Local Training Epoch: {} ' '\tLoss: {:.6f}'.format( self.client_index, epoch, sum(epoch_loss) / len(epoch_loss))) weights = self.model.cpu().state_dict() # transform Tensor to list if self.args.is_mobile == 1: weights = transform_tensor_to_list(weights) return weights, self.local_sample_number
for i in range(m): if data.y[i] == item[i]: similar_values += 1 sim.append(round(similar_values / m, 2)) return sim def heatmap(dct): matrix = [] for key, data in dct.items(): simi = similarities(data.y, dct) matrix.append(simi) print(matrix) test_data = SQLDataProvider(args()).cache(99999) print("test data:", test_data.y) for index, client_idx in enumerate(database_clients): data = SQLDataProvider(args()).cache(client_idx) model = LogisticRegression(28 * 28, 10) trained = train(model, data.batch(8)) data_dict[client_idx] = data model_stats[client_idx] = trained models[client_idx] = model sample_dict[client_idx] = len(data) print("model accuracy:", infer(model, test_data.batch(8))) def test_a_case(test_case, title='start evaluation'): print('-----------------' + title + '-----------------') global_model_stats = aggregate(dict_select(test_case, model_stats), dict_select(test_case, sample_dict))
class FedAVGAggregator(object): def __init__(self, worker_num, device, model, args): self.provider = SQLDataProvider(args) self.batch_size = args.batch_size self.worker_num = worker_num self.device = device self.args = args self.model_dict = dict() self.sample_num_dict = dict() self.flag_client_model_uploaded_dict = dict() for idx in range(self.worker_num): self.flag_client_model_uploaded_dict[idx] = False self.model, _ = self.init_model(model) self.cached_model, _ = self.init_model(model) self.model_influence = {} def init_model(self, model): model_params = model.state_dict() # logging.info(model) return model, model_params def get_global_model_params(self): return self.model.state_dict() def add_local_trained_result(self, index, model_params, sample_num): logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: return False for idx in range(self.worker_num): self.flag_client_model_uploaded_dict[idx] = False return True def aggregate_models(self, models_dict: dict): start_time = time.time() model_list = [] training_num = 0 for idx in models_dict.keys(): model_list.append((self.sample_num_dict[idx], models_dict[idx])) training_num += self.sample_num_dict[idx] logging.info("len of self.model_dict[idx] = " + str(len(self.model_dict))) # logging.info("################aggregate: %d" % len(model_list)) (num0, averaged_params) = model_list[0] for k in averaged_params.keys(): for i in range(0, len(model_list)): local_sample_number, local_model_params = model_list[i] w = local_sample_number / training_num if i == 0: averaged_params[k] = local_model_params[k] * w else: averaged_params[k] += local_model_params[k] * w end_time = time.time() logging.info("aggregate time cost: %d" % (end_time - start_time)) return averaged_params def aggregate(self): averaged_params = self.aggregate_models(self.model_dict) self.model.load_state_dict(averaged_params) return averaged_params def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): num_clients = min(client_num_per_round, client_num_in_total) np.random.seed( round_idx ) # make sure for each comparison, we are selecting the same clients each round client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) print(client_indexes) logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes def test_on_all_clients(self, round_idx): self.test_model_on_all_clients(self.model, round_idx) def test_model_on_all_clients(self, model, round_idx): if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1: logging.info( "################local_test_on_all_clients : {}".format( round_idx)) train_num_samples = [] train_tot_corrects = [] train_losses = [] test_num_samples = [] test_tot_corrects = [] test_losses = [] for client_idx in range(self.args.client_num_in_total): # train data train_data = self.provider.cache(client_idx).batch( self.batch_size) train_tot_correct, train_num_sample, train_loss = self._infer_model( model, train_data) train_tot_corrects.append(copy.deepcopy(train_tot_correct)) train_num_samples.append(copy.deepcopy(train_num_sample)) train_losses.append(copy.deepcopy(train_loss)) # test data test_data = self.provider.cache(client_idx, True).batch(self.batch_size) test_tot_correct, test_num_sample, test_loss = self._infer_model( model, test_data) test_tot_corrects.append(copy.deepcopy(test_tot_correct)) test_num_samples.append(copy.deepcopy(test_num_sample)) test_losses.append(copy.deepcopy(test_loss)) """Note: CI environment is CPU-based computing. The training speed for RNN training is to slow in this setting, so we only test a client to make sure there is no programming error. """ if self.args.ci == 1: break # test on training dataset train_acc = sum(train_tot_corrects) / sum(train_num_samples) train_loss = sum(train_losses) / sum(train_num_samples) # wandb.log({"Train/Acc": train_acc, "round": round_idx}) # wandb.log({"Train/Loss": train_loss, "round": round_idx}) train_stats = { 'training_acc': train_acc, 'training_loss': train_loss } logging.info(train_stats) # test on test dataset test_acc = sum(test_tot_corrects) / sum(test_num_samples) test_loss = sum(test_losses) / sum(test_num_samples) # wandb.log({"Test/Acc": test_acc, "round": round_idx}) # wandb.log({"Test/Loss": test_loss, "round": round_idx}) test_stats = {'test_acc': test_acc, 'test_loss': test_loss} logging.info(test_stats) return train_stats, test_stats def test_on_all_sub_models(self, round_idx: int): print("####round: " + str(round_idx) + "####") round_key = "round_" + str(round_idx) train_stats, test_stats = self.test_model_on_all_clients( self.model, round_idx) print("round", round_idx) print("global_model_train", train_stats) print("global_model_test", test_stats) self.model_influence[round_key] = {} self.model_influence[round_key]["."] = { "train_stats": train_stats, "test_stats": test_stats } print("model[.]:", train_stats) for idx in self.model_dict: temp_model_dict = dict(self.model_dict) del temp_model_dict[idx] model_params = self.aggregate_models(temp_model_dict) sub_model = LogisticRegression(28 * 28, 10) # sub_model = self.cached_model sub_model.load_state_dict(model_params) train_stats, test_stats = self.test_model_on_all_clients( sub_model, round_idx) print("model[" + str(idx) + "]:", train_stats) print("model[" + str(idx) + "].train", train_stats) print("model[" + str(idx) + "].test", test_stats) influence = self._influence(sub_model) print("model[" + str(idx) + "].influence", influence) influence_no_real, influence_both, influence_original = influence influence_ecl = self._influence_ecl(sub_model) self.model_influence[round_key][str(idx)] = {} # self.model_influence[round_key][str(idx)]["test_stats"] = train_stats self.model_influence[round_key][str( idx)]["train_stats"] = train_stats self.model_influence[round_key][str( idx)]["influence_no_real"] = influence_no_real self.model_influence[round_key][str( idx)]["influence_real"] = influence_both self.model_influence[round_key][str( idx)]["influence_ecl"] = influence_ecl.numpy() # print("influence[" + str(idx) + "]", influence) # print("euclidean_influence[" + str(idx) + "]", self._influence_ecl(sub_model)) # plotter.append(influence) print("####end of round: " + str(round_idx) + "####") # plotter.save("round_" + str(round_idx)) # self.log_cache.save() return "" def _influence_ecl(self, model): original = self.model.state_dict() sub = model.state_dict() l2_norm = torch.dist(original["linear.weight"], sub["linear.weight"], 2) return l2_norm.cpu() # noinspection PyUnresolvedReferences def _influence(self, model): client_nums = self.args.influence_test_clients influence_no_labels = torch.tensor(0, dtype=torch.float) influence_correct_labels_both = torch.tensor(0, dtype=torch.float) influence_correct_labels_original = torch.tensor(0, dtype=torch.float) client_round = 0 for client_idx in range(self.args.client_num_in_total): train_data = self.provider.cache(client_idx, True) if len(train_data) == 0: continue train_batches = train_data.batch(self.batch_size) deletion_prediction = self.predict(model, train_batches) original_prediction = self.predict(self.model, train_batches) deletion_labels = self._predictions_to_label(deletion_prediction) original_labels = self._predictions_to_label(original_prediction) real_labels = train_data.y # calculate the influence without taking into consideration the real labels influence_no_labels += self._influence_function_no_labels( deletion_prediction, original_prediction, deletion_labels, original_labels, real_labels) # calculate the influence taking into consideration only the correct labels in both predictions influence_correct_labels_both += self._influence_function_only_correct_labels_both( deletion_prediction, original_prediction, deletion_labels, original_labels, real_labels) # calculate the influence taking into consideration only the correct labels in original predictions influence_correct_labels_original += self._influence_function_only_correct_labels_original( deletion_prediction, original_prediction, deletion_labels, original_labels, real_labels) client_round += 1 if 0 < client_nums <= client_round: break influence_no_labels = influence_no_labels.item() / client_nums influence_correct_labels_both = influence_correct_labels_both.item( ) / client_nums influence_correct_labels_original = influence_correct_labels_original.item( ) / client_nums return influence_no_labels, influence_correct_labels_both, influence_correct_labels_original def _influence_of_predictions(self, left, right): if len(left) == 0: return torch.tensor(0, dtype=torch.float) influence = torch.tensor(data=0.0, dtype=torch.float) for i in range(len(left)): difference = left[i] - right[i] difference = torch.abs(difference) influence += torch.mean(difference) return influence / len(left) def _conditional_influence_function(self, deletion_prediction, original_prediction, deletion_labels, original_labels, real_labels, condition): new_deletion_predictions = torch.tensor([]) new_original_predictions = torch.tensor([]) for index, label in enumerate(real_labels): if condition(deletion_labels[index], original_labels[index], real_labels[index]): new_deletion_predictions = torch.cat( (new_deletion_predictions, torch.unsqueeze(deletion_prediction[index], 0))) new_original_predictions = torch.cat( (new_original_predictions, torch.unsqueeze(original_prediction[index], 0))) return self._influence_of_predictions(new_deletion_predictions, new_original_predictions) def _influence_function_no_labels(self, deletion_prediction, original_prediction, deletion_labels, original_labels, real_labels): return self._conditional_influence_function( deletion_prediction, original_prediction, deletion_labels, original_labels, real_labels, lambda d, o, r: True) def _influence_function_only_correct_labels_both(self, deletion_prediction, original_prediction, deletion_labels, original_labels, real_labels): return self._conditional_influence_function( deletion_prediction, original_prediction, deletion_labels, original_labels, real_labels, lambda d, o, r: d == r == o) def _influence_function_only_correct_labels_original( self, deletion_prediction, original_prediction, deletion_labels, original_labels, real_labels): return self._conditional_influence_function( deletion_prediction, original_prediction, deletion_labels, original_labels, real_labels, lambda d, o, r: o == r) def _predictions_to_label(self, predictions): labels = torch.tensor([]) for prediction in predictions: predicted_label = torch.argmax(prediction) labels = torch.cat((labels, predicted_label.reshape(1).float())) return labels def _infer(self, test_data): return self._infer_model(self.model, test_data) def predict(self, model, data): predictions = None model.eval() model.to(self.device) with torch.no_grad(): for batch_idx, (x, target) in enumerate(data): x = x.to(self.device) target.to(self.device) if predictions is None: predictions = model(x) else: predictions = torch.cat((predictions, model(x))) return predictions.cpu() def _infer_model(self, model, test_data): model.eval() model.to(self.device) test_loss = test_acc = test_total = 0. criterion = nn.CrossEntropyLoss().to(self.device) with torch.no_grad(): for batch_idx, (x, target) in enumerate(test_data): x = x.to(self.device) target = target.to(self.device) pred = model(x) loss = criterion(pred, target) _, predicted = torch.max(pred, -1) correct = predicted.eq(target).sum() test_acc += correct.item() test_loss += loss.item() * target.size(0) test_total += target.size(0) return test_acc, test_total, test_loss