Пример #1
0
    def train(self):
        print('---{} workers per communication round---'.format(
            self.clients_per_round))

        np.random.seed(1234567 + self.seed)
        corrupt_id = np.random.choice(range(len(self.clients)),
                                      size=self.num_corrupted,
                                      replace=False)
        print(corrupt_id)

        if self.dataset == 'shakespeare':
            for c in self.clients:
                c.train_data['y'], c.train_data['x'] = process_y(
                    c.train_data['y']), process_x(c.train_data['x'])
                c.test_data['y'], c.test_data['x'] = process_y(
                    c.test_data['y']), process_x(c.test_data['x'])

        batches = {}
        for idx, c in enumerate(self.clients):
            if idx in corrupt_id:
                c.train_data['y'] = np.asarray(c.train_data['y'])
                if self.dataset == 'celeba':
                    c.train_data['y'] = 1 - c.train_data['y']
                elif self.dataset == 'femnist':
                    c.train_data['y'] = np.random.randint(
                        0, 62, len(c.train_data['y']))  # [0, 62)
                elif self.dataset == 'shakespeare':
                    c.train_data['y'] = np.random.randint(
                        0, 80, len(c.train_data['y']))
                elif self.dataset == "vehicle":
                    c.train_data['y'] = c.train_data['y'] * -1
                elif self.dataset == "fmnist":
                    c.train_data['y'] = np.random.randint(
                        0, 10, len(c.train_data['y']))

            if self.dataset == 'celeba':
                # due to a different data storage format
                batches[c] = gen_batch_celeba(
                    c.train_data, self.batch_size,
                    self.num_rounds * self.local_iters)
            else:
                batches[c] = gen_batch(c.train_data, self.batch_size,
                                       self.num_rounds * self.local_iters)

        for i in range(self.num_rounds + 1):
            if i % self.eval_every == 0 and i > 0:
                tmp_models = []
                for idx in range(len(self.clients)):
                    tmp_models.append(self.local_models[idx])

                num_train, num_correct_train, loss_vector = self.train_error(
                    tmp_models)
                avg_train_loss = np.dot(loss_vector,
                                        num_train) / np.sum(num_train)
                num_test, num_correct_test, _ = self.test(tmp_models)
                tqdm.write('At round {} training accu: {}, loss: {}'.format(
                    i,
                    np.sum(num_correct_train) * 1.0 / np.sum(num_train),
                    avg_train_loss))
                tqdm.write('At round {} test accu: {}'.format(
                    i,
                    np.sum(num_correct_test) * 1.0 / np.sum(num_test)))
                non_corrupt_id = np.setdiff1d(range(len(self.clients)),
                                              corrupt_id)
                tqdm.write('At round {} malicious test accu: {}'.format(
                    i,
                    np.sum(num_correct_test[corrupt_id]) * 1.0 /
                    np.sum(num_test[corrupt_id])))
                tqdm.write('At round {} benign test accu: {}'.format(
                    i,
                    np.sum(num_correct_test[non_corrupt_id]) * 1.0 /
                    np.sum(num_test[non_corrupt_id])))
                print(
                    "variance of the performance: ",
                    np.var(num_correct_test[non_corrupt_id] /
                           num_test[non_corrupt_id]))

            # weighted sampling
            indices, selected_clients = self.select_clients(
                round=i, num_clients=self.clients_per_round)

            csolns = []
            losses = []

            for idx in indices:
                w_global_idx = copy.deepcopy(self.global_model)
                c = self.clients[idx]
                for _ in range(self.local_iters):
                    data_batch = next(batches[c])

                    # local
                    self.client_model.set_params(self.local_models[idx])
                    _, grads, _ = c.solve_sgd(data_batch)

                    if self.dynamic_lam:

                        model_tmp = copy.deepcopy(self.local_models[idx])
                        model_best = copy.deepcopy(self.local_models[idx])
                        tmp_loss = 10000
                        # pick a lambda locally based on validation data
                        for lam_id, candidate_lam in enumerate([0.1, 1, 2]):
                            for layer in range(len(grads[1])):
                                eff_grad = grads[1][layer] + candidate_lam * (
                                    self.local_models[idx][layer] -
                                    self.global_model[layer])
                                model_tmp[layer] = self.local_models[idx][
                                    layer] - self.learning_rate * eff_grad

                            c.set_params(model_tmp)
                            l = c.get_val_loss()
                            if l < tmp_loss:
                                tmp_loss = l
                                model_best = copy.deepcopy(model_tmp)

                        self.local_models[idx] = copy.deepcopy(model_best)

                    else:
                        for layer in range(len(grads[1])):
                            eff_grad = grads[1][layer] + self.lam * (
                                self.local_models[idx][layer] -
                                self.global_model[layer])
                            self.local_models[idx][layer] = self.local_models[
                                idx][layer] - self.learning_rate * eff_grad

                    # global
                    self.client_model.set_params(w_global_idx)
                    loss = c.get_loss()
                    losses.append(loss)
                    _, grads, _ = c.solve_sgd(data_batch)
                    w_global_idx = self.client_model.get_params()

                # get the difference (global model updates)
                diff = [
                    u - v for (u, v) in zip(w_global_idx, self.global_model)
                ]

                # send the malicious updates
                if idx in corrupt_id:
                    if self.boosting:
                        # scale malicious updates
                        diff = [self.clients_per_round * u for u in diff]
                    elif self.random_updates:
                        # send random updates
                        stdev_ = get_stdev(diff)
                        diff = [
                            np.random.normal(0, stdev_, size=u.shape)
                            for u in diff
                        ]

                if self.q == 0:
                    csolns.append(diff)
                else:
                    csolns.append((np.exp(self.q * loss), diff))

            if self.q != 0:
                avg_updates = self.aggregate(csolns)
            else:
                if self.gradient_clipping:
                    csolns = l2_clip(csolns)

                expected_num_mali = int(self.clients_per_round *
                                        self.num_corrupted / len(self.clients))

                if self.median:
                    avg_updates = self.median_average(csolns)
                elif self.k_norm:
                    avg_updates = self.k_norm_average(
                        self.clients_per_round - expected_num_mali, csolns)
                elif self.krum:
                    avg_updates = self.krum_average(
                        self.clients_per_round - expected_num_mali - 2, csolns)
                elif self.mkrum:
                    m = self.clients_per_round - expected_num_mali
                    avg_updates = self.mkrum_average(
                        self.clients_per_round - expected_num_mali - 2, m,
                        csolns)
                else:
                    avg_updates = self.simple_average(csolns)

            # update the global model
            for layer in range(len(avg_updates)):
                self.global_model[layer] += avg_updates[layer]
Пример #2
0
    def train(self):
        print('Training with {} workers ---'.format(self.clients_per_round))

        np.random.seed(1234567 + self.seed)
        corrupt_id = np.random.choice(range(len(self.clients)),
                                      size=self.num_corrupted,
                                      replace=False)
        print(corrupt_id)

        batches = {}
        for idx, c in enumerate(self.clients):
            if idx in corrupt_id:
                c.train_data['y'] = np.asarray(c.train_data['y'])
                if self.dataset == 'celeba':
                    c.train_data['y'] = 1 - c.train_data['y']
                elif self.dataset == 'femnist':
                    c.train_data['y'] = np.random.randint(
                        0, 62, len(c.train_data['y']))  # [0, 62)
                elif self.dataset == 'fmnist':  # fashion mnist
                    c.train_data['y'] = np.random.randint(
                        0, 10, len(c.train_data['y']))

            if self.dataset == 'celeba':
                batches[c] = gen_batch_celeba(
                    c.train_data, self.batch_size,
                    self.num_rounds * self.local_iters + 350)
            else:
                batches[c] = gen_batch(
                    c.train_data, self.batch_size,
                    self.num_rounds * self.local_iters + 350)

        initialization = copy.deepcopy(self.clients[0].get_params())

        for i in range(self.num_rounds + 1):
            if i % self.eval_every == 0:
                num_test, num_correct_test, _ = self.test(
                )  # have set the latest model for all clients
                num_train, num_correct_train, loss_vector = self.train_error()

                avg_loss = np.dot(loss_vector, num_train) / np.sum(num_train)

                tqdm.write('At round {} training accu: {}, loss: {}'.format(
                    i,
                    np.sum(num_correct_train) * 1.0 / np.sum(num_train),
                    avg_loss))
                tqdm.write('At round {} test accu: {}'.format(
                    i,
                    np.sum(num_correct_test) * 1.0 / np.sum(num_test)))
                non_corrupt_id = np.setdiff1d(range(len(self.clients)),
                                              corrupt_id)
                tqdm.write('At round {} malicious test accu: {}'.format(
                    i,
                    np.sum(num_correct_test[corrupt_id]) * 1.0 /
                    np.sum(num_test[corrupt_id])))
                tqdm.write('At round {} benign test accu: {}'.format(
                    i,
                    np.sum(num_correct_test[non_corrupt_id]) * 1.0 /
                    np.sum(num_test[non_corrupt_id])))
                print(
                    "variance of the performance: ",
                    np.var(num_correct_test[non_corrupt_id] /
                           num_test[non_corrupt_id]))

            indices, selected_clients = self.select_clients(
                round=i,
                corrupt_id=corrupt_id,
                num_clients=self.clients_per_round)

            csolns = []
            losses = []

            for idx in indices:
                c = self.clients[idx]

                # communicate the latest model
                c.set_params(self.latest_model)
                weights_before = copy.deepcopy(self.latest_model)
                loss = c.get_loss()  # compute loss on the whole training data
                losses.append(loss)

                for _ in range(self.local_iters):
                    data_batch = next(batches[c])
                    _, _, _ = c.solve_sgd(data_batch)

                new_weights = c.get_params()

                grads = [(u - v) * 1.0
                         for u, v in zip(new_weights, weights_before)]

                if idx in corrupt_id:
                    if self.boosting:  # model replacement
                        grads = [self.clients_per_round * u for u in grads]
                    elif self.random_updates:
                        # send random updates
                        stdev_ = get_stdev(grads)
                        grads = [
                            np.random.normal(0, stdev_, size=u.shape)
                            for u in grads
                        ]

                if self.q > 0:
                    csolns.append((np.exp(self.q * loss), grads))
                else:
                    csolns.append(grads)

            if self.q > 0:
                overall_updates = self.aggregate(csolns)
            else:
                if self.gradient_clipping:
                    csolns = l2_clip(csolns)

                expected_num_mali = int(self.clients_per_round *
                                        self.num_corrupted / len(self.clients))
                if self.median:
                    overall_updates = self.median_average(csolns)
                elif self.k_norm:
                    overall_updates = self.k_norm_average(
                        self.clients_per_round - expected_num_mali, csolns)
                elif self.k_loss:
                    overall_updates = self.k_loss_average(
                        self.clients_per_round - expected_num_mali, losses,
                        csolns)
                elif self.krum:
                    overall_updates = self.krum_average(
                        self.clients_per_round - expected_num_mali - 2, csolns)
                elif self.mkrum:
                    m = self.clients_per_round - expected_num_mali
                    overall_updates = self.mkrum_average(
                        self.clients_per_round - expected_num_mali - 2, m,
                        csolns)
                else:
                    overall_updates = self.simple_average(csolns)

            self.latest_model = [
                (u + v) for u, v in zip(self.latest_model, overall_updates)
            ]

            distance = np.linalg.norm(
                process_grad(self.latest_model) - process_grad(initialization))
            if i % self.eval_every == 0:
                print('distance to initialization:', distance)

        # local finetuning
        init_model = copy.deepcopy(self.latest_model)

        after_test_accu = []
        test_samples = []
        for idx, c in enumerate(self.clients):
            c.set_params(init_model)
            local_model = copy.deepcopy(init_model)
            for _ in range(
                    max(
                        int(self.finetune_iters * c.train_samples /
                            self.batch_size), self.finetune_iters)):
                c.set_params(local_model)
                data_batch = next(batches[c])
                _, grads, _ = c.solve_sgd(data_batch)
                for j in range(len(grads[1])):
                    eff_grad = grads[1][j] + self.lam * (local_model[j] -
                                                         init_model[j])
                    local_model[j] = local_model[
                        j] - self.learning_rate * self.decay_factor * eff_grad
            c.set_params(local_model)
            tc, _, num_test = c.test()
            after_test_accu.append(tc)
            test_samples.append(num_test)

        after_test_accu = np.asarray(after_test_accu)
        test_samples = np.asarray(test_samples)
        tqdm.write('final test accu: {}'.format(
            np.sum(after_test_accu) * 1.0 / np.sum(test_samples)))
        tqdm.write('final malicious test accu: {}'.format(
            np.sum(after_test_accu[corrupt_id]) * 1.0 /
            np.sum(test_samples[corrupt_id])))
        tqdm.write('final benign test accu: {}'.format(
            np.sum(after_test_accu[non_corrupt_id]) * 1.0 /
            np.sum(test_samples[non_corrupt_id])))
        print(
            "variance of the performance: ",
            np.var(after_test_accu[non_corrupt_id] /
                   test_samples[non_corrupt_id]))
Пример #3
0
    def train(self):
        print('Training with {} workers ---'.format(self.clients_per_round))

        np.random.seed(1234567 + self.seed)
        corrupt_id = np.random.choice(range(len(self.clients)),
                                      size=self.num_corrupted,
                                      replace=False)
        print(corrupt_id)

        if self.dataset == 'shakespeare':
            for c in self.clients:
                c.train_data['y'], c.train_data['x'] = process_y(
                    c.train_data['y']), process_x(c.train_data['x'])
                c.test_data['y'], c.test_data['x'] = process_y(
                    c.test_data['y']), process_x(c.test_data['x'])

        batches = {}

        for idx, c in enumerate(self.clients):
            if idx in corrupt_id:
                c.train_data['y'] = np.asarray(c.train_data['y'])
                if self.dataset == 'celeba':
                    c.train_data['y'] = 1 - c.train_data['y']
                elif self.dataset == 'femnist':
                    c.train_data['y'] = np.random.randint(
                        0, 62, len(c.train_data['y']))
                elif self.dataset == 'shakespeare':
                    c.train_data['y'] = np.random.randint(
                        0, 80, len(c.train_data['y']))
                elif self.dataset == "vehicle":
                    c.train_data['y'] = c.train_data['y'] * -1
                elif self.dataset == 'fmnist':
                    c.train_data['y'] = np.random.randint(
                        0, 10, len(c.train_data['y']))

            if self.dataset == 'celeba':  # need to deal with celeba data loading a bit differently
                batches[c] = gen_batch_celeba(
                    c.train_data, self.batch_size,
                    self.num_rounds * self.local_iters)
            else:
                batches[c] = gen_batch(c.train_data, self.batch_size,
                                       self.num_rounds * self.local_iters)

        for i in range(self.num_rounds + 1):
            if i % self.eval_every == 0:
                num_test, num_correct_test, test_loss_vector = self.test(
                )  # have set the latest model for all clients
                avg_test_loss = np.dot(test_loss_vector,
                                       num_test) / np.sum(num_test)
                num_train, num_correct_train, train_loss_vector = self.train_error(
                )
                avg_train_loss = np.dot(train_loss_vector,
                                        num_train) / np.sum(num_train)

                tqdm.write('At round {} training accu: {}, loss: {}'.format(
                    i,
                    np.sum(num_correct_train) * 1.0 / np.sum(num_train),
                    avg_train_loss))
                tqdm.write('At round {} test loss: {}'.format(
                    i, avg_test_loss))
                tqdm.write('At round {} test accu: {}'.format(
                    i,
                    np.sum(num_correct_test) * 1.0 / np.sum(num_test)))
                non_corrupt_id = np.setdiff1d(range(len(self.clients)),
                                              corrupt_id)
                tqdm.write('At round {} malicious test accu: {}'.format(
                    i,
                    np.sum(num_correct_test[corrupt_id]) * 1.0 /
                    np.sum(num_test[corrupt_id])))
                tqdm.write('At round {} benign test accu: {}'.format(
                    i,
                    np.sum(num_correct_test[non_corrupt_id]) * 1.0 /
                    np.sum(num_test[non_corrupt_id])))
                tqdm.write('At round {} variance: {}'.format(
                    i,
                    np.var(num_correct_test[non_corrupt_id] * 1.0 /
                           num_test[non_corrupt_id])))

            indices, selected_clients = self.select_clients(
                round=i,
                corrupt_id=corrupt_id,
                num_clients=self.clients_per_round)

            csolns = []
            losses = []

            for idx in indices:
                c = self.clients[idx]

                # communicate the latest model
                c.set_params(self.latest_model)
                weights_before = copy.deepcopy(self.latest_model)

                loss = c.get_loss()  # training loss
                losses.append(loss)

                for _ in range(self.local_iters):
                    data_batch = next(batches[c])
                    _, grads, _ = c.solve_sgd(data_batch)

                w_global = c.get_params()

                grads = [(u - v) * 1.0
                         for u, v in zip(w_global, weights_before)]

                if idx in corrupt_id:
                    if self.boosting:  # model replacement
                        grads = [self.clients_per_round * u for u in grads]
                    elif self.random_updates:
                        # send random updates
                        stdev_ = get_stdev(grads)
                        grads = [
                            np.random.normal(0, stdev_, size=u.shape)
                            for u in grads
                        ]

                if self.q == 0:
                    csolns.append(grads)
                else:
                    csolns.append((np.exp(self.q * loss), grads))

            if self.q != 0:
                overall_updates = self.aggregate(csolns)
            else:
                if self.gradient_clipping:
                    csolns = l2_clip(csolns)

                expected_num_mali = int(self.clients_per_round *
                                        self.num_corrupted / len(self.clients))

                if self.median:
                    overall_updates = self.median_average(csolns)
                elif self.k_norm:
                    overall_updates = self.k_norm_average(
                        self.clients_per_round - expected_num_mali, csolns)
                elif self.k_loss:
                    overall_updates = self.k_loss_average(
                        self.clients_per_round - expected_num_mali, losses,
                        csolns)
                elif self.krum:
                    overall_updates = self.krum_average(
                        self.clients_per_round - expected_num_mali - 2, csolns)
                elif self.mkrum:
                    m = self.clients_per_round - expected_num_mali
                    overall_updates = self.mkrum_average(
                        self.clients_per_round - expected_num_mali - 2, m,
                        csolns)
                elif self.fedmgda:
                    overall_updates = self.fedmgda_average(csolns)
                else:
                    overall_updates = self.simple_average(csolns)

            self.latest_model = [
                (u + v) for u, v in zip(self.latest_model, overall_updates)
            ]
Пример #4
0
    def train(self):
        print('---{} workers per communication round---'.format(
            self.clients_per_round))

        np.random.seed(1234567)
        corrupt_id = np.random.choice(range(len(self.clients)),
                                      size=self.num_corrupted)
        print(corrupt_id)

        batches = {}
        for idx, c in enumerate(self.clients):
            if idx in corrupt_id:
                c.train_data['y'] = np.asarray(c.train_data['y'])
                if self.dataset == 'celeba':
                    c.train_data['y'] = 1 - c.train_data['y']
                elif self.dataset == 'femnist':
                    c.train_data['y'] = np.random.randint(
                        0, 62, len(c.train_data['y']))  # [0, 62)
                elif self.dataset == 'shakespeare':
                    c.train_data['y'] = np.random.randint(
                        0, 80, len(c.train_data['y']))
                elif self.dataset == "vehicle":
                    c.train_data['y'] = c.train_data['y'] * -1
                elif self.dataset == "fmnist":
                    c.train_data['y'] = np.random.randint(
                        0, 10, len(c.train_data['y']))

            if self.dataset == 'celeba':
                # due to a different data storage format
                batches[c] = gen_batch_celeba(
                    c.train_data, self.batch_size,
                    self.num_rounds * self.local_iters)
            else:
                batches[c] = gen_batch(c.train_data, self.batch_size,
                                       self.num_rounds * self.local_iters)

        initialization = copy.deepcopy(self.clients[0].get_params())

        for i in range(self.num_rounds + 1):
            if i % self.eval_every == 0 and i > 0:
                tmp_models = []
                for idx in range(len(self.clients)):
                    a = []
                    for layer in range(len(self.local_models[idx])):
                        a.append(self.alpha * self.local_models[idx][layer] +
                                 (1 - self.alpha) * self.global_model[layer])
                    tmp_models.append(a)
                num_test, num_correct_test = self.test(tmp_models)
                num_train, num_correct_train, loss_vector = self.train_error(
                    tmp_models)
                avg_loss = np.dot(loss_vector, num_train) / np.sum(num_train)
                print(num_correct_test / num_test)

                tqdm.write('At round {} training accu: {}, loss: {}'.format(
                    i,
                    np.sum(num_correct_train) * 1.0 / np.sum(num_train),
                    avg_loss))
                tqdm.write('At round {} test accu: {}'.format(
                    i,
                    np.sum(num_correct_test) * 1.0 / np.sum(num_test)))
                non_corrupt_id = np.setdiff1d(range(len(self.clients)),
                                              corrupt_id)
                tqdm.write('At round {} malicious test accu: {}'.format(
                    i,
                    np.sum(num_correct_test[corrupt_id]) * 1.0 /
                    np.sum(num_test[corrupt_id])))
                tqdm.write('At round {} benign test accu: {}'.format(
                    i,
                    np.sum(num_correct_test[non_corrupt_id]) * 1.0 /
                    np.sum(num_test[non_corrupt_id])))
                print(
                    "variance of the performance: ",
                    np.var(num_correct_test[non_corrupt_id] /
                           num_test[non_corrupt_id]))

            # weighted sampling
            indices, selected_clients = self.select_clients(
                round=i, num_clients=self.clients_per_round)

            csolns = []

            for idx in indices:
                c = self.clients[idx]

                # server sends the current global model to selected devices
                w_global_idx = copy.deepcopy(self.global_model)

                self.client_model.set_params(self.global_model)

                for _ in range(self.local_iters):
                    # first sample a mini-batch
                    data_batch = next(batches[c])
                    # optimize the global model

                    self.client_model.set_params(w_global_idx)
                    _, grads, _ = c.solve_sgd(
                        data_batch)  # grads: (num_samples, real_grads)
                    w_global_idx = self.client_model.get_params()

                    # optimize for the local model (wrt to the interpolation)
                    self.client_model.set_params(self.interpolation[idx])
                    _, grads, _ = c.solve_sgd(data_batch)
                    for layer in range(len(self.local_models[idx])):
                        self.local_models[idx][layer] = self.local_models[idx][
                            layer] - self.alpha * self.learning_rate * grads[
                                1][layer]
#
# update the interpolation
                    for layer in range(len(self.local_models[idx])):
                        self.interpolation[idx][
                            layer] = self.alpha * self.local_models[idx][
                                layer] + (1 - self.alpha) * w_global_idx[layer]

                diff = [
                    u - v for (u, v) in zip(w_global_idx, self.global_model)
                ]

                # send the malicious updates
                if idx in corrupt_id:
                    if self.boosting:
                        # scale malicious updates
                        diff = [self.clients_per_round * u for u in diff]
                    elif self.random_updates:
                        # send random updates
                        stdev_ = get_stdev(diff)
                        diff = [
                            np.random.normal(0, stdev_, size=u.shape)
                            for u in diff
                        ]

                csolns.append(diff)

            if self.gradient_clipping:
                csolns = l2_clip(csolns)
            if self.median:
                avg_updates = self.median_average(csolns)
            else:
                avg_updates = self.simple_average(csolns)
            for layer in range(len(avg_updates)):
                self.global_model[layer] += avg_updates[layer]
Пример #5
0
    def train(self):
        print('---{} workers per communication round---'.format(self.clients_per_round))

        np.random.seed(1234567+self.seed)
        corrupt_id = np.random.choice(range(len(self.clients)), size=self.num_corrupted, replace=False)

        batches = {}
        for idx, c in enumerate(self.clients):
            if idx in corrupt_id:
                c.train_data['y'] = np.asarray(c.train_data['y'])
                if self.dataset == 'celeba':
                    c.train_data['y'] = 1 - c.train_data['y']
                elif self.dataset == 'femnist':
                    c.train_data['y'] = np.random.randint(0, 62, len(c.train_data['y']))

            if self.dataset == 'celeba':
                batches[c] = gen_batch_celeba(c.train_data, self.batch_size, self.num_rounds * self.local_iters + 300)
            else:
                batches[c] = gen_batch(c.train_data, self.batch_size, self.num_rounds * self.local_iters + 300)


        for i in range(self.num_rounds + 1):
            if i % self.eval_every == 0 and i > 0:

                num_test, num_correct_test, _ = self.test()  # have set the latest model for all clients
                num_train, num_correct_train, loss_vector = self.train_error()

                avg_loss = np.dot(loss_vector, num_train) / np.sum(num_train)

                tqdm.write('At round {} training accu: {}, loss: {}'.format(i, np.sum(num_correct_train) * 1.0 / np.sum(
                    num_train), avg_loss))
                tqdm.write('At round {} test accu: {}'.format(i, np.sum(num_correct_test) * 1.0 / np.sum(num_test)))
                non_corrupt_id = np.setdiff1d(range(len(self.clients)), corrupt_id)
                tqdm.write('At round {} malicious test accu: {}'.format(i, np.sum(
                    num_correct_test[corrupt_id]) * 1.0 / np.sum(num_test[corrupt_id])))
                tqdm.write('At round {} benign test accu: {}'.format(i, np.sum(
                    num_correct_test[non_corrupt_id]) * 1.0 / np.sum(num_test[non_corrupt_id])))
                print("variance of the performance: ",
                      np.var(num_correct_test[non_corrupt_id] / num_test[non_corrupt_id]))

            # weighted sampling
            indices, selected_clients = self.select_clients(round=i, corrupt_id=corrupt_id, num_clients=self.clients_per_round)

            csolns = []
            for idx in indices:
                w_global_idx = copy.deepcopy(self.latest_model)
                c = self.clients[idx]
                c.set_params(w_global_idx)
                for _ in range(self.local_iters):
                    data_batch = next(batches[c])
                    _, grads, _ = c.solve_sgd(data_batch)
                w_global_idx = self.client_model.get_params()

                # get the difference (global model updates)
                diff = [u - v for (u, v) in zip(w_global_idx, self.latest_model)]

                # send the malicious updates
                if idx in corrupt_id:
                    if self.boosting:
                        # scale malicious updates
                        diff = [self.clients_per_round * u for u in diff]
                    elif self.random_updates:
                        # send random updates
                        stdev_ = get_stdev(diff)
                        diff = [np.random.normal(0, stdev_, size=u.shape) for u in diff]

                csolns.append(diff)

            if self.gradient_clipping:
                csolns = l2_clip(csolns)

            avg_updates = self.simple_average(csolns)

            # update the global model
            for layer in range(len(avg_updates)):
                self.latest_model[layer] += avg_updates[layer]


        # local finetuning based on KL
        after_test_accu = []
        test_samples = []
        for idx, c in enumerate(self.clients):

            c.set_params(self.latest_model)
            output2 = copy.deepcopy(c.get_softmax()) 
            # start to finetune
            local_model = copy.deepcopy(self.latest_model)

            for _ in range(max(int(self.finetune_iters * c.train_samples / self.batch_size), self.finetune_iters)):
                data_batch = next(batches[c])
                c.set_params(local_model)
                kl_grads = c.get_kl_grads(output2)
                _, grads, _ = c.solve_sgd(data_batch)

                for j in range(len(grads[1])):
                   eff_grad = grads[1][j] + self.lam * kl_grads[j]
                   local_model[j] = local_model[j] - self.learning_rate * eff_grad

            c.set_params(local_model)
            tc, _, num_test = c.test()
            after_test_accu.append(tc)
            test_samples.append(num_test)


        non_corrupt_id = np.setdiff1d(range(len(self.clients)), corrupt_id)
        after_test_accu = np.asarray(after_test_accu)
        test_samples = np.asarray(test_samples)
        tqdm.write('final test accu: {}'.format(np.sum(after_test_accu) * 1.0 / np.sum(test_samples)))
        tqdm.write('final malicious test accu: {}'.format(np.sum(
            after_test_accu[corrupt_id]) * 1.0 / np.sum(test_samples[corrupt_id])))
        tqdm.write('final benign test accu: {}'.format(np.sum(
            after_test_accu[non_corrupt_id]) * 1.0 / np.sum(test_samples[non_corrupt_id])))
        print("variance of the performance: ",
              np.var(after_test_accu[non_corrupt_id] / test_samples[non_corrupt_id]))