Example #1
0
    def get_average_model(self, workers_idx):
        """ Find the average of all parameters of all layers of given workers. The average for each 
            layer is stored as an element of a list.

        Args:
            workers_idx (list[str]): List of workers ids
        Returns:
            
        """
        logging.info("Finding an average model for {} workers.".format(
            len(workers_idx)))
        tmp_model = FLNet().to(self.device)
        self.getback_model(self.workers_model, workers_idx)

        with torch.no_grad():
            for id_, ww_id in enumerate(workers_idx):
                worker_model = self.workers_model[ww_id]
                if id_ == 0:
                    tmp_model.conv1.weight.set_(worker_model.conv1.weight.data)
                    tmp_model.conv1.bias.set_(worker_model.conv1.bias.data)
                    tmp_model.conv2.weight.set_(worker_model.conv2.weight.data)
                    tmp_model.conv2.bias.set_(worker_model.conv2.bias.data)
                    tmp_model.fc1.weight.set_(worker_model.fc1.weight.data)
                    tmp_model.fc1.bias.set_(worker_model.fc1.bias.data)
                    tmp_model.fc2.weight.set_(worker_model.fc2.weight.data)
                    tmp_model.fc2.bias.set_(worker_model.fc2.bias.data)
                else:
                    tmp_model.conv1.weight.set_(tmp_model.conv1.weight.data +
                                                worker_model.conv1.weight.data)
                    tmp_model.conv1.bias.set_(tmp_model.conv1.bias.data +
                                              worker_model.conv1.bias.data)
                    tmp_model.conv2.weight.set_(tmp_model.conv2.weight.data +
                                                worker_model.conv2.weight.data)
                    tmp_model.conv2.bias.set_(tmp_model.conv2.bias.data +
                                              worker_model.conv2.bias.data)
                    tmp_model.fc1.weight.set_(tmp_model.fc1.weight.data +
                                              worker_model.fc1.weight.data)
                    tmp_model.fc1.bias.set_(tmp_model.fc1.bias.data +
                                            worker_model.fc1.bias.data)
                    tmp_model.fc2.weight.set_(tmp_model.fc2.weight.data +
                                              worker_model.fc2.weight.data)
                    tmp_model.fc2.bias.set_(tmp_model.fc2.bias.data +
                                            worker_model.fc2.bias.data)

        for param in tmp_model.parameters():
            param.data = param.data / len(workers_idx)

        return tmp_model
Example #2
0
    def wieghted_avg_model(self, W, workers_model):
        self.getback_model(workers_model)
        tmp_model = FLNet().to(self.device)
        with torch.no_grad():
            tmp_model.conv1.weight.data.fill_(0)
            tmp_model.conv1.bias.data.fill_(0)
            tmp_model.conv2.weight.data.fill_(0)
            tmp_model.conv2.bias.data.fill_(0)
            tmp_model.fc1.weight.data.fill_(0)
            tmp_model.fc1.bias.data.fill_(0)
            tmp_model.fc2.weight.data.fill_(0)
            tmp_model.fc2.bias.data.fill_(0)

            for counter, (ww_id,
                          worker_model) in enumerate(workers_model.items()):
                tmp_model.conv1.weight.data = (
                    tmp_model.conv1.weight.data +
                    W[counter] * worker_model.conv1.weight.data)
                tmp_model.conv1.bias.data = (
                    tmp_model.conv1.bias.data +
                    W[counter] * worker_model.conv1.bias.data)
                tmp_model.conv2.weight.data = (
                    tmp_model.conv2.weight.data +
                    W[counter] * worker_model.conv2.weight.data)
                tmp_model.conv2.bias.data = (
                    tmp_model.conv2.bias.data +
                    W[counter] * worker_model.conv2.bias.data)
                tmp_model.fc1.weight.data = (
                    tmp_model.fc1.weight.data +
                    W[counter] * worker_model.fc1.weight.data)
                tmp_model.fc1.bias.data = (
                    tmp_model.fc1.bias.data +
                    W[counter] * worker_model.fc1.bias.data)
                tmp_model.fc2.weight.data = (
                    tmp_model.fc2.weight.data +
                    W[counter] * worker_model.fc2.weight.data)
                tmp_model.fc2.bias.data = (
                    tmp_model.fc2.bias.data +
                    W[counter] * worker_model.fc2.bias.data)

        return tmp_model
Example #3
0
def train_workers_with_attack(federated_train_loader, models, workers_idx,
                              attackers_idx, round_no, args):
    attackers_here = [ii for ii in workers_idx if ii in attackers_idx]
    workers_opt = dict()
    workers_loss = data = defaultdict(lambda: [])
    for ii in workers_idx:
        workers_opt[ii] = torch.optim.SGD(params=models[ii].parameters(),
                                          lr=args.lr,
                                          weight_decay=args.weight_decay)
    with tqdm(total=args.epochs,
              leave=False,
              colour="yellow",
              ncols=80,
              desc="Epoch\t",
              bar_format=TQDM_R_BAR) as t2:
        for epoch in range(args.epochs):
            t2.set_postfix(Rounds=round_no, Epochs=epoch)
            with tqdm(total=len(workers_idx),
                      ncols=80,
                      desc="Workers\t",
                      leave=False,
                      bar_format=TQDM_R_BAR) as t3:
                t3.set_postfix(ordered_dict={
                    'ATK':
                    "{}/{}".format(len(attackers_here), len(workers_idx))
                })
                for ww_id, fed_dataloader in federated_train_loader.items():
                    if ww_id in workers_idx:
                        with tqdm(total=len(fed_dataloader),
                                  ncols=80,
                                  colour='red',
                                  desc="Batch\t",
                                  leave=False,
                                  bar_format=TQDM_R_BAR) as t4:
                            for batch_idx, (
                                    data, target) in enumerate(fed_dataloader):
                                ww = data.location
                                model = models[ww.id]
                                data, target = data.to("cpu"), target.to("cpu")
                                if ww.id in attackers_idx:
                                    if args.attack_type == 1:
                                        models[ww.id] = FLNet().to(args.device)
                                    elif args.attack_type == 2:
                                        ss = utils.negative_parameters(
                                            models[ww.id].state_dict())
                                        models[ww.id].load_state_dict(ss)
                                    t4.set_postfix(
                                        ordered_dict={
                                            'Worker':
                                            ww.id,
                                            'ATK':
                                            "[T]" if ww.id in
                                            attackers_idx else "[F]",
                                            'BatchID':
                                            batch_idx,
                                            'Loss':
                                            '-'
                                        })
                                    #TODO: Be careful about the break
                                    break
                                else:
                                    model.train()
                                    model.send(ww.id)
                                    opt = workers_opt[ww.id]
                                    opt.zero_grad()
                                    output = model(data)
                                    loss = F.nll_loss(output, target)
                                    loss.backward()
                                    opt.step()
                                    model.get()  # <-- NEW: get the model back
                                    loss = loss.get(
                                    )  # <-- NEW: get the loss back
                                    workers_loss[ww.id].append(loss.item())
                                    t4.set_postfix(
                                        ordered_dict={
                                            'Worker':
                                            ww.id,
                                            'ATK':
                                            "[T]" if ww.id in
                                            attackers_idx else "[F]",
                                            'BatchID':
                                            batch_idx,
                                            'Loss':
                                            loss.item()
                                        })
                                t4.update()
                        t3.update()
            t2.update()

    # Mean per worker
    for ii in workers_loss:
        workers_loss[ii] = sum(workers_loss[ii]) / len(workers_loss[ii])
    return workers_loss
Example #4
0
def main(start_round):
    logging.info("Total number of users: {}".format(args.total_users_num))
    workers_idx = ["worker_" + str(i) for i in range(args.total_users_num)]
    workers = create_workers(hook, workers_idx)
    server = create_workers(hook, ['server'])
    server = server['server']
    # if args.local_log:
    #     utils.check_write_to_file(args.log_dir, "all_users", workers_idx)

    attackers_idx = None
    if utils.find_file(args.log_dir, "attackers"):
        logging.info("attackers list was found. Loading from file...")
        attackers_idx = utils.load_object(args.log_dir, "attackers")
    else:
        logging.error("This should not be happened in this study.")
        exit(1)
    #     attackers_idx = utils.get_workers_idx(workers_idx, args.attackers_num, [])
    #     if args.local_log:
    #         utils.save_object(args.log_dir, "attackers", attackers_idx)

    mapped_datasets = dict()
    if utils.find_file(args.log_dir, "mapped_datasets"):
        logging.info("mapped_datasets was found. Loading from file...")
        mapped_datasets = utils.load_object(args.log_dir, "mapped_datasets")
    else:
        logging.error("This should not be happened in this study.")
        exit(1)
        # Now sort the dataset and distribute among users
        # mapped_ds_itr = utils.map_shards_to_worker(
        #     utils.split_randomly_dataset(
        #         utils.sort_mnist_dataset(
        #             utils.fraction_of_datasets(
        #                 {"dataset": utils.load_mnist_dataset(
        #                     train=True,
        #                     transform=transforms.Compose([transforms.ToTensor(),]))},
        #                 args.load_fraction, [])
        #         ),
        #         args.shards_num),
        #     workers_idx,
        #     args.shards_per_worker_num)

        # mapping to users and performin attacks
        # for mapped_ds in mapped_ds_itr:
        #     for ww_id, dataset in mapped_ds.items():
        #         if ww_id in attackers_idx:
        #             mapped_datasets.update(
        #                 {ww_id: FLCustomDataset(
        #                     utils.attack_shuffle_pixels(dataset.data),
        #                     dataset.targets,
        #                     transform=transforms.Compose([
        #                         transforms.ToTensor()])
        #                 )}
        #             )
        #         else:
        #             mapped_datasets.update(mapped_ds)

        # if args.local_log:
        #     utils.save_object(args.log_dir, "mapped_datasets", mapped_datasets)

    server_pub_dataset = None
    if utils.find_file(args.log_dir, "server_pub_dataset"):
        logging.info("server_pub_dataset was found. Loading from file...")
        server_pub_dataset = utils.load_object(args.log_dir,
                                               "server_pub_dataset")
    else:
        logging.error("This should not be happened in this study.")
        exit(1)
        # if args.server_pure:
        #     server_pub_dataset = utils.fraction_of_datasets(mapped_datasets, args.server_data_fraction)
        # else:
        #     logging.info("Server data is NOT pure.")
        #     server_pub_dataset = utils.fraction_of_datasets(
        #         mapped_datasets, args.server_data_fraction, attackers_idx)
        # if args.local_log:
        #     utils.save_object(args.log_dir, "server_pub_dataset", server_pub_dataset)

    federated_server_loader = dict()
    federated_server_loader['server'] = sy.FederatedDataLoader(
        server_pub_dataset.federate([server]),
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False)

    federated_train_loader = dict()
    logging.info("Creating federated dataloaders for workers...")
    for ww_id, fed_dataset in mapped_datasets.items():
        federated_train_loader[ww_id] = sy.FederatedDataLoader(
            fed_dataset.federate([workers[ww_id]]),
            batch_size=args.batch_size,
            shuffle=False,
            drop_last=False)

    test_loader = utils.get_dataloader(utils.load_mnist_dataset(
        train=False, transform=transforms.Compose([
            transforms.ToTensor(),
        ])),
                                       args.test_batch_size,
                                       shuffle=True,
                                       drop_last=False)

    previous_round = int(start_round)
    logging.info("Use explicit starting round number: {}".format(start_round))
    # if utils.find_file(args.log_dir, "accuracy"):
    #     previous_round = int(utils.get_last_round_num(args.log_dir, "accuracy"))
    #     logging.info("Previous complete execution was found. Last run is: {}".format(previous_round))

    round_start = previous_round if previous_round == 0 else previous_round + 1
    round_end = round_start + ROUNDS_BREAKDOWN if round_start + ROUNDS_BREAKDOWN < args.rounds else args.rounds

    server_model, server_model_name = FLNet().to(
        args.device), "R{}_server_model".format(previous_round)
    server_model_path = args.log_dir + "models"
    if utils.find_file(server_model_path, server_model_name):
        logging.info(
            "server_model was found. Loading {} from the file...".format(
                server_model_name))
        server_model.load_state_dict(
            torch.load(server_model_path + "/" + server_model_name))
    server_opt = dict()
    server_opt['server'] = torch.optim.SGD(params=server_model.parameters(),
                                           lr=args.lr,
                                           weight_decay=args.weight_decay)
    test_loss, test_acc = 0.0, 0.0
    with tqdm(total=min(ROUNDS_BREAKDOWN, args.rounds - round_start),
              leave=True,
              colour="green",
              ncols=80,
              desc="Round\t",
              bar_format=TQDM_R_BAR) as t1:
        for round_no in range(round_start, round_end):
            workers_to_be_used = random.sample(workers_idx,
                                               args.selected_users_num)
            workers_model = dict()
            for ww_id in workers_to_be_used:
                workers_model[ww_id] = deepcopy(server_model)

            # logging.info("Workers for this round: {}".format(workers_to_be_used))
            if args.local_log:
                utils.save_object(
                    args.log_dir,
                    "R{}_p_pca_workers".format(round_no) if args.server_pure else \
                            "R{}_np_pca_workers".format(round_no) ,
                    workers_to_be_used
                )
            train_loss = train_workers_with_attack(federated_train_loader,
                                                   workers_model,
                                                   workers_to_be_used,
                                                   attackers_idx, round_no,
                                                   args)

            # Find the best weights and update the server model
            weights = dict()
            if args.mode == "avg":
                # Each worker takes two shards of 300 random.samples. Total of 600 random.samples
                # per worker. Total number of random.samples is 60000.
                # weights = [600.0 / 60000] * args.selected_users_num
                for ww_id in workers_to_be_used:
                    weights[ww_id] = 1.0 / args.selected_users_num
                train_loss = sum(train_loss.values()) / len(train_loss)
            elif args.mode == "opt":
                # models should be returned from the workers before calling the following functions:
                # Train server
                train_workers_with_attack(federated_server_loader,
                                          {'server': server_model}, ['server'],
                                          [], round_no, args)
                pass

                weights = utils.find_best_weights_opt(server_model,
                                                      workers_model)

                loss = 0
                for ww in train_loss:
                    loss += weights[ww] * train_loss[ww]
                train_loss = loss
                # if args.local_log:
                #     utils.write_to_file(args.log_dir, "opt_weights", weights, round_no=round_no)

            # logging.info("Update server model in this round...")
            server_model.load_state_dict(
                utils.wieghted_avg_model(weights, workers_model,
                                         workers_to_be_used))

            # Apply the server model to the test dataset
            # logging.info("Starting model evaluation on the test dataset...")
            # test_loss, test_acc = test(server_model, test_loader, round_no, args)

            if args.local_log:
                # utils.write_to_file(args.log_dir, "train_loss", train_loss, round_no=round_no)
                # utils.save_model(
                #     server_model.state_dict(),
                #     "{}/{}".format(args.log_dir, "models"),
                #     "pca_niid_np_{}_{}_server_R{}".format(args.mode, args.attackers_num, round_no) if not args.server_pure else \
                #             "pca_niid_p_{}_{}_server_R{}".format(args.mode, args.attackers_num, round_no)
                # )
                for ww_id, ww_model in workers_model.items():
                    utils.save_model(
                        ww_model.state_dict(),
                        "{}/{}/workers_p_R{}".format(args.log_dir, "models", round_no) \
                            if args.server_pure else \
                                "{}/{}/workers_np_R{}".format(args.log_dir, "models", round_no),
                        "{}_model".format(ww_id)
                    )
            if args.neptune_log:
                neptune.log_metric("train_loss", train_loss)

            print()
            logging.info('Test Average loss: {:.4f}, Accuracy: {:.0f}%'.format(
                test_loss, test_acc))
            print()

            t1.set_postfix(test_acc=test_acc, test_loss=test_loss)
            t1.update()
    return round_end
Example #5
0
 def create_model(self):
     logging.info("Creating a model...")
     return FLNet().to(self.device)
Example #6
0
 def create_server_model(self):
     logging.info("Creating a model for the server...")
     self.server_model = FLNet().to(self.device)