Exemplo n.º 1
0
    def get_average_model(self, workers_idx):
        """ Find the average of all parameters of all layers of given workers. The average for each 
            layer is stored as an element of a list.

        Args:
            workers_idx (list[str]): List of workers ids
        Returns:
            
        """
        logging.info("Finding an average model for {} workers.".format(
            len(workers_idx)))
        tmp_model = FLNet().to(self.device)
        self.getback_model(self.workers_model, workers_idx)

        with torch.no_grad():
            for id_, ww_id in enumerate(workers_idx):
                worker_model = self.workers_model[ww_id]
                if id_ == 0:
                    tmp_model.conv1.weight.set_(worker_model.conv1.weight.data)
                    tmp_model.conv1.bias.set_(worker_model.conv1.bias.data)
                    tmp_model.conv2.weight.set_(worker_model.conv2.weight.data)
                    tmp_model.conv2.bias.set_(worker_model.conv2.bias.data)
                    tmp_model.fc1.weight.set_(worker_model.fc1.weight.data)
                    tmp_model.fc1.bias.set_(worker_model.fc1.bias.data)
                    tmp_model.fc2.weight.set_(worker_model.fc2.weight.data)
                    tmp_model.fc2.bias.set_(worker_model.fc2.bias.data)
                else:
                    tmp_model.conv1.weight.set_(tmp_model.conv1.weight.data +
                                                worker_model.conv1.weight.data)
                    tmp_model.conv1.bias.set_(tmp_model.conv1.bias.data +
                                              worker_model.conv1.bias.data)
                    tmp_model.conv2.weight.set_(tmp_model.conv2.weight.data +
                                                worker_model.conv2.weight.data)
                    tmp_model.conv2.bias.set_(tmp_model.conv2.bias.data +
                                              worker_model.conv2.bias.data)
                    tmp_model.fc1.weight.set_(tmp_model.fc1.weight.data +
                                              worker_model.fc1.weight.data)
                    tmp_model.fc1.bias.set_(tmp_model.fc1.bias.data +
                                            worker_model.fc1.bias.data)
                    tmp_model.fc2.weight.set_(tmp_model.fc2.weight.data +
                                              worker_model.fc2.weight.data)
                    tmp_model.fc2.bias.set_(tmp_model.fc2.bias.data +
                                            worker_model.fc2.bias.data)

        for param in tmp_model.parameters():
            param.data = param.data / len(workers_idx)

        return tmp_model
Exemplo n.º 2
0
    def wieghted_avg_model(self, W, workers_model):
        self.getback_model(workers_model)
        tmp_model = FLNet().to(self.device)
        with torch.no_grad():
            tmp_model.conv1.weight.data.fill_(0)
            tmp_model.conv1.bias.data.fill_(0)
            tmp_model.conv2.weight.data.fill_(0)
            tmp_model.conv2.bias.data.fill_(0)
            tmp_model.fc1.weight.data.fill_(0)
            tmp_model.fc1.bias.data.fill_(0)
            tmp_model.fc2.weight.data.fill_(0)
            tmp_model.fc2.bias.data.fill_(0)

            for counter, (ww_id,
                          worker_model) in enumerate(workers_model.items()):
                tmp_model.conv1.weight.data = (
                    tmp_model.conv1.weight.data +
                    W[counter] * worker_model.conv1.weight.data)
                tmp_model.conv1.bias.data = (
                    tmp_model.conv1.bias.data +
                    W[counter] * worker_model.conv1.bias.data)
                tmp_model.conv2.weight.data = (
                    tmp_model.conv2.weight.data +
                    W[counter] * worker_model.conv2.weight.data)
                tmp_model.conv2.bias.data = (
                    tmp_model.conv2.bias.data +
                    W[counter] * worker_model.conv2.bias.data)
                tmp_model.fc1.weight.data = (
                    tmp_model.fc1.weight.data +
                    W[counter] * worker_model.fc1.weight.data)
                tmp_model.fc1.bias.data = (
                    tmp_model.fc1.bias.data +
                    W[counter] * worker_model.fc1.bias.data)
                tmp_model.fc2.weight.data = (
                    tmp_model.fc2.weight.data +
                    W[counter] * worker_model.fc2.weight.data)
                tmp_model.fc2.bias.data = (
                    tmp_model.fc2.bias.data +
                    W[counter] * worker_model.fc2.bias.data)

        return tmp_model
Exemplo n.º 3
0
def train_workers_with_attack(federated_train_loader, models, workers_idx,
                              attackers_idx, round_no, args):
    attackers_here = [ii for ii in workers_idx if ii in attackers_idx]
    workers_opt = dict()
    workers_loss = data = defaultdict(lambda: [])
    for ii in workers_idx:
        workers_opt[ii] = torch.optim.SGD(params=models[ii].parameters(),
                                          lr=args.lr,
                                          weight_decay=args.weight_decay)
    with tqdm(total=args.epochs,
              leave=False,
              colour="yellow",
              ncols=80,
              desc="Epoch\t",
              bar_format=TQDM_R_BAR) as t2:
        for epoch in range(args.epochs):
            t2.set_postfix(Rounds=round_no, Epochs=epoch)
            with tqdm(total=len(workers_idx),
                      ncols=80,
                      desc="Workers\t",
                      leave=False,
                      bar_format=TQDM_R_BAR) as t3:
                t3.set_postfix(ordered_dict={
                    'ATK':
                    "{}/{}".format(len(attackers_here), len(workers_idx))
                })
                for ww_id, fed_dataloader in federated_train_loader.items():
                    if ww_id in workers_idx:
                        with tqdm(total=len(fed_dataloader),
                                  ncols=80,
                                  colour='red',
                                  desc="Batch\t",
                                  leave=False,
                                  bar_format=TQDM_R_BAR) as t4:
                            for batch_idx, (
                                    data, target) in enumerate(fed_dataloader):
                                ww = data.location
                                model = models[ww.id]
                                data, target = data.to("cpu"), target.to("cpu")
                                if ww.id in attackers_idx:
                                    if args.attack_type == 1:
                                        models[ww.id] = FLNet().to(args.device)
                                    elif args.attack_type == 2:
                                        ss = utils.negative_parameters(
                                            models[ww.id].state_dict())
                                        models[ww.id].load_state_dict(ss)
                                    t4.set_postfix(
                                        ordered_dict={
                                            'Worker':
                                            ww.id,
                                            'ATK':
                                            "[T]" if ww.id in
                                            attackers_idx else "[F]",
                                            'BatchID':
                                            batch_idx,
                                            'Loss':
                                            '-'
                                        })
                                    #TODO: Be careful about the break
                                    break
                                else:
                                    model.train()
                                    model.send(ww.id)
                                    opt = workers_opt[ww.id]
                                    opt.zero_grad()
                                    output = model(data)
                                    loss = F.nll_loss(output, target)
                                    loss.backward()
                                    opt.step()
                                    model.get()  # <-- NEW: get the model back
                                    loss = loss.get(
                                    )  # <-- NEW: get the loss back
                                    workers_loss[ww.id].append(loss.item())
                                    t4.set_postfix(
                                        ordered_dict={
                                            'Worker':
                                            ww.id,
                                            'ATK':
                                            "[T]" if ww.id in
                                            attackers_idx else "[F]",
                                            'BatchID':
                                            batch_idx,
                                            'Loss':
                                            loss.item()
                                        })
                                t4.update()
                        t3.update()
            t2.update()

    # Mean per worker
    for ii in workers_loss:
        workers_loss[ii] = sum(workers_loss[ii]) / len(workers_loss[ii])
    return workers_loss
Exemplo n.º 4
0
def main(start_round):
    logging.info("Total number of users: {}".format(args.total_users_num))
    workers_idx = ["worker_" + str(i) for i in range(args.total_users_num)]
    workers = create_workers(hook, workers_idx)
    server = create_workers(hook, ['server'])
    server = server['server']
    # if args.local_log:
    #     utils.check_write_to_file(args.log_dir, "all_users", workers_idx)

    attackers_idx = None
    if utils.find_file(args.log_dir, "attackers"):
        logging.info("attackers list was found. Loading from file...")
        attackers_idx = utils.load_object(args.log_dir, "attackers")
    else:
        logging.error("This should not be happened in this study.")
        exit(1)
    #     attackers_idx = utils.get_workers_idx(workers_idx, args.attackers_num, [])
    #     if args.local_log:
    #         utils.save_object(args.log_dir, "attackers", attackers_idx)

    mapped_datasets = dict()
    if utils.find_file(args.log_dir, "mapped_datasets"):
        logging.info("mapped_datasets was found. Loading from file...")
        mapped_datasets = utils.load_object(args.log_dir, "mapped_datasets")
    else:
        logging.error("This should not be happened in this study.")
        exit(1)
        # Now sort the dataset and distribute among users
        # mapped_ds_itr = utils.map_shards_to_worker(
        #     utils.split_randomly_dataset(
        #         utils.sort_mnist_dataset(
        #             utils.fraction_of_datasets(
        #                 {"dataset": utils.load_mnist_dataset(
        #                     train=True,
        #                     transform=transforms.Compose([transforms.ToTensor(),]))},
        #                 args.load_fraction, [])
        #         ),
        #         args.shards_num),
        #     workers_idx,
        #     args.shards_per_worker_num)

        # mapping to users and performin attacks
        # for mapped_ds in mapped_ds_itr:
        #     for ww_id, dataset in mapped_ds.items():
        #         if ww_id in attackers_idx:
        #             mapped_datasets.update(
        #                 {ww_id: FLCustomDataset(
        #                     utils.attack_shuffle_pixels(dataset.data),
        #                     dataset.targets,
        #                     transform=transforms.Compose([
        #                         transforms.ToTensor()])
        #                 )}
        #             )
        #         else:
        #             mapped_datasets.update(mapped_ds)

        # if args.local_log:
        #     utils.save_object(args.log_dir, "mapped_datasets", mapped_datasets)

    server_pub_dataset = None
    if utils.find_file(args.log_dir, "server_pub_dataset"):
        logging.info("server_pub_dataset was found. Loading from file...")
        server_pub_dataset = utils.load_object(args.log_dir,
                                               "server_pub_dataset")
    else:
        logging.error("This should not be happened in this study.")
        exit(1)
        # if args.server_pure:
        #     server_pub_dataset = utils.fraction_of_datasets(mapped_datasets, args.server_data_fraction)
        # else:
        #     logging.info("Server data is NOT pure.")
        #     server_pub_dataset = utils.fraction_of_datasets(
        #         mapped_datasets, args.server_data_fraction, attackers_idx)
        # if args.local_log:
        #     utils.save_object(args.log_dir, "server_pub_dataset", server_pub_dataset)

    federated_server_loader = dict()
    federated_server_loader['server'] = sy.FederatedDataLoader(
        server_pub_dataset.federate([server]),
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False)

    federated_train_loader = dict()
    logging.info("Creating federated dataloaders for workers...")
    for ww_id, fed_dataset in mapped_datasets.items():
        federated_train_loader[ww_id] = sy.FederatedDataLoader(
            fed_dataset.federate([workers[ww_id]]),
            batch_size=args.batch_size,
            shuffle=False,
            drop_last=False)

    test_loader = utils.get_dataloader(utils.load_mnist_dataset(
        train=False, transform=transforms.Compose([
            transforms.ToTensor(),
        ])),
                                       args.test_batch_size,
                                       shuffle=True,
                                       drop_last=False)

    previous_round = int(start_round)
    logging.info("Use explicit starting round number: {}".format(start_round))
    # if utils.find_file(args.log_dir, "accuracy"):
    #     previous_round = int(utils.get_last_round_num(args.log_dir, "accuracy"))
    #     logging.info("Previous complete execution was found. Last run is: {}".format(previous_round))

    round_start = previous_round if previous_round == 0 else previous_round + 1
    round_end = round_start + ROUNDS_BREAKDOWN if round_start + ROUNDS_BREAKDOWN < args.rounds else args.rounds

    server_model, server_model_name = FLNet().to(
        args.device), "R{}_server_model".format(previous_round)
    server_model_path = args.log_dir + "models"
    if utils.find_file(server_model_path, server_model_name):
        logging.info(
            "server_model was found. Loading {} from the file...".format(
                server_model_name))
        server_model.load_state_dict(
            torch.load(server_model_path + "/" + server_model_name))
    server_opt = dict()
    server_opt['server'] = torch.optim.SGD(params=server_model.parameters(),
                                           lr=args.lr,
                                           weight_decay=args.weight_decay)
    test_loss, test_acc = 0.0, 0.0
    with tqdm(total=min(ROUNDS_BREAKDOWN, args.rounds - round_start),
              leave=True,
              colour="green",
              ncols=80,
              desc="Round\t",
              bar_format=TQDM_R_BAR) as t1:
        for round_no in range(round_start, round_end):
            workers_to_be_used = random.sample(workers_idx,
                                               args.selected_users_num)
            workers_model = dict()
            for ww_id in workers_to_be_used:
                workers_model[ww_id] = deepcopy(server_model)

            # logging.info("Workers for this round: {}".format(workers_to_be_used))
            if args.local_log:
                utils.save_object(
                    args.log_dir,
                    "R{}_p_pca_workers".format(round_no) if args.server_pure else \
                            "R{}_np_pca_workers".format(round_no) ,
                    workers_to_be_used
                )
            train_loss = train_workers_with_attack(federated_train_loader,
                                                   workers_model,
                                                   workers_to_be_used,
                                                   attackers_idx, round_no,
                                                   args)

            # Find the best weights and update the server model
            weights = dict()
            if args.mode == "avg":
                # Each worker takes two shards of 300 random.samples. Total of 600 random.samples
                # per worker. Total number of random.samples is 60000.
                # weights = [600.0 / 60000] * args.selected_users_num
                for ww_id in workers_to_be_used:
                    weights[ww_id] = 1.0 / args.selected_users_num
                train_loss = sum(train_loss.values()) / len(train_loss)
            elif args.mode == "opt":
                # models should be returned from the workers before calling the following functions:
                # Train server
                train_workers_with_attack(federated_server_loader,
                                          {'server': server_model}, ['server'],
                                          [], round_no, args)
                pass

                weights = utils.find_best_weights_opt(server_model,
                                                      workers_model)

                loss = 0
                for ww in train_loss:
                    loss += weights[ww] * train_loss[ww]
                train_loss = loss
                # if args.local_log:
                #     utils.write_to_file(args.log_dir, "opt_weights", weights, round_no=round_no)

            # logging.info("Update server model in this round...")
            server_model.load_state_dict(
                utils.wieghted_avg_model(weights, workers_model,
                                         workers_to_be_used))

            # Apply the server model to the test dataset
            # logging.info("Starting model evaluation on the test dataset...")
            # test_loss, test_acc = test(server_model, test_loader, round_no, args)

            if args.local_log:
                # utils.write_to_file(args.log_dir, "train_loss", train_loss, round_no=round_no)
                # utils.save_model(
                #     server_model.state_dict(),
                #     "{}/{}".format(args.log_dir, "models"),
                #     "pca_niid_np_{}_{}_server_R{}".format(args.mode, args.attackers_num, round_no) if not args.server_pure else \
                #             "pca_niid_p_{}_{}_server_R{}".format(args.mode, args.attackers_num, round_no)
                # )
                for ww_id, ww_model in workers_model.items():
                    utils.save_model(
                        ww_model.state_dict(),
                        "{}/{}/workers_p_R{}".format(args.log_dir, "models", round_no) \
                            if args.server_pure else \
                                "{}/{}/workers_np_R{}".format(args.log_dir, "models", round_no),
                        "{}_model".format(ww_id)
                    )
            if args.neptune_log:
                neptune.log_metric("train_loss", train_loss)

            print()
            logging.info('Test Average loss: {:.4f}, Accuracy: {:.0f}%'.format(
                test_loss, test_acc))
            print()

            t1.set_postfix(test_acc=test_acc, test_loss=test_loss)
            t1.update()
    return round_end
Exemplo n.º 5
0
 def create_model(self):
     logging.info("Creating a model...")
     return FLNet().to(self.device)
Exemplo n.º 6
0
 def create_server_model(self):
     logging.info("Creating a model for the server...")
     self.server_model = FLNet().to(self.device)
Exemplo n.º 7
0
class FederatedLearning():

    # Initializing variables
    # batch_size, test_batch_size, lr, wieght_decay, momentum,
    def __init__(self, neptune_enable, log_enable, log_interval, output_dir,
                 random_seed):

        logging.info("Initializing Federated Learning class...")

        self.hook = sy.TorchHook(torch)
        use_cuda = False
        self.kwargs = {
            'num_workers': 1,
            'pin_memory': True
        } if use_cuda else {}
        self.device = torch.device("cuda" if use_cuda else "cpu")
        torch.manual_seed(random_seed)

        self.workers = dict()
        # self.workers_model = dict()
        # self.server = None
        # self.server_model = None
        # self.batch_size = batch_size
        # self.test_batch_size = test_batch_size
        # self.lr = lr
        # self.momentum = momentum
        # self.log_interval = log_interval
        # self.weight_decay = wieght_decay
        self.seed = random_seed
        self.log_enable = log_enable
        self.neptune_enable = neptune_enable
        self.log_file_path = output_dir

    def create_workers(self, workers_id_list):
        logging.info("Creating workers...")
        for worker_id in workers_id_list:
            if worker_id not in self.workers:
                logging.debug("Creating the worker: {}".format(worker_id))
                self.workers[worker_id] = sy.VirtualWorker(self.hook,
                                                           id=worker_id)
            else:
                logging.debug(
                    "Worker {} exists. Skip creating this worker".format(
                        worker_id))

    # def create_server(self):
    #     logging.info("Creating the server...")
    #     self.server = sy.VirtualWorker(self.hook, id="server")

    def create_server_model(self):
        logging.info("Creating a model for the server...")
        self.server_model = FLNet().to(self.device)

    def create_model(self):
        logging.info("Creating a model...")
        return FLNet().to(self.device)

    def create_workers_model(self, selected_workers_id):
        logging.info("Creating a model for {} worker(s)...".format(
            len(selected_workers_id)))
        for worker_id in selected_workers_id:
            if worker_id not in self.workers_model:
                logging.debug(
                    "Creating a (copy) model of server for worker {}".format(
                        worker_id))
                self.workers_model[worker_id] = deepcopy(self.server_model)
            else:
                logging.debug(
                    "The model for worker {} exists".format(worker_id))

    ############################ MNIST RELATED FUNCS ###############################

    def create_federated_mnist(self, dataset, destination_idx, batch_size,
                               shuffle):
        """ 

        Args:
            dataset (FLCustomDataset): Dataset to be federated
            destination_idx (list[str]): Path to the config file
        Returns:
            Obj: Corresponding python object
        """
        workers = []
        if "server" in destination_idx:
            workers.append(self.server)
        else:
            for worker_id, worker in self.workers.items():
                worker_id in destination_idx and workers.append(worker)

        fed_dataloader = sy.FederatedDataLoader(dataset.federate(workers),
                                                batch_size=batch_size,
                                                shuffle=shuffle,
                                                drop_last=True)

        return fed_dataloader

    def create_mnist_fed_datasets(self, raw_dataset):
        """
        raw_datasets (dict)
        ex.
            data: raw_datasets['worker_1']['x']
            label: raw_datasets['worker_1']['y']
        """
        fed_datasets = dict()

        for ww_id, ww_data in raw_dataset.items():
            images = tensor(ww_data['x'], dtype=float32)
            labels = tensor(ww_data['y'].ravel(), dtype=int64)
            dataset = sy.BaseDataset(images,
                                     labels,
                                     transform=transforms.Compose([
                                         transforms.ToTensor(),
                                         transforms.Normalize(
                                             (ww_data['x'].mean(), ),
                                             (ww_data['x'].std(), ))
                                     ])).federate([self.workers[ww_id]])
            fed_datasets[ww_id] = dataset

        return fed_datasets

    ############################ FEMNIST RELATED FUNCS ###############################

    def create_femnist_dataset(self,
                               raw_data,
                               workers_idx,
                               shuffle=True,
                               drop_last=True):
        """ 

        Args:
            raw_data (dict of str): dict contains processed train and test data categorized based on user id
                # raw_data['f0_12345']['x'], raw_data['f0_12345']['y'] 
        Returns:
            Dataloader for the server
        """
        logging.info("Creating 1 test dataset from {} workers".format(
            len(workers_idx)))
        # raw_data = utils.extract_data(raw_data, workers_idx)
        server_images = np.array([], dtype=np.float32).reshape(-1, 28, 28)
        server_labels = np.array([], dtype=np.int64)

        for worker_id in workers_idx:
            images = np.array(raw_data[worker_id]['x'],
                              dtype=np.float32).reshape(-1, 28, 28)
            labels = np.array(raw_data[worker_id]['x'], dtype=np.int64).ravel()
            server_images = np.concatenate((server_images, images))
            server_labels = np.concatenate((server_labels, labels))

        test_dataset = FLCustomDataset(server_images,
                                       server_labels,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize(
                                               (server_images.mean(), ),
                                               (server_images.std(), ))
                                       ]))

        return test_dataset

    def create_femnist_fed_dataset(self, raw_data, workers_idx, percentage):
        """ 
        Assume this only used for preparing aggregated dataset for the server
        Args:
            raw_data (dict): 
            workers_idx (list(int)): 
            percentage (float): Out of 100, amount of public data of each user
        Returns:
        """
        logging.info(
            "Creating the dataset from {}% of {} selected users' data...".
            format(percentage, len(workers_idx)))
        # Fraction of public data of each user, which be shared by the server
        server_images = tensor([], dtype=float32).view(-1, 28, 28)
        server_labels = tensor([], dtype=int64)
        # server_images = np.array([], dtype = np.float32).reshape(-1, 28, 28)
        # server_labels = np.array([], dtype = np.int64)
        for worker_id in workers_idx:
            worker_samples_num = len(raw_data[worker_id]['y'])
            num_samples_for_server = math.floor(
                (percentage / 100.0) * worker_samples_num)
            logging.debug(
                "Sending {} samples from worker {} with total {}".format(
                    num_samples_for_server, worker_id, worker_samples_num))
            indices = sample(range(worker_samples_num), num_samples_for_server)
            images = tensor([raw_data[worker_id]['x'][i] for i in indices],
                            dtype=float32).view(-1, 28, 28)
            labels = tensor([raw_data[worker_id]['y'][i] for i in indices],
                            dtype=int64).view(-1)
            server_images = cat((server_images, images))
            server_labels = cat((server_labels, labels))

        logging.info(
            "Selected {} samples in total for the server from {} users.".
            format(server_images.shape, len(workers_idx)))

        return sy.BaseDataset(server_images,
                              server_labels,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),
                                  transforms.Normalize(
                                      (server_images.mean().item(), ),
                                      (server_images.std().item(), ))
                              ])).federate([self.server])

    def create_femnist_fed_datasets(self, raw_dataset, workers_idx):
        fed_datasets = dict()

        for worker_id in workers_idx:
            images = tensor(raw_dataset[worker_id]['x'], dtype=float32)
            labels = tensor(raw_dataset[worker_id]['y'].ravel(), dtype=int64)
            dataset = sy.BaseDataset(
                images,
                labels,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(
                        (raw_dataset[worker_id]['x'].mean(), ),
                        (raw_dataset[worker_id]['x'].std(), ))
                ])).federate([self.workers[worker_id]])
            fed_datasets[worker_id] = dataset

        return fed_datasets

    def create_femnist_datasets(self, raw_dataset, workers_idx):
        datasets = dict()

        for worker_id in workers_idx:
            images = tensor(raw_dataset[worker_id]['x'], dtype=float32)
            labels = tensor(raw_dataset[worker_id]['y'].ravel(), dtype=int64)
            dataset = sy.BaseDataset(
                images,
                labels,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(
                        (raw_dataset[worker_id]['x'].mean(), ),
                        (raw_dataset[worker_id]['x'].std(), ))
                ]))
            datasets[worker_id] = dataset

        return datasets

    ############################ GENERAL FUNC ################################

    def send_model(self, model, location, location_id):
        if isinstance(model, dict):
            for ww_id, ww in model.items():
                if ww.location is None:
                    model.send(location)
                elif ww.location.id != location_id:
                    model.move(location)
        elif model.location is None:
            model.send(location)
        elif model.location.id != location_id:
            model.move(location)

    def getback_model(self, model, selected_client_ids=None):
        if isinstance(model, dict):
            if selected_client_ids is not None:
                for ww_id in selected_client_ids:
                    if model[ww_id].location is not None:
                        model[ww_id].get()
            else:
                for ww_id, ww in model.items():
                    if ww.location is not None:
                        ww.get()
        elif model.location is not None:
            model.get()

    def train_server(self, server_dataloader, round_no, epochs_num):
        self.send_model(self.server_model, self.server, "server")
        server_opt = optim.SGD(self.server_model.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)
        for epoch_no in range(epochs_num):
            for batch_idx, (data, target) in enumerate(server_dataloader):
                self.server_model.train()
                data, target = data.to(self.device), target.to(self.device)
                server_opt.zero_grad()
                output = self.server_model(data)
                loss = F.nll_loss(output, target)
                loss.backward()
                server_opt.step()

                if batch_idx % self.log_interval == 0:
                    loss = loss.get()
                    if self.neptune_enable:
                        neptune.log_metric('train_w0_loss', loss)
                    if self.log_enable:
                        file = open(self.log_file_path + "server_train", "a")
                        TO_FILE = '{} {} {} [server] {}\n'.format(
                            round_no, epoch_no, batch_idx, loss)
                        file.write(TO_FILE)
                        file.close()
                    logging.info(
                        'Train Round: {}, Epoch: {} [server] [{}: {}/{} ({:.0f}%)]\tLoss: {:.6f}'
                        .format(
                            round_no, epoch_no, batch_idx,
                            batch_idx * server_dataloader.batch_size,
                            len(server_dataloader) *
                            server_dataloader.batch_size,
                            100. * batch_idx / len(server_dataloader),
                            loss.item()))
        # Always need to get back the model
        # self.getback_model(self.server_model)
        print()

    def train_workers(self, federated_train_loader, workers_model, round_no,
                      epochs_num):
        workers_opt = {}
        for ww_id, ww_model in workers_model.items():
            if ww_model.location is None \
                    or ww_model.location.id != ww_id:
                ww_model.send(self.workers[ww_id])
            workers_opt[ww_id] = optim.SGD(params=ww_model.parameters(),
                                           lr=self.lr,
                                           weight_decay=self.weight_decay)

        for epoch_no in range(epochs_num):
            for ww_id, fed_dataloader in federated_train_loader.items():
                if ww_id in workers_model.keys():
                    for batch_idx, (data, target) in enumerate(fed_dataloader):
                        worker_id = data.location.id
                        worker_opt = workers_opt[worker_id]
                        workers_model[worker_id].train()
                        data, target = data.to(self.device), target.to(
                            self.device)
                        worker_opt.zero_grad()
                        output = workers_model[worker_id](data)
                        loss = F.nll_loss(output, target)
                        loss.backward()
                        worker_opt.step()

                        if batch_idx % self.log_interval == 0:
                            loss = loss.get()
                            if self.neptune_enable:
                                neptune.log_metric(
                                    "train_loss_" + str(worker_id), loss)
                            if self.log_enable:
                                file = open(
                                    self.log_file_path + str(worker_id) +
                                    "_train", "a")
                                TO_FILE = '{} {} {} {} {}\n'.format(
                                    round_no, epoch_no, batch_idx, worker_id,
                                    loss)
                                file.write(TO_FILE)
                                file.close()
                            logging.info(
                                'Train Round: {}, Epoch: {} [{}] [{}: {}/{} ({:.0f}%)]\tLoss: {:.6f}'
                                .format(
                                    round_no, epoch_no, worker_id, batch_idx,
                                    batch_idx * fed_dataloader.batch_size,
                                    len(fed_dataloader) *
                                    fed_dataloader.batch_size,
                                    100. * batch_idx / len(fed_dataloader),
                                    loss.item()))
        print()

    def save_workers_model(self, workers_idx, round_no):
        self.getback_model(self.workers_model, workers_idx)
        logging.info("Saving models {}".format(workers_idx))
        for worker_id, worker_model in self.workers_model.items():
            if worker_id in workers_idx:
                self.save_model(worker_model,
                                "R{}_{}".format(round_no, worker_id))

    def save_model(self, model, name):
        parent_dir = "{}{}".format(self.log_file_path, "models")
        if not os.path.isdir(parent_dir):
            logging.debug("Create a directory for model(s).")
            os.mkdir(parent_dir)
        full_path = "{}/{}".format(parent_dir, name)
        logging.debug("Saving the model into " + full_path)
        torch.save(model, full_path)

    def test(self, model, test_loader, worker_id, round_no):
        self.getback_model(model)
        model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(self.device), target.to(
                    self.device, dtype=torch.int64)
                output = model(data)
                test_loss += F.nll_loss(
                    output, target,
                    reduction='sum').item()  # sum up batch loss
                pred = output.argmax(
                    1,
                    keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        test_acc = 100. * correct / len(test_loader.dataset)

        if self.neptune_enable:
            neptune.log_metric("test_loss_" + str(worker_id), test_loss)
            neptune.log_metric("test_acc_" + str(worker_id), test_acc)
        if self.log_enable:
            file = open(self.log_file_path + str(worker_id) + "_test", "a")
            TO_FILE = '{} {} "{{/*Accuracy:}}\\n{}%" {}\n'.format(
                round_no, test_loss, test_acc, test_acc)
            file.write(TO_FILE)
            file.close()

        logging.info(
            'Test Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
                test_loss, correct, len(test_loader.dataset), test_acc))
        return test_acc

    def wieghted_avg_model(self, W, workers_model):
        self.getback_model(workers_model)
        tmp_model = FLNet().to(self.device)
        with torch.no_grad():
            tmp_model.conv1.weight.data.fill_(0)
            tmp_model.conv1.bias.data.fill_(0)
            tmp_model.conv2.weight.data.fill_(0)
            tmp_model.conv2.bias.data.fill_(0)
            tmp_model.fc1.weight.data.fill_(0)
            tmp_model.fc1.bias.data.fill_(0)
            tmp_model.fc2.weight.data.fill_(0)
            tmp_model.fc2.bias.data.fill_(0)

            for counter, (ww_id,
                          worker_model) in enumerate(workers_model.items()):
                tmp_model.conv1.weight.data = (
                    tmp_model.conv1.weight.data +
                    W[counter] * worker_model.conv1.weight.data)
                tmp_model.conv1.bias.data = (
                    tmp_model.conv1.bias.data +
                    W[counter] * worker_model.conv1.bias.data)
                tmp_model.conv2.weight.data = (
                    tmp_model.conv2.weight.data +
                    W[counter] * worker_model.conv2.weight.data)
                tmp_model.conv2.bias.data = (
                    tmp_model.conv2.bias.data +
                    W[counter] * worker_model.conv2.bias.data)
                tmp_model.fc1.weight.data = (
                    tmp_model.fc1.weight.data +
                    W[counter] * worker_model.fc1.weight.data)
                tmp_model.fc1.bias.data = (
                    tmp_model.fc1.bias.data +
                    W[counter] * worker_model.fc1.bias.data)
                tmp_model.fc2.weight.data = (
                    tmp_model.fc2.weight.data +
                    W[counter] * worker_model.fc2.weight.data)
                tmp_model.fc2.bias.data = (
                    tmp_model.fc2.bias.data +
                    W[counter] * worker_model.fc2.bias.data)

        return tmp_model

    def update_models(self, workers_idx, weighted_avg_model):
        self.getback_model(weighted_avg_model)
        with torch.no_grad():
            for worker_id in workers_idx:
                model = None
                if worker_id == "server":
                    self.getback_model(self.server_model)
                    self.server_model.conv1.weight.set_(
                        weighted_avg_model.conv1.weight.data)
                    self.server_model.conv1.bias.set_(
                        weighted_avg_model.conv1.bias.data)
                    self.server_model.conv2.weight.set_(
                        weighted_avg_model.conv2.weight.data)
                    self.server_model.conv2.bias.set_(
                        weighted_avg_model.conv2.bias.data)
                    self.server_model.fc1.weight.set_(
                        weighted_avg_model.fc1.weight.data)
                    self.server_model.fc1.bias.set_(
                        weighted_avg_model.fc1.bias.data)
                    self.server_model.fc2.weight.set_(
                        weighted_avg_model.fc2.weight.data)
                    self.server_model.fc2.bias.set_(
                        weighted_avg_model.fc2.bias.data)
                else:
                    self.getback_model(self.workers_model[worker_id])
                    self.workers_model[worker_id].conv1.weight.set_(
                        weighted_avg_model.conv1.weight.data)
                    self.workers_model[worker_id].conv1.bias.set_(
                        weighted_avg_model.conv1.bias.data)
                    self.workers_model[worker_id].conv2.weight.set_(
                        weighted_avg_model.conv2.weight.data)
                    self.workers_model[worker_id].conv2.bias.set_(
                        weighted_avg_model.conv2.bias.data)
                    self.workers_model[worker_id].fc1.weight.set_(
                        weighted_avg_model.fc1.weight.data)
                    self.workers_model[worker_id].fc1.bias.set_(
                        weighted_avg_model.fc1.bias.data)
                    self.workers_model[worker_id].fc2.weight.set_(
                        weighted_avg_model.fc2.weight.data)
                    self.workers_model[worker_id].fc2.bias.set_(
                        weighted_avg_model.fc2.bias.data)

    def normalize_weights(self, list_of_ids, **kwargs):
        self.getback_model(self.workers_model, list_of_ids)
        w0_model = None
        for model_id in kwargs:
            if model_id == "w0_model":
                w0_model = kwargs[model_id]

        workers_params = {}
        for worker_id in list_of_ids:
            worker_model = self.workers_model[worker_id]
            self.getback_model(worker_model)

            workers_params[worker_id] = [[] for i in range(8)]
            for layer_id, param in enumerate(worker_model.parameters()):
                workers_params[worker_id][layer_id] = param.data.numpy(
                ).reshape(-1, 1)

        if w0_model is not None:
            workers_params['w0_model'] = [[] for i in range(8)]
            for layer_id, param in enumerate(w0_model.parameters()):
                workers_params['w0_model'][layer_id] = param.data.numpy(
                ).reshape(-1, 1)

        workers_all_params = []
        for ii in range(8):
            workers_all_params.append(
                np.array([]).reshape(
                    workers_params[list_of_ids[0]][ii].shape[0], 0))
            logging.debug("all_dparams: {}".format(
                workers_all_params[ii].shape))

        for worker_id, worker_model in workers_params.items():
            workers_all_params[0] = np.concatenate(
                (workers_all_params[0], workers_params[worker_id][0]), 1)
            workers_all_params[1] = np.concatenate(
                (workers_all_params[1], workers_params[worker_id][1]), 1)
            workers_all_params[2] = np.concatenate(
                (workers_all_params[2], workers_params[worker_id][2]), 1)
            workers_all_params[3] = np.concatenate(
                (workers_all_params[3], workers_params[worker_id][3]), 1)
            workers_all_params[4] = np.concatenate(
                (workers_all_params[4], workers_params[worker_id][4]), 1)
            workers_all_params[5] = np.concatenate(
                (workers_all_params[5], workers_params[worker_id][5]), 1)
            workers_all_params[6] = np.concatenate(
                (workers_all_params[6], workers_params[worker_id][6]), 1)
            workers_all_params[7] = np.concatenate(
                (workers_all_params[7], workers_params[worker_id][7]), 1)

        normalized_workers_all_params = []
        for ii in range(len(workers_all_params)):
            norm = MinMaxScaler().fit(workers_all_params[ii])
            normalized_workers_all_params.append(
                norm.transform(workers_all_params[ii]))

        return normalized_workers_all_params

    def find_best_weights(self, referenced_model, workers_to_be_used):

        # last column of normalized_weights is corresponding to the w0_model:
        normalized_weights = self.normalize_weights(workers_to_be_used,
                                                    w0_model=referenced_model)

        reference_layer = []
        workers_all_params = []
        for ii in range(len(normalized_weights)):
            reference_layer.append(normalized_weights[ii][:,
                                                          -1].reshape(-1, 1))
            workers_all_params.append(
                normalized_weights[ii][:, :normalized_weights[ii].shape[1] -
                                       1])

        reference_layers = []
        for ii in range(len(reference_layer)):
            tmp = np.array([]).reshape(reference_layer[ii].shape[0], 0)
            for jj in range(len(workers_to_be_used)):
                tmp = np.concatenate((tmp, reference_layer[ii]), axis=1)
            reference_layers.append(tmp)

        W = cp.Variable(len(workers_to_be_used))
        objective = cp.Minimize(
            cp.matmul(
                cp.norm2(workers_all_params[0] -
                         reference_layers[0], axis=0), W) +
            cp.matmul(
                cp.norm2(workers_all_params[1] -
                         reference_layers[1], axis=0), W) +
            cp.matmul(
                cp.norm2(workers_all_params[2] -
                         reference_layers[2], axis=0), W) +
            cp.matmul(
                cp.norm2(workers_all_params[3] -
                         reference_layers[3], axis=0), W) +
            cp.matmul(
                cp.norm2(workers_all_params[4] -
                         reference_layers[4], axis=0), W) +
            cp.matmul(
                cp.norm2(workers_all_params[5] -
                         reference_layers[5], axis=0), W) +
            cp.matmul(
                cp.norm2(workers_all_params[6] -
                         reference_layers[6], axis=0), W) +
            cp.matmul(
                cp.norm2(workers_all_params[7] -
                         reference_layers[7], axis=0), W))

        # for i in range(len(workers_all_params)):
        #     logging.debug("Mean [{}]: {}".format(i, np.round(np.mean(workers_all_params[i],0) - np.mean(reference_layers[i],0),6)))
        #     logging.debug("")

        constraints = [0 <= W, W <= 1, sum(W) == 1]
        prob = cp.Problem(objective, constraints)
        result = prob.solve(solver=cp.MOSEK)
        logging.info(W.value)
        logging.info("")
        if self.log_enable:
            file = open(self.log_file_path + "opt_weights", "a")
            TO_FILE = '{}\n'.format(np.array2string(W.value).replace('\n', ''))
            file.write(TO_FILE)
            file.close()
        return W.value

########################################################################
##################### Trusted Users ####################################

    def get_average_param(self, all_params, indexes):
        """ Find the average of all parameters of all layers of given workers
        Args:
        Returns:
            
        """
        logging.info("Finding an average model for {} workers.".format(
            len(indexes)))
        avg_param = []
        for ii in range(len(all_params)):
            params = np.transpose(
                np.array([all_params[ii][:, i] for i in indexes]))
            logging.debug("params[{}] shape: {}".format(ii, params.shape))
            avg_param.append(params.mean(axis=1))

        return avg_param

    def get_index_number(self, workers_idx, selected_idx):
        positions = []
        for ii, jj in enumerate(workers_idx):
            if jj in selected_idx:
                positions.append(ii)
        return positions

    def find_best_weights_from_trusted_idx_normalized_last_layer(
            self, workers_idx, trusted_idx):
        """
        Args:
            workers_idx (list[str])
            trusted_idx (list[str])
        """
        trusted_workers_position = self.get_index_number(
            workers_idx, trusted_idx)
        all_params = self.normalize_weights(workers_idx)
        """
        len(all_params) = 8
        all_params[0].shape = (num of elements in layer 0 of cnn, num of users)
        all_params[1].shape = (num of elements in layer 1 of cnn, num of users)
        """
        avg_param = self.get_average_param(all_params,
                                           trusted_workers_position)

        # Not trusted users (i.e. Normal users + attackers)
        workers_to_be_used = list(set(workers_idx) - set(trusted_idx))
        workers_to_be_used_position = self.get_index_number(
            workers_idx, workers_to_be_used)

        workers_all_params = []
        for ii in range(len(all_params)):
            workers_all_params.append(
                np.transpose(
                    np.array([
                        all_params[ii][:, i]
                        for i in workers_to_be_used_position
                    ])))

        W = cp.Variable(len(workers_to_be_used))

        objective = cp.Minimize(
            cp.norm2(cp.matmul(workers_all_params[7], W) - avg_param[7]))

        constraints = [0 <= W, W <= 1, sum(W) == 1]
        prob = cp.Problem(objective, constraints)
        result = prob.solve(solver=cp.MOSEK)
        logging.info(W.value)
        logging.info("")
        if self.log_enable:
            file = open(self.log_file_path + "opt_weights", "a")
            TO_FILE = '{}\n'.format(np.array2string(W.value).replace('\n', ''))
            file.write(TO_FILE)
            file.close()
        return W.value

    def get_average_model(self, workers_idx):
        """ Find the average of all parameters of all layers of given workers. The average for each 
            layer is stored as an element of a list.

        Args:
            workers_idx (list[str]): List of workers ids
        Returns:
            
        """
        logging.info("Finding an average model for {} workers.".format(
            len(workers_idx)))
        tmp_model = FLNet().to(self.device)
        self.getback_model(self.workers_model, workers_idx)

        with torch.no_grad():
            for id_, ww_id in enumerate(workers_idx):
                worker_model = self.workers_model[ww_id]
                if id_ == 0:
                    tmp_model.conv1.weight.set_(worker_model.conv1.weight.data)
                    tmp_model.conv1.bias.set_(worker_model.conv1.bias.data)
                    tmp_model.conv2.weight.set_(worker_model.conv2.weight.data)
                    tmp_model.conv2.bias.set_(worker_model.conv2.bias.data)
                    tmp_model.fc1.weight.set_(worker_model.fc1.weight.data)
                    tmp_model.fc1.bias.set_(worker_model.fc1.bias.data)
                    tmp_model.fc2.weight.set_(worker_model.fc2.weight.data)
                    tmp_model.fc2.bias.set_(worker_model.fc2.bias.data)
                else:
                    tmp_model.conv1.weight.set_(tmp_model.conv1.weight.data +
                                                worker_model.conv1.weight.data)
                    tmp_model.conv1.bias.set_(tmp_model.conv1.bias.data +
                                              worker_model.conv1.bias.data)
                    tmp_model.conv2.weight.set_(tmp_model.conv2.weight.data +
                                                worker_model.conv2.weight.data)
                    tmp_model.conv2.bias.set_(tmp_model.conv2.bias.data +
                                              worker_model.conv2.bias.data)
                    tmp_model.fc1.weight.set_(tmp_model.fc1.weight.data +
                                              worker_model.fc1.weight.data)
                    tmp_model.fc1.bias.set_(tmp_model.fc1.bias.data +
                                            worker_model.fc1.bias.data)
                    tmp_model.fc2.weight.set_(tmp_model.fc2.weight.data +
                                              worker_model.fc2.weight.data)
                    tmp_model.fc2.bias.set_(tmp_model.fc2.bias.data +
                                            worker_model.fc2.bias.data)

        for param in tmp_model.parameters():
            param.data = param.data / len(workers_idx)

        return tmp_model