コード例 #1
0
        print 'image size:', img_size
        dataset_test = datasets.CIFAR10('../data/cifar',
                                        train=False,
                                        transform=transform,
                                        target_transform=None,
                                        download=True)

        test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
    else:
        exit('Error: unrecognized dataset')

    # build model
    if args.model == 'cnn' and args.dataset == 'cifar':
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            net_glob = CNNCifar(args=args).cuda()
        else:
            net_glob = CNNCifar(args=args)

    elif args.model == 'cnn' and args.dataset == 'mnist':
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            net_glob = CNNMnist(args=args).cuda()
        else:
            net_glob = CNNMnist(args=args)

    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        if args.gpu != -1:
コード例 #2
0
ファイル: fedavg.py プロジェクト: liaojie-box/Files
def main(args):
    #####-Choose Variable-#####
    set_variable = args.set_num_Chosenusers
    set_variable0 = copy.deepcopy(args.set_epochs)
    set_variable1 = copy.deepcopy(args.set_degree_noniid)

    if not os.path.exists('./experiresult'):
        os.mkdir('./experiresult')

    # load dataset and split users
    dict_users, dict_users_train, dict_users_test = {}, {}, {}
    dataset_train, dataset_test = [], []
    if args.dataset == 'mnist':
        dataset_train = datasets.MNIST('./dataset/mnist/',
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        dataset_test = datasets.MNIST('./dataset/mnist/',
                                      train=False,
                                      download=True,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))
        # sample users
        if args.iid:
            dict_users = mnist_iid(args, dataset_train, args.num_users,
                                   args.num_items_train)
            # dict_users_test = mnist_iid(dataset_test, args.num_users, args.num_items_test)
            dict_sever = mnist_iid(args, dataset_test, args.num_users,
                                   args.num_items_test)
        else:
            dict_users = mnist_noniid(args, dataset_train, args.num_users,
                                      args.num_items_train)
            dict_sever = mnist_noniid(args, dataset_test, args.num_users,
                                      args.num_items_test)

    elif args.dataset == 'cifar':
        dict_users_train, dict_sever = {}, {}
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset_train = datasets.CIFAR10('./dataset/cifar/',
                                         train=True,
                                         transform=transform,
                                         target_transform=None,
                                         download=True)
        dataset_test = copy.deepcopy(dataset_train)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users,
                                   args.num_items_train)
            num_train = int(0.6 * args.num_items_train)
            for idx in range(args.num_users):
                dict_users_train[idx] = set(list(dict_users[idx])[:num_train])
                dict_sever[idx] = set(list(dict_users[idx])[num_train:])
        else:
            dict_users = cifar_noniid(args, dataset_train, args.num_users,
                                      args.num_items_train)
            dict_test = []
            num_train = int(0.6 * args.num_items_train)
            for idx in range(args.num_users):
                dict_users_train[idx] = set(list(dict_users[idx])[:num_train])
                dict_sever[idx] = set(list(dict_users[idx])[num_train:])

    # sample users
    if args.iid:
        dict_users = mnist_iid(args, dataset_train, args.num_users,
                               args.num_items_train)
        # dict_users_test = mnist_iid(dataset_test, args.num_users, args.num_items_test)
        dict_sever = mnist_iid(args, dataset_test, args.num_users,
                               args.num_items_test)
    else:
        dict_users = mnist_noniid(args, dataset_train, args.num_users,
                                  args.num_items_train)
        dict_sever = mnist_iid(args, dataset_test, args.num_users,
                               args.num_items_test)

    img_size = dataset_train[0][0].shape

    for v in range(len(set_variable)):
        final_train_loss = [[0 for i in range(len(set_variable1))]
                            for j in range(len(set_variable0))]
        final_train_accuracy = [[0 for i in range(len(set_variable1))]
                                for j in range(len(set_variable0))]
        final_test_loss = [[0 for i in range(len(set_variable1))]
                           for j in range(len(set_variable0))]
        final_test_accuracy = [[0 for i in range(len(set_variable1))]
                               for j in range(len(set_variable0))]
        final_com_cons = [[0 for i in range(len(set_variable1))]
                          for j in range(len(set_variable0))]
        args.num_Chosenusers = copy.deepcopy(set_variable[v])
        for s in range(len(set_variable0)):
            for j in range(len(set_variable1)):
                args.epochs = copy.deepcopy(set_variable0[s])
                args.degree_noniid = copy.deepcopy(set_variable1[j])
                print(args)
                loss_test, loss_train = [], []
                acc_test, acc_train = [], []
                for m in range(args.num_experiments):
                    # build model
                    net_glob = None
                    if args.model == 'cnn' and args.dataset == 'mnist':
                        if args.gpu != -1:
                            torch.cuda.set_device(args.gpu)
                            # net_glob = CNNMnist(args=args).cuda()
                            net_glob = CNN_test(args=args).cuda()
                        else:
                            net_glob = CNNMnist(args=args)
                    elif args.model == 'mlp' and args.dataset == 'mnist':
                        len_in = 1
                        for x in img_size:
                            len_in *= x
                        if args.gpu != -1:
                            torch.cuda.set_device(args.gpu)
                            net_glob = MLP1(dim_in=len_in,
                                            dim_hidden=256,
                                            dim_out=args.num_classes).cuda()
                        else:
                            net_glob = MLP1(dim_in=len_in,
                                            dim_hidden=256,
                                            dim_out=args.num_classes)
                    elif args.model == 'cnn' and args.dataset == 'cifar':
                        if args.gpu != -1:
                            net_glob = CNNCifar(args).cuda()
                        else:
                            net_glob = CNNCifar(args)
                    else:
                        exit('Error: unrecognized model')
                    print("Nerual Net:", net_glob)

                    net_glob.train(
                    )  #Train() does not change the weight values
                    # copy weights
                    w_glob = net_glob.state_dict()
                    w_size = 0
                    w_size_all = 0
                    for k in w_glob.keys():
                        size = w_glob[k].size()
                        if (len(size) == 1):
                            nelements = size[0]
                        else:
                            nelements = size[0] * size[1]
                        w_size += nelements * 4
                        w_size_all += nelements
                        # print("Size ", k, ": ",nelements*4)
                    print("Weight Size:", w_size, " bytes")
                    print("Weight & Grad Size:", w_size * 2, " bytes")
                    print("Each user Training size:",
                          784 * 8 / 8 * args.local_bs, " bytes")
                    print("Total Training size:", 784 * 8 / 8 * 60000,
                          " bytes")
                    # training
                    loss_avg_list, acc_avg_list, list_loss, loss_avg, com_cons = [], [], [], [], []
                    ###  FedAvg Aglorithm  ###
                    for iter in range(args.epochs):
                        print('\n', '*' * 20, f'Epoch: {iter}', '*' * 20)
                        if args.num_Chosenusers < args.num_users:
                            chosenUsers = random.sample(
                                range(args.num_users), args.num_Chosenusers)
                            chosenUsers.sort()
                        else:
                            chosenUsers = range(args.num_users)
                        print("\nChosen users:", chosenUsers)
                        w_locals, w_locals_1ep, loss_locals, acc_locals = [], [], [], []

                        values_golbal = []
                        for i in w_glob.keys():
                            values_golbal += list(
                                w_glob[i].view(-1).cpu().numpy())

                        for idx in range(len(chosenUsers)):
                            local = LocalUpdate(
                                args=args,
                                dataset=dataset_train,
                                idxs=dict_users[chosenUsers[idx]],
                                tb=summary)
                            w_1st_ep, w, loss, acc = local.update_weights(
                                net=copy.deepcopy(net_glob))
                            w_locals.append(copy.deepcopy(w))
                            ### get 1st ep local weights ###
                            w_locals_1ep.append(copy.deepcopy(w_1st_ep))
                            loss_locals.append(copy.deepcopy(loss))
                            # print("User ", chosenUsers[idx], " Acc:", acc, " Loss:", loss)
                            acc_locals.append(copy.deepcopy(acc))

                            # histogram for all clients
                            values_local = []
                            for i in w_glob.keys():
                                values_local += list(
                                    w[i].view(-1).cpu().numpy())
                            values_increment = [
                                values_local[i] - values_golbal[i]
                                for i in range(len(values_local))
                            ]
                            value_sequence = sorted(
                                [d for d in values_increment],
                                reverse=True)  # value sequence
                            hist, bin_edges = np.histogram(value_sequence,
                                                           bins=100)
                            # valueCount = collections.Counter(hist)
                            # val, cnt = zip(*valueCount.items())
                            #print(hist, bin_edges)
                            # fig, ax = plt.subplots()
                            plt.close()
                            # plt.bar(range(len(hist)), hist, width=0.80, color='b')
                            # plt.close()
                            # plt.hist(value_sequence,bin_edges,color='b',alpha=0.8, rwidth=0.8)
                            plt.hist(value_sequence,
                                     bin_edges,
                                     color='steelblue',
                                     edgecolor='black',
                                     alpha=0.8)
                            plt.savefig(
                                './histogra/histogra-{}-client-{}-iter-{}.pdf'.
                                format(args.model, idx, iter))
                            plt.show()

                        # malicious_users = [0, 3]
                        # w_locals = noise_add(args, w_locals, 0.001, malicious_users)
                        ### update global weights ###
                        # w_locals = users_sampling(args, w_locals, chosenUsers)
                        w_glob = average_weights(w_locals)

                        # val_min_dist, val_mah_dist = [], []
                        # for i in range(len(w_locals)):
                        #     val_min_dist.append(minkowski_distance(w_locals[i],w_glob,1))
                        #     val_mah_dist.append(mahala_distance(w_locals[i],w_glob,w_locals,5))
                        # print('Minkowski distance:', val_mah_dist)
                        # print('Mahala distance:', val_min_dist)

                        # copy weight to net_glob
                        net_glob.load_state_dict(w_glob)
                        # global test
                        list_acc, list_loss = [], []
                        net_glob.eval()
                        for c in range(args.num_users):
                            net_local = LocalUpdate(args=args,
                                                    dataset=dataset_test,
                                                    idxs=dict_sever[idx],
                                                    tb=summary)
                            acc, loss = net_local.test(net=net_glob)
                            # acc, loss = net_local.test_gen(net=net_glob, idxs=dict_users[c], dataset=dataset_test)
                            list_acc.append(acc)
                            list_loss.append(loss)
                        # print("\nEpoch: {}, Global test loss {}, Global test acc: {:.2f}%".\
                        #      format(iter, sum(list_loss) / len(list_loss),100. * sum(list_acc) / len(list_acc)))
                        # print loss
                        loss_avg = sum(loss_locals) / len(loss_locals)
                        acc_avg = sum(acc_locals) / len(acc_locals)
                        loss_avg_list.append(loss_avg)
                        acc_avg_list.append(acc_avg)
                        print("\nTrain loss: {}, Train acc: {}".\
                              format(loss_avg_list[-1], acc_avg_list[-1]))
                        print("\nTest loss: {}, Test acc: {}".\
                              format(sum(list_loss) / len(list_loss), sum(list_acc) / len(list_acc)))

                        # if (iter+1)%20==0:
                        #     torch.save(net_glob.state_dict(),'./Train_model/glob_model_{}epochs.pth'.format(iter))

                    loss_train.append(loss_avg)
                    acc_train.append(acc_avg)
                    loss_test.append(sum(list_loss) / len(list_loss))
                    acc_test.append(sum(list_acc) / len(list_acc))
                    com_cons.append(iter + 1)
                # plot loss curve
                final_train_loss[s][j] = copy.deepcopy(
                    sum(loss_train) / len(loss_train))
                final_train_accuracy[s][j] = copy.deepcopy(
                    sum(acc_train) / len(acc_train))
                final_test_loss[s][j] = copy.deepcopy(
                    sum(loss_test) / len(loss_test))
                final_test_accuracy[s][j] = copy.deepcopy(
                    sum(acc_test) / len(acc_test))
                final_com_cons[s][j] = copy.deepcopy(
                    sum(com_cons) / len(com_cons))

            print('\nFinal train loss:', final_train_loss)
            print('\nFinal train accuracy:', final_train_accuracy)
            print('\nFinal test loss:', final_test_loss)
            print('\nFinal test accuracy:', final_test_accuracy)
        timeslot = int(time.time())
        data_test_loss = pd.DataFrame(index=set_variable0,
                                      columns=set_variable1,
                                      data=final_train_loss)
        data_test_loss.to_csv(
            './experiresult/' +
            'train_loss_{}_{}.csv'.format(set_variable[v], timeslot))
        data_test_loss = pd.DataFrame(index=set_variable0,
                                      columns=set_variable1,
                                      data=final_test_loss)
        data_test_loss.to_csv(
            './experiresult/' +
            'test_loss_{}_{}.csv'.format(set_variable[v], timeslot))
        data_test_acc = pd.DataFrame(index=set_variable0,
                                     columns=set_variable1,
                                     data=final_train_accuracy)
        data_test_acc.to_csv(
            './experiresult/' +
            'train_acc_{}_{}.csv'.format(set_variable[v], timeslot))
        data_test_acc = pd.DataFrame(index=set_variable0,
                                     columns=set_variable1,
                                     data=final_test_accuracy)
        data_test_acc.to_csv(
            './experiresult/' +
            'test_acc_{}_{}.csv'.format(set_variable[v], timeslot))
        data_test_acc = pd.DataFrame(index=set_variable0,
                                     columns=set_variable1,
                                     data=final_com_cons)
        data_test_acc.to_csv('./experiresult/' +
                             'aggregation_consuming_{}_{}.csv'.format(
                                 set_variable[v], timeslot))

        plt.close()
コード例 #3
0
ファイル: main_fed.py プロジェクト: zhaoyang626/OnDevAI
def main(loc_ep, weighted, alg):
    # parse args
    args = args_parser()
    algo_list = ['fedavg', 'fedprox', 'fsvgr']
    # define paths
    path_project = os.path.abspath('..')

    summary = SummaryWriter('local')
    args.gpu = 0  # -1 (CPU only) or GPU = 0
    args.lr = 0.002  # 0.001 for cifar dataset
    args.model = 'mlp'  # 'mlp' or 'cnn'
    args.dataset = 'mnist'  # 'cifar' or 'mnist'
    args.num_users = 5
    args.epochs = 30  # numb of global iters
    args.local_ep = loc_ep  # numb of local iters
    args.local_bs = 1201  # Local Batch size (>=1200 = full dataset size of a user for mnist, 2000 for cifar)
    args.algorithm = alg  # 'fedavg', 'fedprox', 'fsvgr'
    args.iid = False
    args.verbose = False
    print("dataset:", args.dataset, " num_users:", args.num_users, " epochs:", args.epochs, "local_ep:", args.local_ep)

    # load dataset and split users
    dict_users = {}
    dataset_train = []
    if args.dataset == 'mnist':
        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307,), (0.3081,))
                                       ]))
        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = mnist_noniid(dataset_train, args.num_users)
    elif args.dataset == 'cifar':
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('../data/cifar', train=True, transform=transform, target_transform=None,
                                         download=True)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            dict_users = cifar_noniid(dataset_train, args.num_users)
            # exit('Error: only consider IID setting in CIFAR10')
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape

    # build model
    net_glob = None
    if args.model == 'cnn' and args.dataset == 'cifar':
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            net_glob = CNNCifar(args=args).cuda()
        else:
            net_glob = CNNCifar(args=args)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            net_glob = CNNMnist(args=args).cuda()
        else:
            net_glob = CNNMnist(args=args)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            # net_glob = MLP1(dim_in=len_in, dim_hidden=128, dim_out=args.num_classes).cuda()
            net_glob = MLP1(dim_in=len_in, dim_hidden=256, dim_out=args.num_classes).cuda()
        else:
            # net_glob = MLP1(dim_in=len_in, dim_hidden=128, dim_out=args.num_classes)
            net_glob = MLP1(dim_in=len_in, dim_hidden=256, dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')
    print("Nerual Net:", net_glob)

    net_glob.train()  # Train() does not change the weight values
    # copy weights
    w_glob = net_glob.state_dict()

    # w_size = 0
    # for k in w_glob.keys():
    #     size = w_glob[k].size()
    #     if (len(size) == 1):
    #         nelements = size[0]
    #     else:
    #         nelements = size[0] * size[1]
    #     w_size += nelements * 4
    #     # print("Size ", k, ": ",nelements*4)
    # print("Weight Size:", w_size, " bytes")
    # print("Weight & Grad Size:", w_size * 2, " bytes")
    # print("Each user Training size:", 784 * 8 / 8 * args.local_bs, " bytes")
    # print("Total Training size:", 784 * 8 / 8 * 60000, " bytes")
    # # training
    global_grad = []
    user_grads = []
    loss_test = []
    acc_test = []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0
    net_best = None
    val_acc_list, net_list = [], []
    # print(dict_users.keys())

    rs_avg_acc, rs_avg_loss, rs_glob_acc, rs_glob_loss= [], [], [], []

    ###  FedAvg Aglorithm  ###
    if args.algorithm == 'fedavg':
        # for iter in tqdm(range(args.epochs)):
        for iter in range(args.epochs):
            w_locals, loss_locals, acc_locals, num_samples_list = [], [], [], []
            for idx in range(args.num_users):
                if(args.local_bs>1200):
                    local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary, bs=200*(4+idx)) #Batch_size bs = full data
                else:
                    local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary, bs=args.local_bs)
                num_samples, w, loss, acc = local.update_weights(net=copy.deepcopy(net_glob))
                num_samples_list.append(num_samples)
                w_locals.append(copy.deepcopy(w))
                loss_locals.append(copy.deepcopy(loss))
                # print("User ", idx, " Acc:", acc, " Loss:", loss)
                acc_locals.append(copy.deepcopy(acc))
            # update global weights
            if(weighted):
                w_glob = weighted_average_weights(w_locals, num_samples_list)
            else:
                w_glob = average_weights(w_locals)


            # copy weight to net_glob
            net_glob.load_state_dict(w_glob)
            # global test
            list_acc, list_loss = [], []
            net_glob.eval()
            # for c in tqdm(range(args.num_users)):
            for c in range(args.num_users):
                if (args.local_bs > 1200):
                    net_local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary, bs=200*(4+c)) #Batch_size bs = full data
                else:
                    net_local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary, bs=args.local_bs)
                acc, loss = net_local.test(net=net_glob)
                list_acc.append(acc)
                list_loss.append(loss)
            print("\nEpoch: {}, Global test loss {}, Global test acc: {:.2f}%".format(iter,
                                                                                      sum(list_loss) / len(list_loss),
                                                                                      100. * sum(list_acc) / len(
                                                                                          list_acc)))

            # print loss
            loss_avg = sum(loss_locals) / len(loss_locals)
            acc_avg = sum(acc_locals) / len(acc_locals)
            if args.epochs % 1 == 0:
                print('\nUsers Train Average loss:', loss_avg)
                print('\nTrain Train Average accuracy', acc_avg)
            # loss_test.append(sum(list_loss) / len(list_loss))
            # acc_test.append(sum(list_acc) / len(list_acc))

            rs_avg_acc.append(acc_avg)
            rs_avg_loss.append(loss_avg)
            rs_glob_acc.append(sum(list_acc) / len(list_acc))
            rs_glob_loss.append(sum(list_loss) / len(list_loss))
            # if (acc_avg >= 0.89):
            #     return iter+1

    ###  FedProx Aglorithm  ###
    elif args.algorithm == 'fedprox':
        args.mu = 0.005  ### change mu 0.001
        args.limit = 0.3
        # for iter in tqdm(range(args.epochs)):
        for iter in range(args.epochs):
            w_locals, loss_locals, acc_locals, num_samples_list = [], [], [], []
            # m = max(int(args.frac * args.num_users), 1)
            # idxs_users = np.random.choice(range(args.num_users), m, replace=False)
            for idx in range(args.num_users):
                if(args.local_bs>1200):
                    local = LocalFedProxUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary, bs=200*(4+idx)) #Batch_size bs = full data
                else:
                    local = LocalFedProxUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary, bs=args.local_bs)

                num_samples, w, loss, acc = local.update_FedProx_weights(net=copy.deepcopy(net_glob))
                num_samples_list.append(num_samples)
                w_locals.append(copy.deepcopy(w))
                loss_locals.append(copy.deepcopy(loss))
                # print("User ", idx, " Acc:", acc, " Loss:", loss)
                acc_locals.append(copy.deepcopy(acc))
            # update global weights
            if(weighted):
                w_glob = weighted_average_weights(w_locals, num_samples_list)
            else:
                w_glob = average_weights(w_locals)

            # copy weight to net_glob
            net_glob.load_state_dict(w_glob)
            # global test
            list_acc, list_loss = [], []
            net_glob.eval()
            # for c in tqdm(range(args.num_users)):
            for c in range(args.num_users):
                if (args.local_bs > 1200):
                    net_local = LocalFedProxUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary,
                                            bs=200 * (4 + c))  # Batch_size bs = full data
                else:
                    net_local = LocalFedProxUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary,
                                            bs=args.local_bs)

                acc, loss = net_local.test(net=net_glob)
                list_acc.append(acc)
                list_loss.append(loss)
            print("\nEpoch: {}, Global test loss {}, Global test acc: {:.2f}%".format(iter,
                                                                                      sum(list_loss) / len(list_loss),
                                                                                      100. * sum(list_acc) / len(
                                                                                          list_acc)))

            # print loss
            loss_avg = sum(loss_locals) / len(loss_locals)
            acc_avg = sum(acc_locals) / len(acc_locals)
            if args.epochs % 1 == 0:
                print('\nUsers train average loss:', loss_avg)
                print('\nUsers train average accuracy', acc_avg)
            # loss_test.append(sum(list_loss) / len(list_loss))
            # acc_test.append(sum(list_acc) / len(list_acc))

            rs_avg_acc.append(acc_avg)
            rs_avg_loss.append(loss_avg)
            rs_glob_acc.append(sum(list_acc) / len(list_acc))
            rs_glob_loss.append(sum(list_loss) / len(list_loss))
            # if (acc_avg >= 0.89):
            #     return iter+1

    ###  FSVGR Aglorithm  ###
    elif args.algorithm == 'fsvgr':
        args.ag_scalar = 1.  # 0.001 or 0.1
        args.lg_scalar = 1
        args.threshold = 0.001
        # for iter in tqdm(range(args.epochs)):
        for iter in range(args.epochs):

            print("=========Global epoch {}=========".format(iter))
            w_locals, loss_locals, acc_locals = [], [], []
            """
            First communication round: server send w_t to client --> client calculate gradient and send
            to sever --> server calculate average global gradient and send to client
            """
            for idx in range(args.num_users):
                local = LocalFSVGRUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary)
                num_sample, grad_k = local.calculate_global_grad(net=copy.deepcopy(net_glob))
                user_grads.append([num_sample, grad_k])
            global_grad = calculate_avg_grad(user_grads)

            """
            Second communication round: client update w_k_t+1 and send to server --> server update global w_t+1
            """
            for idx in range(args.num_users):
                print("Training user {}".format(idx))
                local = LocalFSVGRUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary)
                num_samples, w_k, loss, acc = local.update_FSVGR_weights(global_grad, idx, copy.deepcopy(net_glob),
                                                                         iter)
                w_locals.append(copy.deepcopy([num_samples, w_k]))
                print("Global_Epoch ", iter, "User ", idx, " Acc:", acc, " Loss:", loss)
                loss_locals.append(copy.deepcopy(loss))
                acc_locals.append(copy.deepcopy(acc))

            # w_t = net_glob.state_dict()
            w_glob = average_FSVRG_weights(w_locals, args.ag_scalar, copy.deepcopy(net_glob), args.gpu)

            # copy weight to net_glob
            net_glob.load_state_dict(w_glob)

            # global test
            list_acc, list_loss = [], []
            net_glob.eval()
            # for c in tqdm(range(args.num_users)):
            for c in range(args.num_users):
                net_local = LocalFSVGRUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary)
                acc, loss = net_local.test(net=net_glob)
                list_acc.append(acc)
                list_loss.append(loss)

            print("\nTest Global Weight:", list_acc)
            print("\nEpoch: {}, Global test loss {}, Global test acc: {:.2f}%".format(iter,
                                                                                      sum(list_loss) / len(list_loss),
                                                                                      100. * sum(list_acc) / len(
                                                                                          list_acc)))

            # print loss
            loss_avg = sum(loss_locals) / len(loss_locals)
            acc_avg = sum(acc_locals) / len(acc_locals)
            if iter % 1 == 0:
                print('\nEpoch: {}, Users train average loss: {}'.format(iter, loss_avg))
                print('\nEpoch: {}, Users train average accuracy: {}'.format(iter, acc_avg))
            loss_test.append(sum(list_loss) / len(list_loss))
            acc_test.append(sum(list_acc) / len(list_acc))

            if (acc_avg >= 0.89):
                return
    if(weighted):
        alg=alg+'1'
    simple_save_data(loc_ep, alg, rs_avg_acc, rs_avg_loss, rs_glob_acc, rs_glob_loss)
    plot_rs(loc_ep, alg)
コード例 #4
0
                                         transform=transform,
                                         target_transform=None,
                                         download=True)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in CIFAR10')
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape

    # build model
    if args.model == 'cnn' and args.dataset == 'cifar':
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            net_glob = CNNCifar(args=args).cuda()
        else:
            net_glob = CNNCifar(args=args)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            net_glob = CNNMnist(args=args).cuda()
        else:
            net_glob = CNNMnist(args=args)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            net_glob = MLP(dim_in=len_in,
コード例 #5
0
                   ]))
        img_size = dataset_train[0][0].shape
    elif args.dataset == 'cifar':
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('../data/cifar', train=True, transform=transform, target_transform=None, download=True)
        img_size = dataset_train[0][0].shape
    else:
        exit('Error: unrecognized dataset')

    # build model
    if args.model == 'cnn' and args.dataset == 'cifar':
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            net_glob = CNNCifar(args=args).cuda()
        else:
            net_glob = CNNCifar(args=args)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            net_glob = CNNMnist(args=args).cuda()
        else:
            net_glob = CNNMnist(args=args)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).cuda()
コード例 #6
0
            dict_users = cifar_noniid(dataset_train, args.num_users)
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape

    # BUILD MODEL
    if args.model == 'cnn' and args.dataset == 'mnist':
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            net_glob = CNNMnist(args=args).cuda()
        else:
            net_glob = CNNMnist(args=args)
    elif args.model == 'cnn' and args.dataset == 'cifar':
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            net_glob = CNNCifar(args=args).cuda()
        else:
            net_glob = CNNCifar(args=args)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            net_glob = MLP(dim_in=len_in,
                           dim_hidden=64,
                           dim_out=args.num_classes).cuda()
        else:
            net_glob = MLP(dim_in=len_in,
                           dim_hidden=64,
                           dim_out=args.num_classes)