Exemplo n.º 1
0
def get_dataset(args):
    """
    Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """
    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        train_dataset = datasets.CIFAR10(data_dir,
                                         train=True,
                                         download=True,
                                         transform=apply_transform)

        test_dataset = datasets.CIFAR10(data_dir,
                                        train=False,
                                        download=True,
                                        transform=apply_transform)

        # sample training data amongst users
        if args.iid:  # Sample IID user data from CIFAR
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:  # Sample non-IID user data from CIFAR
            user_groups = cifar_noniid(train_dataset, args.num_users)

    elif args.dataset == 'mnist' or 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'

        apply_transform = transforms.Compose([
            # transforms.Lambda(lambda x: AddGaussianNoise(x)),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])
        train_dataset = datasets.MNIST(data_dir,
                                       train=True,
                                       download=True,
                                       transform=apply_transform)
        test_dataset = datasets.MNIST(data_dir,
                                      train=False,
                                      download=True,
                                      transform=apply_transform)

        if args.iid:  # sample training data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:  # Sample Non-IID user data from Mnist
            user_groups = mnist_noniid(train_dataset, args.num_users)

    return train_dataset, test_dataset, user_groups
Exemplo n.º 2
0
def getdataset(args, conf_dict):
    N_parties = conf_dict["N_parties"]
    N_samples_per_class = conf_dict["N_samples_per_class"]

    public_classes = conf_dict["public_classes"]
    private_classes = conf_dict["private_classes"]

    # 载入数据集
    if args.dataset == 'cifar':
        X_train_public, y_train_public, X_test_public, y_test_public = get_dataarray(
            args, dataset='cifar10')
        public_dataset = {"X": X_train_public, "y": y_train_public}

        X_train_private, y_train_private, X_test_private, y_test_private\
            = get_dataarray(args,dataset='cifar100')

    elif args.dataset == 'mnist':
        X_train_public, y_train_public, X_test_public, y_test_public = get_dataarray(
            args, dataset='mnist')
        public_dataset = {"X": X_train_public, "y": y_train_public}

        X_train_private, y_train_private, X_test_private, y_test_private \
            = get_dataarray(args,dataset='fmnist')

        y_train_private += len(
            public_classes)  #所有标签数值加上len(public_classes)这一常数
        y_test_private += len(public_classes)

    # only use those data whose y_labels belong to private_classes
    # 采样
    if True:
        X_train_private, y_train_private \
            = generate_partial_data(X=X_train_private, y=y_train_private,       #取[private_class]中的data
                                    class_in_use=private_classes,
                                    verbose=True)

        X_test_private, y_test_private \
            = generate_partial_data(X=X_test_private, y=y_test_private,
                                    class_in_use=private_classes,
                                    verbose=True)

        # relabel the selected private data for future convenience
        for index, cls_ in enumerate(private_classes):
            y_train_private[y_train_private == cls_] = index + len(
                public_classes)  #list的bool值索引
            y_test_private[y_test_private ==
                           cls_] = index + len(public_classes)
        del index, cls_

        print(pd.Series(y_train_private).value_counts())
        mod_private_classes = np.arange(len(private_classes)) + len(
            public_classes
        )  #mod_private_classes=【0~len(private_classes)-1】每个元素加上常数len(public_classes)

        #------------------------10类class---4类model,每类model 5 个client,20个client---N_parties=20---len(private_classes)=10-------10~12----------13~15-------
        ##用一个 for 循环构建 mod_private_classes[0-19]##
        # mod_private_classes = []
        # for idx in range(N_parties):
        #     #mod_private_classes.append(len(public_classes)+idx*(len(private_classes)//5)+np.arange(len(private_classes)//5))
        #     mod_private_classes.append(np.arange(len(private_classes)//5))
        #     last_random_num = -1
        #     for _ in mod_private_classes[idx]:
        #         while True:
        #             mod_private_classes[idx][_]=random.choice(range(10)) + len(public_classes)
        #             if mod_private_classes[idx][_] != last_random_num:   #此处仅能处理这次和上次取到随机数一样的情况,即仅适用于private data取两类
        #                 break
        #         last_random_num = mod_private_classes[idx][_]
        # del idx

        print('mod_private_classes:', mod_private_classes)

        # mod_private_classes1 = np.arange(len(private_classes)//2) + len(public_classes)
        # mod_private_classes2 = np.arange(len(private_classes)//2) + len(public_classes) + len(private_classes)//2
        # print("mod_private_class1:", mod_private_classes1)
        # print("mod_private_class2:", mod_private_classes2)
        #---------------------------------------

        # print("=" * 60)
        # # generate private data
        # private_data, total_private_data \
        #     = generate_bal_private_data(X_train_private, y_train_private,
        #                                 N_parties=N_parties,
        #                                 classes_in_use=mod_private_classes,
        #                                 N_samples_per_class=N_samples_per_class,
        #                                 data_overlap=False)

        print("=" * 60)
        #############利用mod_private_classes进行generate_bal_private_data;for循环##########
        ##########作为non-iid第二种分布方式 要改这里###########
        private_data_tmp = [None] * N_parties
        total_private_data_tmp = {}
        for idx in range(N_parties //
                         4):  ####20个client,4种model,按数据分布分为5组,每组都有4种model###
            private_data_tmp[idx], total_private_data_tmp[idx] \
                = generate_bal_private_data(X_train_private, y_train_private,
                                            N_parties=4,
                                            classes_in_use=mod_private_classes,
                                            N_samples_per_class=idx,
                                            data_overlap=False)
        del idx
        ##########合并list与dict#################
        private_data = []
        #print('private_data_tmp[0]:', private_data_tmp[0])
        for idx in range(N_parties // 4):
            #print('idx:', idx)
            private_data = private_data + private_data_tmp[idx]
            #total_private_data = dict(total_private_data, **total_private_data_tmp[idx])
        del idx

        dicts = []
        total_private_data = {}
        for idx in range(N_parties // 4):
            dicts.append(total_private_data_tmp[idx])
        for d in dicts:
            for k, v in d.items():
                #total_private_data.setdefault(k, []).append(v)
                try:
                    total_private_data.setdefault(k, []).extend(v)
                except TypeError:
                    total_private_data[k].append(v)
        ####################################

        # # generate private data
        # private_data1, total_private_data1 \
        #     = generate_bal_private_data(X_train_private, y_train_private,
        #                                 N_parties=N_parties//2,
        #                                 classes_in_use=mod_private_classes1,
        #                                 N_samples_per_class=N_samples_per_class,
        #                                 data_overlap=False)

        # print("=" * 60)
        # # generate private data
        # private_data2, total_private_data2 \
        #     = generate_bal_private_data(X_train_private, y_train_private,
        #                                 N_parties=N_parties//2,
        #                                 classes_in_use=mod_private_classes2,
        #                                 N_samples_per_class=N_samples_per_class,
        #                                 data_overlap=False)

        # #----------合并两个list-----
        # private_data = private_data1 + private_data2
        # total_private_data = dict(total_private_data1, **total_private_data2)
        # #total_private_data = dict(total_private_data2, **total_private_data1)

###########################################################################################################################
# print("=" * 60)
# X_tmp, y_tmp = generate_partial_data(X=X_test_private, y=y_test_private,    #主要是mod_private_classes:relabel the selected private data
#                                      class_in_use=mod_private_classes,
#                                      verbose=True)
# private_test_data = {"X": X_tmp, "y": y_tmp}

        print("=" * 60)
        ##########################################################
        private_test_data_tmp = {}
        for idx in range(N_parties):
            X_tmp, y_tmp = generate_partial_data(
                X=X_test_private,
                y=y_test_private,  #主要是mod_private_classes:relabel the selected private data 
                class_in_use=mod_private_classes,
                verbose=True)
            private_test_data_tmp[idx] = {
                "X{}".format(idx): X_tmp,
                "y{}".format(idx): y_tmp
            }
        del idx

        dicts_private_test_data = []
        private_test_data = {}
        for idx in range(N_parties):
            dicts_private_test_data.append(private_test_data_tmp[idx])
        for d in dicts_private_test_data:
            for k, v in d.items():
                #private_test_data.setdefault(k, []).append(v)
                try:
                    private_test_data.setdefault(k, []).extend(v)
                except TypeError:
                    private_test_data[k].append(v)
        # X_tmp1, y_tmp1 = generate_partial_data(X=X_test_private, y=y_test_private,    #主要是mod_private_classes:relabel the selected private data
        #                                      class_in_use=mod_private_classes1,
        #                                      verbose=True)
        # private_test_data1 = {"X1": X_tmp1, "y1": y_tmp1}

        # X_tmp2, y_tmp2 = generate_partial_data(X=X_test_private, y=y_test_private,    #主要是mod_private_classes:relabel the selected private data
        #                                      class_in_use=mod_private_classes2,
        #                                      verbose=True)
        # private_test_data2 = {"X2": X_tmp2, "y2": y_tmp2}

        # private_test_data = dict(private_test_data1, **private_test_data2)
        #private_test_data = dict(private_test_data2, **private_test_data1)

    else:
        X_train_private, y_train_private \
            = generate_partial_data(X=X_train_private, y=y_train_private,
                                    class_in_use=private_classes,
                                    verbose=True)

        X_test_private, y_test_private\
            = generate_partial_data(X=X_test_private, y=y_test_private,
                                    class_in_use=private_classes,
                                    verbose=True)

        # relabel the selected private data for future convenience
        for index, cls_ in enumerate(private_classes):
            y_train_private[y_train_private ==
                            cls_] = index + len(public_classes)
            y_test_private[y_test_private ==
                           cls_] = index + len(public_classes)
        del index, cls_

        # print(pd.Series(y_train_private).value_counts())
        mod_private_classes = np.arange(
            len(private_classes)) + len(public_classes)

        print("=" * 60)
        # generate private data
        if args.dataset == 'cifar':
            users_index = cifar_noniid(y_train_private, N_parties)

            private_data, total_private_data \
                = get_sample_data(X_train_private, y_train_private,users_index,N_samples_per_class*18)
        else:
            users_index = mnist_noniid(y_train_private, N_parties)

            private_data, total_private_data \
                = get_sample_data(X_train_private, y_train_private,users_index,N_samples_per_class*6)

        print("=" * 60)
        X_tmp, y_tmp = generate_partial_data(X=X_test_private,
                                             y=y_test_private,
                                             class_in_use=mod_private_classes,
                                             verbose=True)
        private_test_data = {"X": X_tmp, "y": y_tmp}


    return [X_train_public, y_train_public,X_test_public, y_test_public],\
           [public_dataset,private_data, total_private_data,private_test_data]
Exemplo n.º 3
0
def main(args):
    #####-Choose Variable-#####
    set_variable = args.set_num_Chosenusers
    set_variable0 = copy.deepcopy(args.set_epochs)
    set_variable1 = copy.deepcopy(args.set_degree_noniid)

    if not os.path.exists('./experiresult'):
        os.mkdir('./experiresult')

    # load dataset and split users
    dict_users, dict_users_train, dict_users_test = {}, {}, {}
    dataset_train, dataset_test = [], []
    if args.dataset == 'mnist':
        dataset_train = datasets.MNIST('./dataset/mnist/',
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        dataset_test = datasets.MNIST('./dataset/mnist/',
                                      train=False,
                                      download=True,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))
        # sample users
        if args.iid:
            dict_users = mnist_iid(args, dataset_train, args.num_users,
                                   args.num_items_train)
            # dict_users_test = mnist_iid(dataset_test, args.num_users, args.num_items_test)
            dict_sever = mnist_iid(args, dataset_test, args.num_users,
                                   args.num_items_test)
        else:
            dict_users = mnist_noniid(args, dataset_train, args.num_users,
                                      args.num_items_train)
            dict_sever = mnist_noniid(args, dataset_test, args.num_users,
                                      args.num_items_test)

    elif args.dataset == 'cifar':
        dict_users_train, dict_sever = {}, {}
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset_train = datasets.CIFAR10('./dataset/cifar/',
                                         train=True,
                                         transform=transform,
                                         target_transform=None,
                                         download=True)
        dataset_test = copy.deepcopy(dataset_train)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users,
                                   args.num_items_train)
            num_train = int(0.6 * args.num_items_train)
            for idx in range(args.num_users):
                dict_users_train[idx] = set(list(dict_users[idx])[:num_train])
                dict_sever[idx] = set(list(dict_users[idx])[num_train:])
        else:
            dict_users = cifar_noniid(args, dataset_train, args.num_users,
                                      args.num_items_train)
            dict_test = []
            num_train = int(0.6 * args.num_items_train)
            for idx in range(args.num_users):
                dict_users_train[idx] = set(list(dict_users[idx])[:num_train])
                dict_sever[idx] = set(list(dict_users[idx])[num_train:])

    # sample users
    if args.iid:
        dict_users = mnist_iid(args, dataset_train, args.num_users,
                               args.num_items_train)
        # dict_users_test = mnist_iid(dataset_test, args.num_users, args.num_items_test)
        dict_sever = mnist_iid(args, dataset_test, args.num_users,
                               args.num_items_test)
    else:
        dict_users = mnist_noniid(args, dataset_train, args.num_users,
                                  args.num_items_train)
        dict_sever = mnist_iid(args, dataset_test, args.num_users,
                               args.num_items_test)

    img_size = dataset_train[0][0].shape

    for v in range(len(set_variable)):
        final_train_loss = [[0 for i in range(len(set_variable1))]
                            for j in range(len(set_variable0))]
        final_train_accuracy = [[0 for i in range(len(set_variable1))]
                                for j in range(len(set_variable0))]
        final_test_loss = [[0 for i in range(len(set_variable1))]
                           for j in range(len(set_variable0))]
        final_test_accuracy = [[0 for i in range(len(set_variable1))]
                               for j in range(len(set_variable0))]
        final_com_cons = [[0 for i in range(len(set_variable1))]
                          for j in range(len(set_variable0))]
        args.num_Chosenusers = copy.deepcopy(set_variable[v])
        for s in range(len(set_variable0)):
            for j in range(len(set_variable1)):
                args.epochs = copy.deepcopy(set_variable0[s])
                args.degree_noniid = copy.deepcopy(set_variable1[j])
                print(args)
                loss_test, loss_train = [], []
                acc_test, acc_train = [], []
                for m in range(args.num_experiments):
                    # build model
                    net_glob = None
                    if args.model == 'cnn' and args.dataset == 'mnist':
                        if args.gpu != -1:
                            torch.cuda.set_device(args.gpu)
                            # net_glob = CNNMnist(args=args).cuda()
                            net_glob = CNN_test(args=args).cuda()
                        else:
                            net_glob = CNNMnist(args=args)
                    elif args.model == 'mlp' and args.dataset == 'mnist':
                        len_in = 1
                        for x in img_size:
                            len_in *= x
                        if args.gpu != -1:
                            torch.cuda.set_device(args.gpu)
                            net_glob = MLP1(dim_in=len_in,
                                            dim_hidden=256,
                                            dim_out=args.num_classes).cuda()
                        else:
                            net_glob = MLP1(dim_in=len_in,
                                            dim_hidden=256,
                                            dim_out=args.num_classes)
                    elif args.model == 'cnn' and args.dataset == 'cifar':
                        if args.gpu != -1:
                            net_glob = CNNCifar(args).cuda()
                        else:
                            net_glob = CNNCifar(args)
                    else:
                        exit('Error: unrecognized model')
                    print("Nerual Net:", net_glob)

                    net_glob.train(
                    )  #Train() does not change the weight values
                    # copy weights
                    w_glob = net_glob.state_dict()
                    w_size = 0
                    w_size_all = 0
                    for k in w_glob.keys():
                        size = w_glob[k].size()
                        if (len(size) == 1):
                            nelements = size[0]
                        else:
                            nelements = size[0] * size[1]
                        w_size += nelements * 4
                        w_size_all += nelements
                        # print("Size ", k, ": ",nelements*4)
                    print("Weight Size:", w_size, " bytes")
                    print("Weight & Grad Size:", w_size * 2, " bytes")
                    print("Each user Training size:",
                          784 * 8 / 8 * args.local_bs, " bytes")
                    print("Total Training size:", 784 * 8 / 8 * 60000,
                          " bytes")
                    # training
                    loss_avg_list, acc_avg_list, list_loss, loss_avg, com_cons = [], [], [], [], []
                    ###  FedAvg Aglorithm  ###
                    for iter in range(args.epochs):
                        print('\n', '*' * 20, f'Epoch: {iter}', '*' * 20)
                        if args.num_Chosenusers < args.num_users:
                            chosenUsers = random.sample(
                                range(args.num_users), args.num_Chosenusers)
                            chosenUsers.sort()
                        else:
                            chosenUsers = range(args.num_users)
                        print("\nChosen users:", chosenUsers)
                        w_locals, w_locals_1ep, loss_locals, acc_locals = [], [], [], []

                        values_golbal = []
                        for i in w_glob.keys():
                            values_golbal += list(
                                w_glob[i].view(-1).cpu().numpy())

                        for idx in range(len(chosenUsers)):
                            local = LocalUpdate(
                                args=args,
                                dataset=dataset_train,
                                idxs=dict_users[chosenUsers[idx]],
                                tb=summary)
                            w_1st_ep, w, loss, acc = local.update_weights(
                                net=copy.deepcopy(net_glob))
                            w_locals.append(copy.deepcopy(w))
                            ### get 1st ep local weights ###
                            w_locals_1ep.append(copy.deepcopy(w_1st_ep))
                            loss_locals.append(copy.deepcopy(loss))
                            # print("User ", chosenUsers[idx], " Acc:", acc, " Loss:", loss)
                            acc_locals.append(copy.deepcopy(acc))

                            # histogram for all clients
                            values_local = []
                            for i in w_glob.keys():
                                values_local += list(
                                    w[i].view(-1).cpu().numpy())
                            values_increment = [
                                values_local[i] - values_golbal[i]
                                for i in range(len(values_local))
                            ]
                            value_sequence = sorted(
                                [d for d in values_increment],
                                reverse=True)  # value sequence
                            hist, bin_edges = np.histogram(value_sequence,
                                                           bins=100)
                            # valueCount = collections.Counter(hist)
                            # val, cnt = zip(*valueCount.items())
                            #print(hist, bin_edges)
                            # fig, ax = plt.subplots()
                            plt.close()
                            # plt.bar(range(len(hist)), hist, width=0.80, color='b')
                            # plt.close()
                            # plt.hist(value_sequence,bin_edges,color='b',alpha=0.8, rwidth=0.8)
                            plt.hist(value_sequence,
                                     bin_edges,
                                     color='steelblue',
                                     edgecolor='black',
                                     alpha=0.8)
                            plt.savefig(
                                './histogra/histogra-{}-client-{}-iter-{}.pdf'.
                                format(args.model, idx, iter))
                            plt.show()

                        # malicious_users = [0, 3]
                        # w_locals = noise_add(args, w_locals, 0.001, malicious_users)
                        ### update global weights ###
                        # w_locals = users_sampling(args, w_locals, chosenUsers)
                        w_glob = average_weights(w_locals)

                        # val_min_dist, val_mah_dist = [], []
                        # for i in range(len(w_locals)):
                        #     val_min_dist.append(minkowski_distance(w_locals[i],w_glob,1))
                        #     val_mah_dist.append(mahala_distance(w_locals[i],w_glob,w_locals,5))
                        # print('Minkowski distance:', val_mah_dist)
                        # print('Mahala distance:', val_min_dist)

                        # copy weight to net_glob
                        net_glob.load_state_dict(w_glob)
                        # global test
                        list_acc, list_loss = [], []
                        net_glob.eval()
                        for c in range(args.num_users):
                            net_local = LocalUpdate(args=args,
                                                    dataset=dataset_test,
                                                    idxs=dict_sever[idx],
                                                    tb=summary)
                            acc, loss = net_local.test(net=net_glob)
                            # acc, loss = net_local.test_gen(net=net_glob, idxs=dict_users[c], dataset=dataset_test)
                            list_acc.append(acc)
                            list_loss.append(loss)
                        # print("\nEpoch: {}, Global test loss {}, Global test acc: {:.2f}%".\
                        #      format(iter, sum(list_loss) / len(list_loss),100. * sum(list_acc) / len(list_acc)))
                        # print loss
                        loss_avg = sum(loss_locals) / len(loss_locals)
                        acc_avg = sum(acc_locals) / len(acc_locals)
                        loss_avg_list.append(loss_avg)
                        acc_avg_list.append(acc_avg)
                        print("\nTrain loss: {}, Train acc: {}".\
                              format(loss_avg_list[-1], acc_avg_list[-1]))
                        print("\nTest loss: {}, Test acc: {}".\
                              format(sum(list_loss) / len(list_loss), sum(list_acc) / len(list_acc)))

                        # if (iter+1)%20==0:
                        #     torch.save(net_glob.state_dict(),'./Train_model/glob_model_{}epochs.pth'.format(iter))

                    loss_train.append(loss_avg)
                    acc_train.append(acc_avg)
                    loss_test.append(sum(list_loss) / len(list_loss))
                    acc_test.append(sum(list_acc) / len(list_acc))
                    com_cons.append(iter + 1)
                # plot loss curve
                final_train_loss[s][j] = copy.deepcopy(
                    sum(loss_train) / len(loss_train))
                final_train_accuracy[s][j] = copy.deepcopy(
                    sum(acc_train) / len(acc_train))
                final_test_loss[s][j] = copy.deepcopy(
                    sum(loss_test) / len(loss_test))
                final_test_accuracy[s][j] = copy.deepcopy(
                    sum(acc_test) / len(acc_test))
                final_com_cons[s][j] = copy.deepcopy(
                    sum(com_cons) / len(com_cons))

            print('\nFinal train loss:', final_train_loss)
            print('\nFinal train accuracy:', final_train_accuracy)
            print('\nFinal test loss:', final_test_loss)
            print('\nFinal test accuracy:', final_test_accuracy)
        timeslot = int(time.time())
        data_test_loss = pd.DataFrame(index=set_variable0,
                                      columns=set_variable1,
                                      data=final_train_loss)
        data_test_loss.to_csv(
            './experiresult/' +
            'train_loss_{}_{}.csv'.format(set_variable[v], timeslot))
        data_test_loss = pd.DataFrame(index=set_variable0,
                                      columns=set_variable1,
                                      data=final_test_loss)
        data_test_loss.to_csv(
            './experiresult/' +
            'test_loss_{}_{}.csv'.format(set_variable[v], timeslot))
        data_test_acc = pd.DataFrame(index=set_variable0,
                                     columns=set_variable1,
                                     data=final_train_accuracy)
        data_test_acc.to_csv(
            './experiresult/' +
            'train_acc_{}_{}.csv'.format(set_variable[v], timeslot))
        data_test_acc = pd.DataFrame(index=set_variable0,
                                     columns=set_variable1,
                                     data=final_test_accuracy)
        data_test_acc.to_csv(
            './experiresult/' +
            'test_acc_{}_{}.csv'.format(set_variable[v], timeslot))
        data_test_acc = pd.DataFrame(index=set_variable0,
                                     columns=set_variable1,
                                     data=final_com_cons)
        data_test_acc.to_csv('./experiresult/' +
                             'aggregation_consuming_{}_{}.csv'.format(
                                 set_variable[v], timeslot))

        plt.close()
Exemplo n.º 4
0
def main(loc_ep, weighted, alg):
    # parse args
    args = args_parser()
    algo_list = ['fedavg', 'fedprox', 'fsvgr']
    # define paths
    path_project = os.path.abspath('..')

    summary = SummaryWriter('local')
    args.gpu = 0  # -1 (CPU only) or GPU = 0
    args.lr = 0.002  # 0.001 for cifar dataset
    args.model = 'mlp'  # 'mlp' or 'cnn'
    args.dataset = 'mnist'  # 'cifar' or 'mnist'
    args.num_users = 5
    args.epochs = 30  # numb of global iters
    args.local_ep = loc_ep  # numb of local iters
    args.local_bs = 1201  # Local Batch size (>=1200 = full dataset size of a user for mnist, 2000 for cifar)
    args.algorithm = alg  # 'fedavg', 'fedprox', 'fsvgr'
    args.iid = False
    args.verbose = False
    print("dataset:", args.dataset, " num_users:", args.num_users, " epochs:", args.epochs, "local_ep:", args.local_ep)

    # load dataset and split users
    dict_users = {}
    dataset_train = []
    if args.dataset == 'mnist':
        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307,), (0.3081,))
                                       ]))
        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = mnist_noniid(dataset_train, args.num_users)
    elif args.dataset == 'cifar':
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('../data/cifar', train=True, transform=transform, target_transform=None,
                                         download=True)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            dict_users = cifar_noniid(dataset_train, args.num_users)
            # exit('Error: only consider IID setting in CIFAR10')
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape

    # build model
    net_glob = None
    if args.model == 'cnn' and args.dataset == 'cifar':
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            net_glob = CNNCifar(args=args).cuda()
        else:
            net_glob = CNNCifar(args=args)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            net_glob = CNNMnist(args=args).cuda()
        else:
            net_glob = CNNMnist(args=args)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            # net_glob = MLP1(dim_in=len_in, dim_hidden=128, dim_out=args.num_classes).cuda()
            net_glob = MLP1(dim_in=len_in, dim_hidden=256, dim_out=args.num_classes).cuda()
        else:
            # net_glob = MLP1(dim_in=len_in, dim_hidden=128, dim_out=args.num_classes)
            net_glob = MLP1(dim_in=len_in, dim_hidden=256, dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')
    print("Nerual Net:", net_glob)

    net_glob.train()  # Train() does not change the weight values
    # copy weights
    w_glob = net_glob.state_dict()

    # w_size = 0
    # for k in w_glob.keys():
    #     size = w_glob[k].size()
    #     if (len(size) == 1):
    #         nelements = size[0]
    #     else:
    #         nelements = size[0] * size[1]
    #     w_size += nelements * 4
    #     # print("Size ", k, ": ",nelements*4)
    # print("Weight Size:", w_size, " bytes")
    # print("Weight & Grad Size:", w_size * 2, " bytes")
    # print("Each user Training size:", 784 * 8 / 8 * args.local_bs, " bytes")
    # print("Total Training size:", 784 * 8 / 8 * 60000, " bytes")
    # # training
    global_grad = []
    user_grads = []
    loss_test = []
    acc_test = []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0
    net_best = None
    val_acc_list, net_list = [], []
    # print(dict_users.keys())

    rs_avg_acc, rs_avg_loss, rs_glob_acc, rs_glob_loss= [], [], [], []

    ###  FedAvg Aglorithm  ###
    if args.algorithm == 'fedavg':
        # for iter in tqdm(range(args.epochs)):
        for iter in range(args.epochs):
            w_locals, loss_locals, acc_locals, num_samples_list = [], [], [], []
            for idx in range(args.num_users):
                if(args.local_bs>1200):
                    local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary, bs=200*(4+idx)) #Batch_size bs = full data
                else:
                    local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary, bs=args.local_bs)
                num_samples, w, loss, acc = local.update_weights(net=copy.deepcopy(net_glob))
                num_samples_list.append(num_samples)
                w_locals.append(copy.deepcopy(w))
                loss_locals.append(copy.deepcopy(loss))
                # print("User ", idx, " Acc:", acc, " Loss:", loss)
                acc_locals.append(copy.deepcopy(acc))
            # update global weights
            if(weighted):
                w_glob = weighted_average_weights(w_locals, num_samples_list)
            else:
                w_glob = average_weights(w_locals)


            # copy weight to net_glob
            net_glob.load_state_dict(w_glob)
            # global test
            list_acc, list_loss = [], []
            net_glob.eval()
            # for c in tqdm(range(args.num_users)):
            for c in range(args.num_users):
                if (args.local_bs > 1200):
                    net_local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary, bs=200*(4+c)) #Batch_size bs = full data
                else:
                    net_local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary, bs=args.local_bs)
                acc, loss = net_local.test(net=net_glob)
                list_acc.append(acc)
                list_loss.append(loss)
            print("\nEpoch: {}, Global test loss {}, Global test acc: {:.2f}%".format(iter,
                                                                                      sum(list_loss) / len(list_loss),
                                                                                      100. * sum(list_acc) / len(
                                                                                          list_acc)))

            # print loss
            loss_avg = sum(loss_locals) / len(loss_locals)
            acc_avg = sum(acc_locals) / len(acc_locals)
            if args.epochs % 1 == 0:
                print('\nUsers Train Average loss:', loss_avg)
                print('\nTrain Train Average accuracy', acc_avg)
            # loss_test.append(sum(list_loss) / len(list_loss))
            # acc_test.append(sum(list_acc) / len(list_acc))

            rs_avg_acc.append(acc_avg)
            rs_avg_loss.append(loss_avg)
            rs_glob_acc.append(sum(list_acc) / len(list_acc))
            rs_glob_loss.append(sum(list_loss) / len(list_loss))
            # if (acc_avg >= 0.89):
            #     return iter+1

    ###  FedProx Aglorithm  ###
    elif args.algorithm == 'fedprox':
        args.mu = 0.005  ### change mu 0.001
        args.limit = 0.3
        # for iter in tqdm(range(args.epochs)):
        for iter in range(args.epochs):
            w_locals, loss_locals, acc_locals, num_samples_list = [], [], [], []
            # m = max(int(args.frac * args.num_users), 1)
            # idxs_users = np.random.choice(range(args.num_users), m, replace=False)
            for idx in range(args.num_users):
                if(args.local_bs>1200):
                    local = LocalFedProxUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary, bs=200*(4+idx)) #Batch_size bs = full data
                else:
                    local = LocalFedProxUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary, bs=args.local_bs)

                num_samples, w, loss, acc = local.update_FedProx_weights(net=copy.deepcopy(net_glob))
                num_samples_list.append(num_samples)
                w_locals.append(copy.deepcopy(w))
                loss_locals.append(copy.deepcopy(loss))
                # print("User ", idx, " Acc:", acc, " Loss:", loss)
                acc_locals.append(copy.deepcopy(acc))
            # update global weights
            if(weighted):
                w_glob = weighted_average_weights(w_locals, num_samples_list)
            else:
                w_glob = average_weights(w_locals)

            # copy weight to net_glob
            net_glob.load_state_dict(w_glob)
            # global test
            list_acc, list_loss = [], []
            net_glob.eval()
            # for c in tqdm(range(args.num_users)):
            for c in range(args.num_users):
                if (args.local_bs > 1200):
                    net_local = LocalFedProxUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary,
                                            bs=200 * (4 + c))  # Batch_size bs = full data
                else:
                    net_local = LocalFedProxUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary,
                                            bs=args.local_bs)

                acc, loss = net_local.test(net=net_glob)
                list_acc.append(acc)
                list_loss.append(loss)
            print("\nEpoch: {}, Global test loss {}, Global test acc: {:.2f}%".format(iter,
                                                                                      sum(list_loss) / len(list_loss),
                                                                                      100. * sum(list_acc) / len(
                                                                                          list_acc)))

            # print loss
            loss_avg = sum(loss_locals) / len(loss_locals)
            acc_avg = sum(acc_locals) / len(acc_locals)
            if args.epochs % 1 == 0:
                print('\nUsers train average loss:', loss_avg)
                print('\nUsers train average accuracy', acc_avg)
            # loss_test.append(sum(list_loss) / len(list_loss))
            # acc_test.append(sum(list_acc) / len(list_acc))

            rs_avg_acc.append(acc_avg)
            rs_avg_loss.append(loss_avg)
            rs_glob_acc.append(sum(list_acc) / len(list_acc))
            rs_glob_loss.append(sum(list_loss) / len(list_loss))
            # if (acc_avg >= 0.89):
            #     return iter+1

    ###  FSVGR Aglorithm  ###
    elif args.algorithm == 'fsvgr':
        args.ag_scalar = 1.  # 0.001 or 0.1
        args.lg_scalar = 1
        args.threshold = 0.001
        # for iter in tqdm(range(args.epochs)):
        for iter in range(args.epochs):

            print("=========Global epoch {}=========".format(iter))
            w_locals, loss_locals, acc_locals = [], [], []
            """
            First communication round: server send w_t to client --> client calculate gradient and send
            to sever --> server calculate average global gradient and send to client
            """
            for idx in range(args.num_users):
                local = LocalFSVGRUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary)
                num_sample, grad_k = local.calculate_global_grad(net=copy.deepcopy(net_glob))
                user_grads.append([num_sample, grad_k])
            global_grad = calculate_avg_grad(user_grads)

            """
            Second communication round: client update w_k_t+1 and send to server --> server update global w_t+1
            """
            for idx in range(args.num_users):
                print("Training user {}".format(idx))
                local = LocalFSVGRUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary)
                num_samples, w_k, loss, acc = local.update_FSVGR_weights(global_grad, idx, copy.deepcopy(net_glob),
                                                                         iter)
                w_locals.append(copy.deepcopy([num_samples, w_k]))
                print("Global_Epoch ", iter, "User ", idx, " Acc:", acc, " Loss:", loss)
                loss_locals.append(copy.deepcopy(loss))
                acc_locals.append(copy.deepcopy(acc))

            # w_t = net_glob.state_dict()
            w_glob = average_FSVRG_weights(w_locals, args.ag_scalar, copy.deepcopy(net_glob), args.gpu)

            # copy weight to net_glob
            net_glob.load_state_dict(w_glob)

            # global test
            list_acc, list_loss = [], []
            net_glob.eval()
            # for c in tqdm(range(args.num_users)):
            for c in range(args.num_users):
                net_local = LocalFSVGRUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary)
                acc, loss = net_local.test(net=net_glob)
                list_acc.append(acc)
                list_loss.append(loss)

            print("\nTest Global Weight:", list_acc)
            print("\nEpoch: {}, Global test loss {}, Global test acc: {:.2f}%".format(iter,
                                                                                      sum(list_loss) / len(list_loss),
                                                                                      100. * sum(list_acc) / len(
                                                                                          list_acc)))

            # print loss
            loss_avg = sum(loss_locals) / len(loss_locals)
            acc_avg = sum(acc_locals) / len(acc_locals)
            if iter % 1 == 0:
                print('\nEpoch: {}, Users train average loss: {}'.format(iter, loss_avg))
                print('\nEpoch: {}, Users train average accuracy: {}'.format(iter, acc_avg))
            loss_test.append(sum(list_loss) / len(list_loss))
            acc_test.append(sum(list_acc) / len(list_acc))

            if (acc_avg >= 0.89):
                return
    if(weighted):
        alg=alg+'1'
    simple_save_data(loc_ep, alg, rs_avg_acc, rs_avg_loss, rs_glob_acc, rs_glob_loss)
    plot_rs(loc_ep, alg)
Exemplo n.º 5
0
def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'
        apply_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                         transform=apply_transform)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                        transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Choose unequal splits for every user
                raise NotImplementedError()
            else:
                # Choose equal splits for every user
                user_groups = cifar_noniid(train_dataset, args.num_users)

    elif args.dataset == 'mnist' or 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = mnist_noniid(train_dataset, args.num_users)

    return train_dataset, test_dataset, user_groups
Exemplo n.º 6
0
def getdataset(args, conf_dict):
    N_parties = conf_dict["N_parties"]
    N_samples_per_class = conf_dict["N_samples_per_class"]

    public_classes = conf_dict["public_classes"]
    private_classes = conf_dict["private_classes"]

    # 载入数据集
    if args.dataset == 'cifar':
        X_train_public, y_train_public, X_test_public, y_test_public = get_dataarray(
            args, dataset='cifar10')
        public_dataset = {"X": X_train_public, "y": y_train_public}

        X_train_private, y_train_private, X_test_private, y_test_private\
            = get_dataarray(args,dataset='cifar100')

    elif args.dataset == 'mnist':
        X_train_public, y_train_public, X_test_public, y_test_public = get_dataarray(
            args, dataset='mnist')
        public_dataset = {"X": X_train_public, "y": y_train_public}

        X_train_private, y_train_private, X_test_private, y_test_private \
            = get_dataarray(args,dataset='fmnist')

        y_train_private += len(
            public_classes)  #所有标签数值加上len(public_classes)这一常数
        y_test_private += len(public_classes)

    # only use those data whose y_labels belong to private_classes
    # 采样
    if args.iid:
        X_train_private, y_train_private \
            = generate_partial_data(X=X_train_private, y=y_train_private,       #取[private_class]中的data
                                    class_in_use=private_classes,
                                    verbose=True)

        X_test_private, y_test_private \
            = generate_partial_data(X=X_test_private, y=y_test_private,
                                    class_in_use=private_classes,
                                    verbose=True)

        # relabel the selected private data for future convenience
        for index, cls_ in enumerate(private_classes):
            y_train_private[y_train_private == cls_] = index + len(
                public_classes)  #list的bool值索引
            y_test_private[y_test_private ==
                           cls_] = index + len(public_classes)
        del index, cls_

        print(pd.Series(y_train_private).value_counts())
        mod_private_classes = np.arange(len(private_classes)) + len(
            public_classes
        )  #mod_private_classes=【0~len(private_classes)-1】每个元素加上常数len(public_classes)

        #---------------------------------------10~12----------13~15-------
        mod_private_classes1 = np.arange(
            len(private_classes) // 2) + len(public_classes)
        mod_private_classes2 = np.arange(
            len(private_classes) //
            2) + len(public_classes) + len(private_classes) // 2
        print("mod_private_class1:", mod_private_classes1)
        print("mod_private_class2:", mod_private_classes2)
        #---------------------------------------

        # print("=" * 60)
        # # generate private data
        # private_data, total_private_data \
        #     = generate_bal_private_data(X_train_private, y_train_private,
        #                                 N_parties=N_parties,
        #                                 classes_in_use=mod_private_classes,
        #                                 N_samples_per_class=N_samples_per_class,
        #                                 data_overlap=False)

        print("=" * 60)
        # generate private data
        private_data1, total_private_data1 \
            = generate_bal_private_data(X_train_private, y_train_private,
                                        N_parties=N_parties//2,
                                        classes_in_use=mod_private_classes1,
                                        N_samples_per_class=N_samples_per_class,
                                        data_overlap=False)

        print("=" * 60)
        # generate private data
        private_data2, total_private_data2 \
            = generate_bal_private_data(X_train_private, y_train_private,
                                        N_parties=N_parties//2,
                                        classes_in_use=mod_private_classes2,
                                        N_samples_per_class=N_samples_per_class,
                                        data_overlap=False)

        #----------合并两个list-----
        private_data = private_data1 + private_data2
        total_private_data = dict(total_private_data1, **total_private_data2)
        #total_private_data = dict(total_private_data2, **total_private_data1)

        # print("=" * 60)
        # X_tmp, y_tmp = generate_partial_data(X=X_test_private, y=y_test_private,    #主要是mod_private_classes:relabel the selected private data
        #                                      class_in_use=mod_private_classes,
        #                                      verbose=True)
        # private_test_data = {"X": X_tmp, "y": y_tmp}

        print("=" * 60)
        X_tmp1, y_tmp1 = generate_partial_data(
            X=X_test_private,
            y=y_test_private,  #主要是mod_private_classes:relabel the selected private data 
            class_in_use=mod_private_classes1,
            verbose=True)
        private_test_data1 = {"X1": X_tmp1, "y1": y_tmp1}

        X_tmp2, y_tmp2 = generate_partial_data(
            X=X_test_private,
            y=y_test_private,  #主要是mod_private_classes:relabel the selected private data 
            class_in_use=mod_private_classes2,
            verbose=True)
        private_test_data2 = {"X2": X_tmp2, "y2": y_tmp2}

        private_test_data = dict(private_test_data1, **private_test_data2)
        #private_test_data = dict(private_test_data2, **private_test_data1)

    else:
        X_train_private, y_train_private \
            = generate_partial_data(X=X_train_private, y=y_train_private,
                                    class_in_use=private_classes,
                                    verbose=True)

        X_test_private, y_test_private\
            = generate_partial_data(X=X_test_private, y=y_test_private,
                                    class_in_use=private_classes,
                                    verbose=True)

        # relabel the selected private data for future convenience
        for index, cls_ in enumerate(private_classes):
            y_train_private[y_train_private ==
                            cls_] = index + len(public_classes)
            y_test_private[y_test_private ==
                           cls_] = index + len(public_classes)
        del index, cls_

        # print(pd.Series(y_train_private).value_counts())
        mod_private_classes = np.arange(
            len(private_classes)) + len(public_classes)

        print("=" * 60)
        # generate private data
        if args.dataset == 'cifar':
            users_index = cifar_noniid(y_train_private, N_parties)

            private_data, total_private_data \
                = get_sample_data(X_train_private, y_train_private,users_index,N_samples_per_class*18)
        else:
            users_index = mnist_noniid(y_train_private, N_parties)

            private_data, total_private_data \
                = get_sample_data(X_train_private, y_train_private,users_index,N_samples_per_class*6)

        print("=" * 60)
        X_tmp, y_tmp = generate_partial_data(X=X_test_private,
                                             y=y_test_private,
                                             class_in_use=mod_private_classes,
                                             verbose=True)
        private_test_data = {"X": X_tmp, "y": y_tmp}


    return [X_train_public, y_train_public,X_test_public, y_test_public],\
           [public_dataset,private_data, total_private_data,private_test_data]
Exemplo n.º 7
0
def getdataset(args, conf_dict):
    N_parties = conf_dict["N_parties"]
    N_samples_per_class = conf_dict["N_samples_per_class"]

    public_classes = conf_dict["public_classes"]
    private_classes = conf_dict["private_classes"]

    # 载入数据集
    if args.dataset == 'cifar':
        X_train_public, y_train_public, X_test_public, y_test_public = get_dataarray(
            args, dataset='cifar10')
        public_dataset = {"X": X_train_public, "y": y_train_public}

        X_train_private, y_train_private, X_test_private, y_test_private\
            = get_dataarray(args,dataset='cifar100')

    elif args.dataset == 'mnist':
        X_train_public, y_train_public, X_test_public, y_test_public = get_dataarray(
            args, dataset='mnist')
        public_dataset = {"X": X_train_public, "y": y_train_public}

        X_train_private, y_train_private, X_test_private, y_test_private \
            = get_dataarray(args,dataset='fmnist')

        y_train_private += len(public_classes)
        y_test_private += len(public_classes)

    # only use those data whose y_labels belong to private_classes
    # 采样
    if args.iid:
        X_train_private, y_train_private \
            = generate_partial_data(X=X_train_private, y=y_train_private,
                                    class_in_use=private_classes,
                                    verbose=True)

        X_test_private, y_test_private \
            = generate_partial_data(X=X_test_private, y=y_test_private,
                                    class_in_use=private_classes,
                                    verbose=True)

        # relabel the selected private data for future convenience
        for index, cls_ in enumerate(private_classes):
            y_train_private[y_train_private ==
                            cls_] = index + len(public_classes)
            y_test_private[y_test_private ==
                           cls_] = index + len(public_classes)
        del index, cls_

        print(pd.Series(y_train_private).value_counts())
        mod_private_classes = np.arange(
            len(private_classes)) + len(public_classes)

        print("=" * 60)
        # generate private data
        private_data, total_private_data \
            = generate_bal_private_data(X_train_private, y_train_private,
                                        N_parties=N_parties,
                                        classes_in_use=mod_private_classes,
                                        N_samples_per_class=N_samples_per_class,
                                        data_overlap=False)
        print("=" * 60)
        X_tmp, y_tmp = generate_partial_data(X=X_test_private,
                                             y=y_test_private,
                                             class_in_use=mod_private_classes,
                                             verbose=True)
        private_test_data = {"X": X_tmp, "y": y_tmp}

    else:
        X_train_private, y_train_private \
            = generate_partial_data(X=X_train_private, y=y_train_private,
                                    class_in_use=private_classes,
                                    verbose=True)

        X_test_private, y_test_private\
            = generate_partial_data(X=X_test_private, y=y_test_private,
                                    class_in_use=private_classes,
                                    verbose=True)

        # relabel the selected private data for future convenience
        for index, cls_ in enumerate(private_classes):
            y_train_private[y_train_private ==
                            cls_] = index + len(public_classes)
            y_test_private[y_test_private ==
                           cls_] = index + len(public_classes)
        del index, cls_

        # print(pd.Series(y_train_private).value_counts())
        mod_private_classes = np.arange(
            len(private_classes)) + len(public_classes)

        print("=" * 60)
        # generate private data
        if args.dataset == 'cifar':
            users_index = cifar_noniid(y_train_private, N_parties)

            private_data, total_private_data \
                = get_sample_data(X_train_private, y_train_private,users_index,N_samples_per_class*18)
        else:
            users_index = mnist_noniid(y_train_private, N_parties)

            private_data, total_private_data \
                = get_sample_data(X_train_private, y_train_private,users_index,N_samples_per_class*6)

        print("=" * 60)
        X_tmp, y_tmp = generate_partial_data(X=X_test_private,
                                             y=y_test_private,
                                             class_in_use=mod_private_classes,
                                             verbose=True)
        private_test_data = {"X": X_tmp, "y": y_tmp}


    return [X_train_public, y_train_public,X_test_public, y_test_public],\
           [public_dataset,private_data, total_private_data,private_test_data ]
Exemplo n.º 8
0
def main(args):
    #####-Choose Variable-#####
    set_variable = args.set_num_Chosenusers
    set_variable0 = copy.deepcopy(args.set_epochs)
    set_variable1 = copy.deepcopy(args.set_privacy_budget)

    if not os.path.exists('./exper_result'):
        os.mkdir('./exper_result')

    # load dataset and split users
    dataset_train, dataset_test = [], []
    dataset_train = datasets.MNIST('./dataset/mnist/',
                                   train=True,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307, ),
                                                            (0.3081, ))
                                   ]))
    dataset_test = datasets.MNIST('./dataset/mnist/',
                                  train=False,
                                  download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.1307, ),
                                                           (0.3081, ))
                                  ]))

    # sample users
    if args.iid:
        dict_users = mnist_iid(dataset_train, args.num_users,
                               args.num_items_train)
        # dict_users_test = mnist_iid(dataset_test, args.num_users, args.num_items_test)
        dict_server = mnist_iid(dataset_test, args.num_users,
                                args.num_items_test)
    else:
        dict_users = mnist_noniid(dataset_train, args.num_users)
        dict_server = mnist_noniid(dataset_test, args.num_users)

    img_size = dataset_train[0][0].shape

    for v in range(len(set_variable)):
        final_train_loss = [[0 for i in range(len(set_variable1))]
                            for j in range(len(set_variable0))]
        final_train_accuracy = [[0 for i in range(len(set_variable1))]
                                for j in range(len(set_variable0))]
        final_test_loss = [[0 for i in range(len(set_variable1))]
                           for j in range(len(set_variable0))]
        final_test_accuracy = [[0 for i in range(len(set_variable1))]
                               for j in range(len(set_variable0))]
        final_com_cons = [[0 for i in range(len(set_variable1))]
                          for j in range(len(set_variable0))]
        args.num_Chosenusers = copy.deepcopy(set_variable[v])
        for s in range(len(set_variable0)):
            for j in range(len(set_variable1)):
                args.epochs = copy.deepcopy(set_variable0[s])
                args.privacy_budget = copy.deepcopy(set_variable1[j])
                print("dataset:", args.dataset, " num_users:", args.num_users, " num_chosen_users:", args.num_Chosenusers, " Privacy budget:", args.privacy_budget,\
                      " epochs:", args.epochs, "local_ep:", args.local_ep, "local train size", args.num_items_train, "batch size:", args.local_bs)
                loss_test, loss_train = [], []
                acc_test, acc_train = [], []
                for m in range(args.num_experiments):
                    # build model
                    net_glob = None
                    if args.model == 'cnn' and args.dataset == 'mnist':
                        if args.gpu != -1:
                            torch.cuda.set_device(args.gpu)
                            # net_glob = CNNMnist(args=args).cuda()
                            net_glob = CNN_test(args=args).cuda()
                        else:
                            net_glob = CNNMnist(args=args)
                    elif args.model == 'mlp':
                        len_in = 1
                        for x in img_size:
                            len_in *= x
                        if args.gpu != -1:
                            torch.cuda.set_device(args.gpu)
                            net_glob = MLP1(dim_in=len_in, dim_hidden=256,\
                                            dim_out=args.num_classes).cuda()
                        else:
                            net_glob = MLP1(dim_in=len_in, dim_hidden=256,\
                                            dim_out=args.num_classes)
                    else:
                        exit('Error: unrecognized model')
                    print("Nerual Net:", net_glob)

                    net_glob.train(
                    )  #Train() does not change the weight values
                    # copy weights
                    w_glob = net_glob.state_dict()
                    w_size = 0
                    w_size_all = 0
                    for k in w_glob.keys():
                        size = w_glob[k].size()
                        if (len(size) == 1):
                            nelements = size[0]
                        else:
                            nelements = size[0] * size[1]
                        w_size += nelements * 4
                        w_size_all += nelements
                        # print("Size ", k, ": ",nelements*4)
                    print("Weight Size:", w_size, " bytes")
                    print("Weight & Grad Size:", w_size * 2, " bytes")
                    print("Each user Training size:",
                          784 * 8 / 8 * args.local_bs, " bytes")
                    print("Total Training size:", 784 * 8 / 8 * 60000,
                          " bytes")
                    # training
                    threshold_epochs = copy.deepcopy(args.epochs)
                    threshold_epochs_list, noise_list = [], []
                    loss_avg_list, acc_avg_list, list_loss, loss_avg = [], [], [], []
                    eps_tot_list, eps_tot = [], 0
                    com_cons = []
                    ###  FedAvg Aglorithm  ###
                    ### Compute noise scale ###
                    noise_scale = copy.deepcopy(Privacy_account(args,\
                                            threshold_epochs, noise_list, 0))
                    for iter in range(args.epochs):
                        print('\n', '*' * 20, f'Epoch: {iter}', '*' * 20)
                        start_time = time.time()
                        if args.num_Chosenusers < args.num_users:
                            chosenUsers = random.sample(range(1,args.num_users)\
                                                        ,args.num_Chosenusers)
                            chosenUsers.sort()
                        else:
                            chosenUsers = range(args.num_users)
                        print("\nChosen users:", chosenUsers)

                        if iter >= 1 and args.para_est == True:
                            w_locals_before = copy.deepcopy(w_locals_org)

                        w_locals, w_locals_1ep, loss_locals, acc_locals = [], [], [], []
                        for idx in range(len(chosenUsers)):
                            local = LocalUpdate(args=args, dataset=dataset_train,\
                                    idxs=dict_users[chosenUsers[idx]], tb=summary)
                            w_1st_ep, w, loss, acc = local.update_weights(\
                                                    net=copy.deepcopy(net_glob))
                            w_locals.append(copy.deepcopy(w))
                            ### get 1st ep local weights ###
                            w_locals_1ep.append(copy.deepcopy(w_1st_ep))
                            loss_locals.append(copy.deepcopy(loss))
                            # print("User ", chosenUsers[idx], " Acc:", acc, " Loss:", loss)
                            acc_locals.append(copy.deepcopy(acc))

                        w_locals_org = copy.deepcopy(w_locals)

                        #  estimate some paramters of the loss function
                        if iter >= 2 and args.para_est == True:
                            Lipz_s,Lipz_c,delta,_,_,_,_,_=para_estimate(args,\
                                        list_loss,loss_locals,w_locals_before,\
                                            w_locals_org,w_glob_before,w_glob)
                            print('Lipschitz smooth, lipschitz continuous, gradient divergence:',\
                                  sum(Lipz_s)/len(Lipz_s),sum(Lipz_c)/len(Lipz_c),sum(delta)/len(delta))

                        ### Clipping ###
                        for idx in range(len(chosenUsers)):
                            w_locals[idx] = copy.deepcopy(
                                clipping(args, w_locals[idx]))
                            # print(get_2_norm(w_locals[idx], w_glob))

                        ### perturb 'w_local' ###
                        w_locals = noise_add(args, noise_scale, w_locals)

                        ### update global weights ###
                        ### w_locals = users_sampling(args, w_locals, chosenUsers) ###
                        w_glob_before = copy.deepcopy(w_glob)
                        w_glob = average_weights(w_locals)

                        # copy weight to net_glob
                        net_glob.load_state_dict(w_glob)
                        # global test
                        list_acc, list_loss = [], []
                        net_glob.eval()
                        for c in range(args.num_users):
                            net_local = LocalUpdate(args=args,dataset=dataset_test,\
                                                    idxs=dict_server[c], tb=summary)
                            acc, loss = net_local.test(net=net_glob)
                            # acc, loss = net_local.test_gen(net=net_glob,\
                            # idxs=dict_users[c], dataset=dataset_test)
                            list_acc.append(acc)
                            list_loss.append(loss)
                        # print("\nEpoch:{},Global test loss:{}, Global test acc:{:.2f}%".\
                        #      format(iter, sum(list_loss) / len(list_loss),\
                        #      100. * sum(list_acc) / len(list_acc)))

                        # print loss
                        loss_avg = sum(loss_locals) / len(loss_locals)
                        acc_avg = sum(acc_locals) / len(acc_locals)
                        loss_avg_list.append(loss_avg)
                        acc_avg_list.append(acc_avg)
                        print("\nTrain loss: {}, Train acc: {}".\
                              format(loss_avg_list[-1], acc_avg_list[-1]))
                        print("\nTest loss: {}, Test acc: {}".\
                              format(sum(list_loss) / len(list_loss),\
                                     sum(list_acc) / len(list_acc)))

                        noise_list.append(noise_scale)
                        threshold_epochs_list.append(threshold_epochs)
                        print('\nNoise Scale:', noise_list)
                        print('\nThreshold epochs:', threshold_epochs_list)
                        ### optimal method ###
                        if args.dp_mechanism == 'CRD' and iter >= 1:
                            threshold_epochs = Adjust_T(args, loss_avg_list,\
                                                threshold_epochs_list, iter)
                            noise_scale = copy.deepcopy(Privacy_account(args,\
                                        threshold_epochs, noise_list, iter))

                        # print run time of each experiment
                        end_time = time.time()
                        print('Run time: %f second' % (end_time - start_time))

                        if iter >= threshold_epochs:
                            break
                    loss_train.append(loss_avg)
                    acc_train.append(acc_avg)
                    loss_test.append(sum(list_loss) / len(list_loss))
                    acc_test.append(sum(list_acc) / len(list_acc))
                    com_cons.append(iter + 1)

                # record results
                final_train_loss[s][j] = copy.deepcopy(
                    sum(loss_train) / len(loss_train))
                final_train_accuracy[s][j] = copy.deepcopy(
                    sum(acc_train) / len(acc_train))
                final_test_loss[s][j] = copy.deepcopy(
                    sum(loss_test) / len(loss_test))
                final_test_accuracy[s][j] = copy.deepcopy(
                    sum(acc_test) / len(acc_test))
                final_com_cons[s][j] = copy.deepcopy(
                    sum(com_cons) / len(com_cons))

            print('\nFinal train loss:', final_train_loss)
            print('\nFinal train acc:', final_train_accuracy)
            print('\nFinal test loss:', final_test_loss)
            print('\nFinal test acc:', final_test_accuracy)

        timeslot = int(time.time())
        data_test_loss = pd.DataFrame(index = set_variable0, columns =\
                                      set_variable1, data = final_train_loss)
        data_test_loss.to_csv('./exper_result/'+'train_loss_{}_{}_{}.csv'.\
                              format(set_variable[v],args.dp_mechanism,timeslot))
        data_test_loss = pd.DataFrame(index = set_variable0, columns =\
                                      set_variable1, data = final_test_loss)
        data_test_loss.to_csv('./exper_result/'+'test_loss_{}_{}_{}.csv'.\
                              format(set_variable[v],args.dp_mechanism,timeslot))
        data_test_acc = pd.DataFrame(index = set_variable0, columns =\
                                     set_variable1, data = final_train_accuracy)
        data_test_acc.to_csv('./exper_result/'+'train_acc_{}_{}_{}.csv'.\
                             format(set_variable[v],args.dp_mechanism,timeslot))
        data_test_acc = pd.DataFrame(index = set_variable0, columns =\
                                     set_variable1, data = final_test_accuracy)
        data_test_acc.to_csv('./exper_result/'+'test_acc_{}_{}_{}.csv'.\
                             format(set_variable[v],args.dp_mechanism,timeslot))
        data_test_acc = pd.DataFrame(index = set_variable0, columns =\
                                     set_variable1, data = final_com_cons)
        data_test_acc.to_csv('./exper_result/'+'aggregation_consuming_{}_{}_{}.csv'.\
                             format(set_variable[v],args.dp_mechanism,timeslot))
Exemplo n.º 9
0
    # load dataset and split users
    if args.dataset == 'mnist':
        dataset_train = datasets.MNIST('../data/mnist/',
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = mnist_noniid(dataset_train, args.num_users)
    elif args.dataset == 'cifar':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset_train = datasets.CIFAR10('../data/cifar',
                                         train=True,
                                         transform=transform,
                                         target_transform=None,
                                         download=True)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in CIFAR10')
    else:
Exemplo n.º 10
0
def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar10':
        data_dir = '../data/cifar10/'
        apply_transform_train = transforms.Compose([
            transforms.RandomCrop(24),
            transforms.RandomHorizontalFlip(0.5),
            transforms.ColorJitter(brightness=0.2,
                                   contrast=0.2,
                                   saturation=0.2),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2470, 0.2435, 0.2616))
        ])

        apply_transform_test = transforms.Compose([
            transforms.CenterCrop(24),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2470, 0.2435, 0.2616))
        ])

        train_dataset = datasets.CIFAR10(data_dir,
                                         train=True,
                                         download=True,
                                         transform=apply_transform_train)

        test_dataset = datasets.CIFAR10(data_dir,
                                        train=False,
                                        download=True,
                                        transform=apply_transform_test)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.hard:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups = cifar_noniid(train_dataset, args.num_users)

    elif args.dataset == 'mnist' or args.dataset == 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'
        #TODO:1 Accommodate FMNIST case (mean, var). This is the mean, var of MNIST; Fashion MNIST may have different set of params/
        # shall we use the  params from opt instead of setting hard params?
        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        train_dataset = datasets.MNIST(data_dir,
                                       train=True,
                                       download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir,
                                      train=False,
                                      download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset,
                                                   args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = mnist_noniid(train_dataset, args.num_users)

    elif args.dataset == 'cub200':
        data_dir = '../data/cub200/'
        apply_transform_train = transforms.Compose([
            transforms.Resize(int(cf.imresize[args.net_type])),
            transforms.RandomRotation(10),
            transforms.RandomCrop(cf.imsize[args.net_type]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]),
        ])

        apply_transform_test = transforms.Compose([
            transforms.Resize(cf.imresize[args.net_type]),
            transforms.CenterCrop(cf.imsize[args.net_type]),
            transforms.ToTensor(),
            transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]),
        ])
        train_dataset = cub.CUB200(data_dir,
                                   year=2011,
                                   train=True,
                                   download=True,
                                   transform=apply_transform_train)

        test_dataset = cub.CUB200(data_dir,
                                  year=2011,
                                  train=False,
                                  download=True,
                                  transform=apply_transform_test)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cub_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.hard:
                # Chose uneuqal splits for every user
                user_groups = cub_noniid_hard(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = cub_noniid(train_dataset, args.num_users)

    return train_dataset, test_dataset, user_groups
Exemplo n.º 11
0
def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'
        apply_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])
        test_apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])
        train_dataset = datasets.CIFAR10(data_dir,
                                         train=True,
                                         download=True,
                                         transform=apply_transform)

        test_dataset = datasets.CIFAR10(data_dir,
                                        train=False,
                                        download=True,
                                        transform=test_apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups = cifar_noniid(train_dataset, args.num_users)
    elif args.dataset == 'mnist' or args.dataset == 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        train_dataset = datasets.MNIST(data_dir,
                                       train=True,
                                       download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir,
                                      train=False,
                                      download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset,
                                                   args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = mnist_noniid(train_dataset, args.num_users)
    elif args.dataset == 'brats2018':
        from data.brats2018.dataset import BRATS2018Dataset, InstitutionWiseBRATS2018Dataset
        # from torch.utils.data import random_split
        from sampling import brats2018_iid, brats2018_unbalanced
        data_dir = args.data_dir
        test_dataset = None
        if args.balanced:
            train_dataset = BRATS2018Dataset(training_dir=data_dir,
                                             img_dim=128)
            user_groups = brats2018_iid(dataset=train_dataset,
                                        num_users=args.num_users)
        else:
            # BRATS2018 得到的数据来自于 19家机构. 默认
            train_dataset = InstitutionWiseBRATS2018Dataset(
                training_dir=data_dir,
                img_dim=128,
                config_json='../data/brats2018/hgg_config.json')
            user_groups = brats2018_unbalanced(dataset=train_dataset,
                                               num_users=args.num_users)
    elif args.dataset == 'brats2018_data_aug':
        from data.brats2018.dataset import InstitutionWiseBRATS2018DatasetDataAugmentation
        # from torch.utils.data import random_split
        from sampling import brats2018_iid, brats2018_unbalanced
        data_dir = args.data_dir
        test_dataset = None
        if args.balanced:
            raise NotImplementedError
        else:
            # BRATS2018 得到的数据来自于 19家机构. 默认
            train_dataset = InstitutionWiseBRATS2018DatasetDataAugmentation(
                training_dir=data_dir,
                img_dim=128,
                config_json='../data/brats2018/hgg_config.json')
            user_groups = brats2018_unbalanced(dataset=train_dataset,
                                               num_users=args.num_users)
    return train_dataset, test_dataset, user_groups
def get_dataset(args,
                tokenizer=None,
                max_seq_len=MAX_SEQUENCE_LENGTH,
                custom_sampling=None):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.task == 'nlp':
        assert args.dataset == "ade", "Parsed dataset not implemented."
        [complete_dataset] = ds.load_dataset("ade_corpus_v2",
                                             "Ade_corpus_v2_classification",
                                             split=["train"])
        # Rename column.
        complete_dataset = complete_dataset.rename_column("label", "labels")
        complete_dataset = complete_dataset.shuffle(seed=args.seed)
        # Split into train and test sets.
        split_examples = complete_dataset.train_test_split(
            test_size=args.test_frac)
        train_examples = split_examples["train"]
        test_examples = split_examples["test"]

        # Tokenize training set.
        train_dataset = train_examples.map(
            lambda examples: tokenizer(
                examples["text"],
                truncation=True,
                max_length=max_seq_len,
                padding="max_length",
            ),
            batched=True,
            remove_columns=["text"],
        )
        train_dataset.set_format(type="torch")

        # Tokenize test set.
        test_dataset = test_examples.map(
            lambda examples: tokenizer(
                examples["text"],
                truncation=True,
                max_length=max_seq_len,
                padding="max_length",
            ),
            batched=True,
            remove_columns=["text"],
        )
        test_dataset.set_format(type="torch")

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Ade_corpus
            user_groups = ade_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Ade_corpus
            if args.unequal:
                # Chose unequal splits for every user
                raise NotImplementedError()
            else:
                # Chose equal splits for every user
                user_groups = ade_noniid(train_dataset, args.num_users)

    elif args.task == 'cv':
        if args.dataset == 'cifar':
            data_dir = './data/cifar/'
            apply_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

            train_dataset = datasets.CIFAR10(data_dir,
                                             train=True,
                                             download=True,
                                             transform=apply_transform)

            test_dataset = datasets.CIFAR10(data_dir,
                                            train=False,
                                            download=True,
                                            transform=apply_transform)

            # sample training data amongst users
            if custom_sampling is not None:
                user_groups = custom_sampling(dataset=train_dataset,
                                              num_users=args.num_users)
                assert len(
                    user_groups
                ) == args.num_users, "Incorrect number of users generated."
                check_client_sampled_data = []
                for client_idx, client_samples in user_groups.items():
                    assert len(client_samples) == len(
                        train_dataset
                    ) / args.num_users, "Incorrectly sampled client shard."
                    for record in client_samples:
                        check_client_sampled_data.append(record)
                assert len(set(check_client_sampled_data)) == len(
                    train_dataset), "Client shards are not i.i.d"
                print("Congratulations! You've got it :)")
            else:
                # sample training data amongst users
                if args.iid:
                    # Sample IID user data from Mnist
                    user_groups = cifar_iid(train_dataset, args.num_users)
                else:
                    # Sample Non-IID user data from Mnist
                    if args.unequal:
                        # Chose uneuqal splits for every user
                        raise NotImplementedError()
                    else:
                        # Chose euqal splits for every user
                        user_groups = cifar_noniid(train_dataset,
                                                   args.num_users)

        elif args.dataset == 'mnist' or 'fmnist':
            if args.dataset == 'mnist':
                data_dir = './data/mnist/'
            else:
                data_dir = './data/fmnist/'

            apply_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])

            train_dataset = datasets.MNIST(data_dir,
                                           train=True,
                                           download=True,
                                           transform=apply_transform)

            test_dataset = datasets.MNIST(data_dir,
                                          train=False,
                                          download=True,
                                          transform=apply_transform)

            if args.iid:
                # Sample IID user data from Mnist
                user_groups = mnist_iid(train_dataset, args.num_users)
            else:
                # Sample Non-IID user data from Mnist
                if args.unequal:
                    # Chose unequal splits for every user
                    user_groups = mnist_noniid_unequal(train_dataset,
                                                       args.num_users)
                else:
                    # Chose equal splits for every user
                    user_groups = mnist_noniid(train_dataset, args.num_users)
        else:
            raise NotImplementedError(f"""Unrecognized dataset {args.dataset}.
                Options are: `cifar`, `mnist`, `fmnist`.
                """)
    else:
        raise NotImplementedError(f"""Unrecognised task {args.task}.
            Options are: `nlp` and `cv`.
            """)

    return train_dataset, test_dataset, user_groups
Exemplo n.º 13
0
def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'

        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)

        transforms_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        transforms_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        train_dataset = datasets.CIFAR10(data_dir,
                                         train=True,
                                         download=True,
                                         transform=transforms_train)
        test_dataset = datasets.CIFAR10(data_dir,
                                        train=False,
                                        download=True,
                                        transform=transforms_test)
        """
        apply_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                      transform=apply_transform)
        """
        # sample training data amongst users
        if args.iid == 1:
            # Sample IID user data from Mnist
            # user_groups = cifar_iid(train_dataset, args.num_users)
            user_groups = cifar10_iid(train_dataset, args.num_users, args=args)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                if args.iid == 2:
                    #                    user_groups = partition_data(train_dataset, 'noniid-#label2', num_uers=args.num_users, alpha=1, args=args)
                    user_groups = cifar_noniid(train_dataset,
                                               num_users=args.num_users,
                                               args=args)
                else:
                    user_groups = partition_data(train_dataset,
                                                 'dirichlet',
                                                 num_uers=args.num_users,
                                                 alpha=1,
                                                 args=args)
        # 분류된 index와 train dataset로 client train dataloder 생성
        client_loader_dict = client_loader(train_dataset, user_groups, args)

    elif args.dataset == 'cifar100':
        data_dir = '../data/fed_cifar100'
        train_dataset, test_dataset, client_loader_dict = load_partition_data_federated_cifar100(
            data_dir=data_dir, batch_size=args.local_bs)

    elif args.dataset == 'mnist' or 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        train_dataset = datasets.MNIST(data_dir,
                                       train=True,
                                       download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir,
                                      train=False,
                                      download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset,
                                                   args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = mnist_noniid(train_dataset, args.num_users)

        # 분류된 index와 train dataset로 client train data loader 생성
        client_loader_dict = client_loader(train_dataset, user_groups, args)

    return train_dataset, test_dataset, client_loader_dict
Exemplo n.º 14
0
def get_dataset(args, n_list, k_list):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """
    data_dir = args.data_dir + args.dataset
    if args.dataset == 'mnist':
        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        train_dataset = datasets.MNIST(data_dir,
                                       train=True,
                                       download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir,
                                      train=False,
                                      download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(args, train_dataset,
                                                   args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups, classes_list = mnist_noniid(
                    args, train_dataset, args.num_users, n_list, k_list)
                user_groups_lt = mnist_noniid_lt(args, test_dataset,
                                                 args.num_users, n_list,
                                                 k_list, classes_list)
                classes_list_gt = classes_list

    elif args.dataset == 'femnist':
        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        train_dataset = femnist.FEMNIST(args,
                                        data_dir,
                                        train=True,
                                        download=True,
                                        transform=apply_transform)
        test_dataset = femnist.FEMNIST(args,
                                       data_dir,
                                       train=False,
                                       download=True,
                                       transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = femnist_iid(train_dataset, args.num_users)
            # print("TBD")
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                # user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
                user_groups = femnist_noniid_unequal(args, train_dataset,
                                                     args.num_users)
                # print("TBD")
            else:
                # Chose euqal splits for every user
                user_groups, classes_list, classes_list_gt = femnist_noniid(
                    args, args.num_users, n_list, k_list)
                user_groups_lt = femnist_noniid_lt(args, args.num_users,
                                                   classes_list)

    elif args.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10(data_dir,
                                         train=True,
                                         download=True,
                                         transform=trans_cifar10_train)
        test_dataset = datasets.CIFAR10(data_dir,
                                        train=False,
                                        download=True,
                                        transform=trans_cifar10_val)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups, classes_list, classes_list_gt = cifar10_noniid(
                    args, train_dataset, args.num_users, n_list, k_list)
                user_groups_lt = cifar10_noniid_lt(args, test_dataset,
                                                   args.num_users, n_list,
                                                   k_list, classes_list)

    elif args.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100(data_dir,
                                          train=True,
                                          download=True,
                                          transform=trans_cifar100_train)
        test_dataset = datasets.CIFAR100(data_dir,
                                         train=False,
                                         download=True,
                                         transform=trans_cifar100_val)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups, classes_list = cifar100_noniid(
                    args, train_dataset, args.num_users, n_list, k_list)
                user_groups_lt = cifar100_noniid_lt(test_dataset,
                                                    args.num_users,
                                                    classes_list)

    return train_dataset, test_dataset, user_groups, user_groups_lt, classes_list, classes_list_gt