Пример #1
0
 def __init__(self, client_index, device, model, args):
     self.provider = SQLDataProvider(args)
     self.client_index = client_index
     self.host = client_index
     self.batch_size = args.batch_size
     self.used_x = []
     self.round = 0
     self.train_local = self.provider.cache(client_index, False)
     self.local_sample_number = self.provider.size()
     self.device = device
     self.args = args
     self.model = model
     # logging.info(self.model)
     self.model.to(self.device)
     self.served_client = []
     self.criterion = nn.CrossEntropyLoss().to(self.device)
     if self.args.client_optimizer == "sgd":
         self.optimizer = torch.optim.SGD(self.model.parameters(),
                                          lr=self.args.lr)
     else:
         self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                  self.model.parameters()),
                                           lr=self.args.lr,
                                           weight_decay=self.args.wd,
                                           amsgrad=True)
Пример #2
0
 def __init__(self, all_clients):
     self.args = args()
     self.all_clients = all_clients
     self.model_stats = {}
     self.models = {}
     self.sample_dict = {}
     self.data_dict = {}
     self.test_data = SQLDataProvider(args()).cache(100)
Пример #3
0
 def build(self):
     for index, client_idx in enumerate(database_clients):
         data = SQLDataProvider(args()).cache(client_idx)
         model = LogisticRegression(28 * 28, 10)
         trained = train(model, data.batch(8))
         self.data_dict[client_idx] = data
         self.model_stats[client_idx] = trained
         self.models[client_idx] = model
         self.sample_dict[client_idx] = len(data)
         print("model accuracy:", infer(model, self.test_data.batch(8)))
Пример #4
0
 def __init__(self):
     self.database_clients = [
         0, 1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 30, 31, 32, 33, 40, 41, 42, 43, 50, 51, 52, 53, 60, 61,
         62, 63, 70, 71, 72, 73, 80, 81, 82, 83, 90, 91, 92, 93
     ]
     self.model_stats = {}
     self.models = {}
     self.sample_dict = {}
     self.data_dict = {}
     self.test_data = SQLDataProvider(args()).cache(20)
Пример #5
0
 def build(self, test_models=False, round_idx=0):
     print("Building Models --Started")
     for index, client_idx in enumerate(self.all_clients):
         data = SQLDataProvider(args()).cache(client_idx)
         model = LogisticRegression(28 * 28, 10)
         trained = tools.train(model, data.batch(8))
         self.data_dict[client_idx] = data
         self.model_stats[client_idx] = trained
         self.models[client_idx] = model
         self.sample_dict[client_idx] = len(data)
         if test_models:
             print("model accuracy:", tools.infer(model, self.test_data.batch(8)))
     print("Building Models --Finished")
Пример #6
0
 def __init__(self, worker_num, device, model, args):
     self.provider = SQLDataProvider(args)
     self.batch_size = args.batch_size
     self.worker_num = worker_num
     self.device = device
     self.args = args
     self.model_dict = dict()
     self.sample_num_dict = dict()
     self.flag_client_model_uploaded_dict = dict()
     for idx in range(self.worker_num):
         self.flag_client_model_uploaded_dict[idx] = False
     self.model, _ = self.init_model(model)
     self.cached_model, _ = self.init_model(model)
     self.model_influence = {}
Пример #7
0
class Context:
    def __init__(self):
        self.database_clients = [
            0, 1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 30, 31, 32, 33, 40, 41, 42, 43, 50, 51, 52, 53, 60, 61,
            62, 63, 70, 71, 72, 73, 80, 81, 82, 83, 90, 91, 92, 93
        ]
        self.model_stats = {}
        self.models = {}
        self.sample_dict = {}
        self.data_dict = {}
        self.test_data = SQLDataProvider(args()).cache(20)

    def build(self):
        for index, client_idx in enumerate(database_clients):
            data = SQLDataProvider(args()).cache(client_idx)
            model = LogisticRegression(28 * 28, 10)
            trained = train(model, data.batch(8))
            self.data_dict[client_idx] = data
            self.model_stats[client_idx] = trained
            self.models[client_idx] = model
            self.sample_dict[client_idx] = len(data)
            print("model accuracy:", infer(model, self.test_data.batch(8)))

    def test_selection_accuracy(self, client_idx, title='test accuracy'):
        print('-----------------' + title + '-----------------')
        global_model_stats = aggregate(dict_select(client_idx, self.model_stats),
                                       dict_select(client_idx, self.sample_dict))
        global_model = LogisticRegression(28 * 28, 10)
        load(global_model, global_model_stats)
        print("test case:", client_idx)
        acc_loss = infer(global_model, self.test_data.batch(8))
        print("global model accuracy:", acc_loss[0])
        print("global model loss:", acc_loss[1])
        return acc_loss

    def test_selection_fitness(self, client_idx, title='test_fitness'):
        aggregated = aggregate(dict_select(client_idx, self.model_stats), dict_select(client_idx, self.sample_dict))
        influences = []
        for key in client_idx:
            influence = influence_ecl(aggregated, self.model_stats[key])
            influences.append(influence)
        return statistics.variance(normalize(influences))
Пример #8
0
class Context:
    def __init__(self, all_clients):
        self.args = args()
        self.all_clients = all_clients
        self.model_stats = {}
        self.models = {}
        self.sample_dict = {}
        self.data_dict = {}
        self.test_data = SQLDataProvider(args()).cache(100)

    def build(self, test_models=False, round_idx=0):
        print("Building Models --Started")
        for index, client_idx in enumerate(self.all_clients):
            data = SQLDataProvider(args()).cache(client_idx)
            model = LogisticRegression(28 * 28, 10)
            trained = tools.train(model, data.batch(8))
            self.data_dict[client_idx] = data
            self.model_stats[client_idx] = trained
            self.models[client_idx] = model
            self.sample_dict[client_idx] = len(data)
            if test_models:
                print("model accuracy:", tools.infer(model, self.test_data.batch(8)))
        print("Building Models --Finished")

    def cluster(self, cluster_size=10):
        print("Clustering Models --Started")
        weights = []
        client_ids = []
        clustered = {}
        for client_id, stats in self.model_stats.items():
            client_ids.append(client_id)
            weights.append(stats['linear.weight'].numpy().flatten())
        kmeans = KMeans(n_clusters=cluster_size).fit(weights)
        for i, label in enumerate(kmeans.labels_):
            clustered[client_ids[i]] = label
        print("Clustering Models --Finished")
        return clustered

    def test_selection_accuracy(self, client_idx, title='test accuracy', output=True):
        print('-----------------' + title + '-----------------')
        global_model_stats = tools.aggregate(tools.dict_select(client_idx, self.model_stats),
                                             tools.dict_select(client_idx, self.sample_dict))
        global_model = LogisticRegression(28 * 28, 10)
        tools.load(global_model, global_model_stats)
        acc_loss = tools.infer(global_model, self.test_data.batch(8))
        if output:
            print("test case:", client_idx)
            print("global model accuracy:", acc_loss[0], 'loss:', acc_loss[1])
        return acc_loss

    def test_selection_fitness(self, client_idx, title='test_fitness', output=True):
        aggregated = tools.aggregate(tools.dict_select(client_idx, self.model_stats),
                                     tools.dict_select(client_idx, self.sample_dict))
        influences = []
        for key in client_idx:
            influence = tools.influence_ecl(aggregated, self.model_stats[key])
            influences.append(influence)
        fitness = statistics.variance(tools.normalize(influences))
        fitness = fitness * 10 ** 5
        if output:
            print("test case:", client_idx)
            print("selection fitness:", fitness)
        return fitness

    def fitness(self, client_idx):
        return self.test_selection_fitness(client_idx, output=False)
Пример #9
0
class FedAVGTrainer(object):
    def __init__(self, client_index, device, model, args):
        self.provider = SQLDataProvider(args)
        self.client_index = client_index
        self.host = client_index
        self.batch_size = args.batch_size
        self.used_x = []
        self.round = 0
        self.train_local = self.provider.cache(client_index, False)
        self.local_sample_number = self.provider.size()
        self.device = device
        self.args = args
        self.model = model
        # logging.info(self.model)
        self.model.to(self.device)
        self.served_client = []
        self.criterion = nn.CrossEntropyLoss().to(self.device)
        if self.args.client_optimizer == "sgd":
            self.optimizer = torch.optim.SGD(self.model.parameters(),
                                             lr=self.args.lr)
        else:
            self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                     self.model.parameters()),
                                              lr=self.args.lr,
                                              weight_decay=self.args.wd,
                                              amsgrad=True)

    def update_model(self, weights):
        logging.info("update_model. client_index = %d" % self.client_index)
        self.model.load_state_dict(weights)

    def update_dataset(self, client_index):
        self.client_index = client_index
        self.served_client.append(client_index)
        logging.info("update_dataset. client_index = %d" % self.client_index)
        self.train_local = self.provider.cache(client_index,
                                               False).batch(self.batch_size)
        self.local_sample_number = self.provider.size()

    def update_round(self, round):
        self.round = round

    def train(self):
        self.model.to(self.device)
        # change to train mode
        self.model.train()

        epoch_loss = []
        for epoch in range(self.args.epochs):
            batch_loss = []
            for batch_idx, (x, labels) in enumerate(self.train_local):
                # logging.info(images.shape)
                x, labels = x.to(self.device), labels.to(self.device)
                self.optimizer.zero_grad()
                log_probs = self.model(x)
                loss = self.criterion(log_probs, labels)
                loss.backward()
                self.optimizer.step()
                batch_loss.append(loss.item())
            if len(batch_loss) > 0:
                epoch_loss.append(sum(batch_loss) / len(batch_loss))
                stats = {'client': self.client_index, 'loss': sum(epoch_loss)}
                logging.info('(client {}. Local Training Epoch: {} '
                             '\tLoss: {:.6f}'.format(
                                 self.client_index, epoch,
                                 sum(epoch_loss) / len(epoch_loss)))

        weights = self.model.cpu().state_dict()
        # transform Tensor to list
        if self.args.is_mobile == 1:
            weights = transform_tensor_to_list(weights)
        return weights, self.local_sample_number
Пример #10
0
        for i in range(m):
            if data.y[i] == item[i]:
                similar_values += 1
        sim.append(round(similar_values / m, 2))
    return sim


def heatmap(dct):
    matrix = []
    for key, data in dct.items():
        simi = similarities(data.y, dct)
        matrix.append(simi)
    print(matrix)


test_data = SQLDataProvider(args()).cache(99999)
print("test data:", test_data.y)
for index, client_idx in enumerate(database_clients):
    data = SQLDataProvider(args()).cache(client_idx)
    model = LogisticRegression(28 * 28, 10)
    trained = train(model, data.batch(8))
    data_dict[client_idx] = data
    model_stats[client_idx] = trained
    models[client_idx] = model
    sample_dict[client_idx] = len(data)
    print("model accuracy:", infer(model, test_data.batch(8)))


def test_a_case(test_case, title='start evaluation'):
    print('-----------------' + title + '-----------------')
    global_model_stats = aggregate(dict_select(test_case, model_stats), dict_select(test_case, sample_dict))
Пример #11
0
class FedAVGAggregator(object):
    def __init__(self, worker_num, device, model, args):
        self.provider = SQLDataProvider(args)
        self.batch_size = args.batch_size
        self.worker_num = worker_num
        self.device = device
        self.args = args
        self.model_dict = dict()
        self.sample_num_dict = dict()
        self.flag_client_model_uploaded_dict = dict()
        for idx in range(self.worker_num):
            self.flag_client_model_uploaded_dict[idx] = False
        self.model, _ = self.init_model(model)
        self.cached_model, _ = self.init_model(model)
        self.model_influence = {}

    def init_model(self, model):
        model_params = model.state_dict()
        # logging.info(model)
        return model, model_params

    def get_global_model_params(self):
        return self.model.state_dict()

    def add_local_trained_result(self, index, model_params, sample_num):
        logging.info("add_model. index = %d" % index)
        self.model_dict[index] = model_params
        self.sample_num_dict[index] = sample_num
        self.flag_client_model_uploaded_dict[index] = True

    def check_whether_all_receive(self):
        for idx in range(self.worker_num):
            if not self.flag_client_model_uploaded_dict[idx]:
                return False
        for idx in range(self.worker_num):
            self.flag_client_model_uploaded_dict[idx] = False
        return True

    def aggregate_models(self, models_dict: dict):
        start_time = time.time()
        model_list = []
        training_num = 0

        for idx in models_dict.keys():
            model_list.append((self.sample_num_dict[idx], models_dict[idx]))
            training_num += self.sample_num_dict[idx]

        logging.info("len of self.model_dict[idx] = " +
                     str(len(self.model_dict)))

        # logging.info("################aggregate: %d" % len(model_list))
        (num0, averaged_params) = model_list[0]
        for k in averaged_params.keys():
            for i in range(0, len(model_list)):
                local_sample_number, local_model_params = model_list[i]
                w = local_sample_number / training_num
                if i == 0:
                    averaged_params[k] = local_model_params[k] * w
                else:
                    averaged_params[k] += local_model_params[k] * w

        end_time = time.time()
        logging.info("aggregate time cost: %d" % (end_time - start_time))
        return averaged_params

    def aggregate(self):
        averaged_params = self.aggregate_models(self.model_dict)
        self.model.load_state_dict(averaged_params)
        return averaged_params

    def client_sampling(self, round_idx, client_num_in_total,
                        client_num_per_round):
        num_clients = min(client_num_per_round, client_num_in_total)
        np.random.seed(
            round_idx
        )  # make sure for each comparison, we are selecting the same clients each round
        client_indexes = np.random.choice(range(client_num_in_total),
                                          num_clients,
                                          replace=False)
        print(client_indexes)
        logging.info("client_indexes = %s" % str(client_indexes))
        return client_indexes

    def test_on_all_clients(self, round_idx):
        self.test_model_on_all_clients(self.model, round_idx)

    def test_model_on_all_clients(self, model, round_idx):
        if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1:
            logging.info(
                "################local_test_on_all_clients : {}".format(
                    round_idx))
            train_num_samples = []
            train_tot_corrects = []
            train_losses = []

            test_num_samples = []
            test_tot_corrects = []
            test_losses = []
            for client_idx in range(self.args.client_num_in_total):
                # train data
                train_data = self.provider.cache(client_idx).batch(
                    self.batch_size)
                train_tot_correct, train_num_sample, train_loss = self._infer_model(
                    model, train_data)
                train_tot_corrects.append(copy.deepcopy(train_tot_correct))
                train_num_samples.append(copy.deepcopy(train_num_sample))
                train_losses.append(copy.deepcopy(train_loss))

                # test data
                test_data = self.provider.cache(client_idx,
                                                True).batch(self.batch_size)
                test_tot_correct, test_num_sample, test_loss = self._infer_model(
                    model, test_data)
                test_tot_corrects.append(copy.deepcopy(test_tot_correct))
                test_num_samples.append(copy.deepcopy(test_num_sample))
                test_losses.append(copy.deepcopy(test_loss))
                """Note: CI environment is CPU-based computing. The training speed for RNN training is to slow in 
                this setting, so we only test a client to make sure there is no programming error. """
                if self.args.ci == 1:
                    break

            # test on training dataset
            train_acc = sum(train_tot_corrects) / sum(train_num_samples)
            train_loss = sum(train_losses) / sum(train_num_samples)
            # wandb.log({"Train/Acc": train_acc, "round": round_idx})
            # wandb.log({"Train/Loss": train_loss, "round": round_idx})
            train_stats = {
                'training_acc': train_acc,
                'training_loss': train_loss
            }
            logging.info(train_stats)

            # test on test dataset
            test_acc = sum(test_tot_corrects) / sum(test_num_samples)
            test_loss = sum(test_losses) / sum(test_num_samples)
            # wandb.log({"Test/Acc": test_acc, "round": round_idx})
            # wandb.log({"Test/Loss": test_loss, "round": round_idx})
            test_stats = {'test_acc': test_acc, 'test_loss': test_loss}
            logging.info(test_stats)
            return train_stats, test_stats

    def test_on_all_sub_models(self, round_idx: int):
        print("####round: " + str(round_idx) + "####")
        round_key = "round_" + str(round_idx)
        train_stats, test_stats = self.test_model_on_all_clients(
            self.model, round_idx)
        print("round", round_idx)
        print("global_model_train", train_stats)
        print("global_model_test", test_stats)
        self.model_influence[round_key] = {}
        self.model_influence[round_key]["."] = {
            "train_stats": train_stats,
            "test_stats": test_stats
        }
        print("model[.]:", train_stats)
        for idx in self.model_dict:
            temp_model_dict = dict(self.model_dict)
            del temp_model_dict[idx]
            model_params = self.aggregate_models(temp_model_dict)
            sub_model = LogisticRegression(28 * 28, 10)
            # sub_model = self.cached_model
            sub_model.load_state_dict(model_params)
            train_stats, test_stats = self.test_model_on_all_clients(
                sub_model, round_idx)
            print("model[" + str(idx) + "]:", train_stats)
            print("model[" + str(idx) + "].train", train_stats)
            print("model[" + str(idx) + "].test", test_stats)
            influence = self._influence(sub_model)
            print("model[" + str(idx) + "].influence", influence)
            influence_no_real, influence_both, influence_original = influence
            influence_ecl = self._influence_ecl(sub_model)
            self.model_influence[round_key][str(idx)] = {}
            # self.model_influence[round_key][str(idx)]["test_stats"] = train_stats
            self.model_influence[round_key][str(
                idx)]["train_stats"] = train_stats
            self.model_influence[round_key][str(
                idx)]["influence_no_real"] = influence_no_real
            self.model_influence[round_key][str(
                idx)]["influence_real"] = influence_both
            self.model_influence[round_key][str(
                idx)]["influence_ecl"] = influence_ecl.numpy()
            # print("influence[" + str(idx) + "]", influence)
            # print("euclidean_influence[" + str(idx) + "]", self._influence_ecl(sub_model))
            # plotter.append(influence)
        print("####end of round: " + str(round_idx) + "####")
        # plotter.save("round_" + str(round_idx))
        # self.log_cache.save()
        return ""

    def _influence_ecl(self, model):
        original = self.model.state_dict()
        sub = model.state_dict()
        l2_norm = torch.dist(original["linear.weight"], sub["linear.weight"],
                             2)
        return l2_norm.cpu()

    # noinspection PyUnresolvedReferences
    def _influence(self, model):
        client_nums = self.args.influence_test_clients
        influence_no_labels = torch.tensor(0, dtype=torch.float)
        influence_correct_labels_both = torch.tensor(0, dtype=torch.float)
        influence_correct_labels_original = torch.tensor(0, dtype=torch.float)
        client_round = 0
        for client_idx in range(self.args.client_num_in_total):
            train_data = self.provider.cache(client_idx, True)
            if len(train_data) == 0:
                continue
            train_batches = train_data.batch(self.batch_size)

            deletion_prediction = self.predict(model, train_batches)
            original_prediction = self.predict(self.model, train_batches)
            deletion_labels = self._predictions_to_label(deletion_prediction)
            original_labels = self._predictions_to_label(original_prediction)
            real_labels = train_data.y

            # calculate the influence without taking into consideration the real labels
            influence_no_labels += self._influence_function_no_labels(
                deletion_prediction, original_prediction, deletion_labels,
                original_labels, real_labels)

            # calculate the influence taking into consideration only the correct labels in both predictions
            influence_correct_labels_both += self._influence_function_only_correct_labels_both(
                deletion_prediction, original_prediction, deletion_labels,
                original_labels, real_labels)

            # calculate the influence taking into consideration only the correct labels in original predictions
            influence_correct_labels_original += self._influence_function_only_correct_labels_original(
                deletion_prediction, original_prediction, deletion_labels,
                original_labels, real_labels)

            client_round += 1
            if 0 < client_nums <= client_round:
                break
        influence_no_labels = influence_no_labels.item() / client_nums
        influence_correct_labels_both = influence_correct_labels_both.item(
        ) / client_nums
        influence_correct_labels_original = influence_correct_labels_original.item(
        ) / client_nums
        return influence_no_labels, influence_correct_labels_both, influence_correct_labels_original

    def _influence_of_predictions(self, left, right):
        if len(left) == 0:
            return torch.tensor(0, dtype=torch.float)

        influence = torch.tensor(data=0.0, dtype=torch.float)
        for i in range(len(left)):
            difference = left[i] - right[i]
            difference = torch.abs(difference)
            influence += torch.mean(difference)
        return influence / len(left)

    def _conditional_influence_function(self, deletion_prediction,
                                        original_prediction, deletion_labels,
                                        original_labels, real_labels,
                                        condition):
        new_deletion_predictions = torch.tensor([])
        new_original_predictions = torch.tensor([])
        for index, label in enumerate(real_labels):
            if condition(deletion_labels[index], original_labels[index],
                         real_labels[index]):
                new_deletion_predictions = torch.cat(
                    (new_deletion_predictions,
                     torch.unsqueeze(deletion_prediction[index], 0)))
                new_original_predictions = torch.cat(
                    (new_original_predictions,
                     torch.unsqueeze(original_prediction[index], 0)))
        return self._influence_of_predictions(new_deletion_predictions,
                                              new_original_predictions)

    def _influence_function_no_labels(self, deletion_prediction,
                                      original_prediction, deletion_labels,
                                      original_labels, real_labels):
        return self._conditional_influence_function(
            deletion_prediction, original_prediction, deletion_labels,
            original_labels, real_labels, lambda d, o, r: True)

    def _influence_function_only_correct_labels_both(self, deletion_prediction,
                                                     original_prediction,
                                                     deletion_labels,
                                                     original_labels,
                                                     real_labels):

        return self._conditional_influence_function(
            deletion_prediction, original_prediction, deletion_labels,
            original_labels, real_labels, lambda d, o, r: d == r == o)

    def _influence_function_only_correct_labels_original(
            self, deletion_prediction, original_prediction, deletion_labels,
            original_labels, real_labels):
        return self._conditional_influence_function(
            deletion_prediction, original_prediction, deletion_labels,
            original_labels, real_labels, lambda d, o, r: o == r)

    def _predictions_to_label(self, predictions):
        labels = torch.tensor([])
        for prediction in predictions:
            predicted_label = torch.argmax(prediction)
            labels = torch.cat((labels, predicted_label.reshape(1).float()))
        return labels

    def _infer(self, test_data):
        return self._infer_model(self.model, test_data)

    def predict(self, model, data):
        predictions = None
        model.eval()
        model.to(self.device)
        with torch.no_grad():
            for batch_idx, (x, target) in enumerate(data):
                x = x.to(self.device)
                target.to(self.device)
                if predictions is None:
                    predictions = model(x)
                else:
                    predictions = torch.cat((predictions, model(x)))
        return predictions.cpu()

    def _infer_model(self, model, test_data):
        model.eval()
        model.to(self.device)

        test_loss = test_acc = test_total = 0.
        criterion = nn.CrossEntropyLoss().to(self.device)
        with torch.no_grad():
            for batch_idx, (x, target) in enumerate(test_data):
                x = x.to(self.device)
                target = target.to(self.device)
                pred = model(x)
                loss = criterion(pred, target)
                _, predicted = torch.max(pred, -1)
                correct = predicted.eq(target).sum()

                test_acc += correct.item()
                test_loss += loss.item() * target.size(0)
                test_total += target.size(0)

        return test_acc, test_total, test_loss