Пример #1
0
def main(loc_ep, weighted, alg):
    # parse args
    args = args_parser()
    algo_list = ['fedavg', 'fedprox', 'fsvgr']
    # define paths
    path_project = os.path.abspath('..')

    summary = SummaryWriter('local')
    args.gpu = 0  # -1 (CPU only) or GPU = 0
    args.lr = 0.002  # 0.001 for cifar dataset
    args.model = 'mlp'  # 'mlp' or 'cnn'
    args.dataset = 'mnist'  # 'cifar' or 'mnist'
    args.num_users = 5
    args.epochs = 30  # numb of global iters
    args.local_ep = loc_ep  # numb of local iters
    args.local_bs = 1201  # Local Batch size (>=1200 = full dataset size of a user for mnist, 2000 for cifar)
    args.algorithm = alg  # 'fedavg', 'fedprox', 'fsvgr'
    args.iid = False
    args.verbose = False
    print("dataset:", args.dataset, " num_users:", args.num_users, " epochs:", args.epochs, "local_ep:", args.local_ep)

    # load dataset and split users
    dict_users = {}
    dataset_train = []
    if args.dataset == 'mnist':
        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307,), (0.3081,))
                                       ]))
        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = mnist_noniid(dataset_train, args.num_users)
    elif args.dataset == 'cifar':
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('../data/cifar', train=True, transform=transform, target_transform=None,
                                         download=True)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            dict_users = cifar_noniid(dataset_train, args.num_users)
            # exit('Error: only consider IID setting in CIFAR10')
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape

    # build model
    net_glob = None
    if args.model == 'cnn' and args.dataset == 'cifar':
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            net_glob = CNNCifar(args=args).cuda()
        else:
            net_glob = CNNCifar(args=args)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            net_glob = CNNMnist(args=args).cuda()
        else:
            net_glob = CNNMnist(args=args)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        if args.gpu != -1:
            torch.cuda.set_device(args.gpu)
            # net_glob = MLP1(dim_in=len_in, dim_hidden=128, dim_out=args.num_classes).cuda()
            net_glob = MLP1(dim_in=len_in, dim_hidden=256, dim_out=args.num_classes).cuda()
        else:
            # net_glob = MLP1(dim_in=len_in, dim_hidden=128, dim_out=args.num_classes)
            net_glob = MLP1(dim_in=len_in, dim_hidden=256, dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')
    print("Nerual Net:", net_glob)

    net_glob.train()  # Train() does not change the weight values
    # copy weights
    w_glob = net_glob.state_dict()

    # w_size = 0
    # for k in w_glob.keys():
    #     size = w_glob[k].size()
    #     if (len(size) == 1):
    #         nelements = size[0]
    #     else:
    #         nelements = size[0] * size[1]
    #     w_size += nelements * 4
    #     # print("Size ", k, ": ",nelements*4)
    # print("Weight Size:", w_size, " bytes")
    # print("Weight & Grad Size:", w_size * 2, " bytes")
    # print("Each user Training size:", 784 * 8 / 8 * args.local_bs, " bytes")
    # print("Total Training size:", 784 * 8 / 8 * 60000, " bytes")
    # # training
    global_grad = []
    user_grads = []
    loss_test = []
    acc_test = []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0
    net_best = None
    val_acc_list, net_list = [], []
    # print(dict_users.keys())

    rs_avg_acc, rs_avg_loss, rs_glob_acc, rs_glob_loss= [], [], [], []

    ###  FedAvg Aglorithm  ###
    if args.algorithm == 'fedavg':
        # for iter in tqdm(range(args.epochs)):
        for iter in range(args.epochs):
            w_locals, loss_locals, acc_locals, num_samples_list = [], [], [], []
            for idx in range(args.num_users):
                if(args.local_bs>1200):
                    local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary, bs=200*(4+idx)) #Batch_size bs = full data
                else:
                    local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary, bs=args.local_bs)
                num_samples, w, loss, acc = local.update_weights(net=copy.deepcopy(net_glob))
                num_samples_list.append(num_samples)
                w_locals.append(copy.deepcopy(w))
                loss_locals.append(copy.deepcopy(loss))
                # print("User ", idx, " Acc:", acc, " Loss:", loss)
                acc_locals.append(copy.deepcopy(acc))
            # update global weights
            if(weighted):
                w_glob = weighted_average_weights(w_locals, num_samples_list)
            else:
                w_glob = average_weights(w_locals)


            # copy weight to net_glob
            net_glob.load_state_dict(w_glob)
            # global test
            list_acc, list_loss = [], []
            net_glob.eval()
            # for c in tqdm(range(args.num_users)):
            for c in range(args.num_users):
                if (args.local_bs > 1200):
                    net_local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary, bs=200*(4+c)) #Batch_size bs = full data
                else:
                    net_local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary, bs=args.local_bs)
                acc, loss = net_local.test(net=net_glob)
                list_acc.append(acc)
                list_loss.append(loss)
            print("\nEpoch: {}, Global test loss {}, Global test acc: {:.2f}%".format(iter,
                                                                                      sum(list_loss) / len(list_loss),
                                                                                      100. * sum(list_acc) / len(
                                                                                          list_acc)))

            # print loss
            loss_avg = sum(loss_locals) / len(loss_locals)
            acc_avg = sum(acc_locals) / len(acc_locals)
            if args.epochs % 1 == 0:
                print('\nUsers Train Average loss:', loss_avg)
                print('\nTrain Train Average accuracy', acc_avg)
            # loss_test.append(sum(list_loss) / len(list_loss))
            # acc_test.append(sum(list_acc) / len(list_acc))

            rs_avg_acc.append(acc_avg)
            rs_avg_loss.append(loss_avg)
            rs_glob_acc.append(sum(list_acc) / len(list_acc))
            rs_glob_loss.append(sum(list_loss) / len(list_loss))
            # if (acc_avg >= 0.89):
            #     return iter+1

    ###  FedProx Aglorithm  ###
    elif args.algorithm == 'fedprox':
        args.mu = 0.005  ### change mu 0.001
        args.limit = 0.3
        # for iter in tqdm(range(args.epochs)):
        for iter in range(args.epochs):
            w_locals, loss_locals, acc_locals, num_samples_list = [], [], [], []
            # m = max(int(args.frac * args.num_users), 1)
            # idxs_users = np.random.choice(range(args.num_users), m, replace=False)
            for idx in range(args.num_users):
                if(args.local_bs>1200):
                    local = LocalFedProxUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary, bs=200*(4+idx)) #Batch_size bs = full data
                else:
                    local = LocalFedProxUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary, bs=args.local_bs)

                num_samples, w, loss, acc = local.update_FedProx_weights(net=copy.deepcopy(net_glob))
                num_samples_list.append(num_samples)
                w_locals.append(copy.deepcopy(w))
                loss_locals.append(copy.deepcopy(loss))
                # print("User ", idx, " Acc:", acc, " Loss:", loss)
                acc_locals.append(copy.deepcopy(acc))
            # update global weights
            if(weighted):
                w_glob = weighted_average_weights(w_locals, num_samples_list)
            else:
                w_glob = average_weights(w_locals)

            # copy weight to net_glob
            net_glob.load_state_dict(w_glob)
            # global test
            list_acc, list_loss = [], []
            net_glob.eval()
            # for c in tqdm(range(args.num_users)):
            for c in range(args.num_users):
                if (args.local_bs > 1200):
                    net_local = LocalFedProxUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary,
                                            bs=200 * (4 + c))  # Batch_size bs = full data
                else:
                    net_local = LocalFedProxUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary,
                                            bs=args.local_bs)

                acc, loss = net_local.test(net=net_glob)
                list_acc.append(acc)
                list_loss.append(loss)
            print("\nEpoch: {}, Global test loss {}, Global test acc: {:.2f}%".format(iter,
                                                                                      sum(list_loss) / len(list_loss),
                                                                                      100. * sum(list_acc) / len(
                                                                                          list_acc)))

            # print loss
            loss_avg = sum(loss_locals) / len(loss_locals)
            acc_avg = sum(acc_locals) / len(acc_locals)
            if args.epochs % 1 == 0:
                print('\nUsers train average loss:', loss_avg)
                print('\nUsers train average accuracy', acc_avg)
            # loss_test.append(sum(list_loss) / len(list_loss))
            # acc_test.append(sum(list_acc) / len(list_acc))

            rs_avg_acc.append(acc_avg)
            rs_avg_loss.append(loss_avg)
            rs_glob_acc.append(sum(list_acc) / len(list_acc))
            rs_glob_loss.append(sum(list_loss) / len(list_loss))
            # if (acc_avg >= 0.89):
            #     return iter+1

    ###  FSVGR Aglorithm  ###
    elif args.algorithm == 'fsvgr':
        args.ag_scalar = 1.  # 0.001 or 0.1
        args.lg_scalar = 1
        args.threshold = 0.001
        # for iter in tqdm(range(args.epochs)):
        for iter in range(args.epochs):

            print("=========Global epoch {}=========".format(iter))
            w_locals, loss_locals, acc_locals = [], [], []
            """
            First communication round: server send w_t to client --> client calculate gradient and send
            to sever --> server calculate average global gradient and send to client
            """
            for idx in range(args.num_users):
                local = LocalFSVGRUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary)
                num_sample, grad_k = local.calculate_global_grad(net=copy.deepcopy(net_glob))
                user_grads.append([num_sample, grad_k])
            global_grad = calculate_avg_grad(user_grads)

            """
            Second communication round: client update w_k_t+1 and send to server --> server update global w_t+1
            """
            for idx in range(args.num_users):
                print("Training user {}".format(idx))
                local = LocalFSVGRUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary)
                num_samples, w_k, loss, acc = local.update_FSVGR_weights(global_grad, idx, copy.deepcopy(net_glob),
                                                                         iter)
                w_locals.append(copy.deepcopy([num_samples, w_k]))
                print("Global_Epoch ", iter, "User ", idx, " Acc:", acc, " Loss:", loss)
                loss_locals.append(copy.deepcopy(loss))
                acc_locals.append(copy.deepcopy(acc))

            # w_t = net_glob.state_dict()
            w_glob = average_FSVRG_weights(w_locals, args.ag_scalar, copy.deepcopy(net_glob), args.gpu)

            # copy weight to net_glob
            net_glob.load_state_dict(w_glob)

            # global test
            list_acc, list_loss = [], []
            net_glob.eval()
            # for c in tqdm(range(args.num_users)):
            for c in range(args.num_users):
                net_local = LocalFSVGRUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary)
                acc, loss = net_local.test(net=net_glob)
                list_acc.append(acc)
                list_loss.append(loss)

            print("\nTest Global Weight:", list_acc)
            print("\nEpoch: {}, Global test loss {}, Global test acc: {:.2f}%".format(iter,
                                                                                      sum(list_loss) / len(list_loss),
                                                                                      100. * sum(list_acc) / len(
                                                                                          list_acc)))

            # print loss
            loss_avg = sum(loss_locals) / len(loss_locals)
            acc_avg = sum(acc_locals) / len(acc_locals)
            if iter % 1 == 0:
                print('\nEpoch: {}, Users train average loss: {}'.format(iter, loss_avg))
                print('\nEpoch: {}, Users train average accuracy: {}'.format(iter, acc_avg))
            loss_test.append(sum(list_loss) / len(list_loss))
            acc_test.append(sum(list_acc) / len(list_acc))

            if (acc_avg >= 0.89):
                return
    if(weighted):
        alg=alg+'1'
    simple_save_data(loc_ep, alg, rs_avg_acc, rs_avg_loss, rs_glob_acc, rs_glob_loss)
    plot_rs(loc_ep, alg)
Пример #2
0
        w_locals, loss_locals = [], []
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        for idx in idxs_users:
            local = LocalUpdate(args=args,
                                dataset=dataset_train,
                                idxs=dict_users[idx],
                                tb=summary)
            w, loss = local.update_weights(net=copy.deepcopy(net_glob))
            w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))
        # update global weights
        w_glob = average_weights(w_locals)

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)

        # print loss
        loss_avg = sum(loss_locals) / len(loss_locals)
        if args.epochs % 10 == 0:
            print('\nTrain loss:', loss_avg)
        loss_train.append(loss_avg)

    # plot loss curve
    plt.figure()
    plt.plot(range(len(loss_train)), loss_train)
    plt.ylabel('train_loss')
    plt.savefig('../save/fed_{}_{}_{}_C{}_iid{}.png'.format(
        args.dataset, args.model, args.epochs, args.frac, args.iid))

    # testing