Example #1
0
def multi_train_local_dif(q_l, q_w, arguemnts, idx, loss, local_train_loader,
                          non_iid, model):

    #print("asdf11")

    print("training user: "******"ye22: " + str(non_iid[idx]))

    local = LocalUpdate(args=arguemnts,
                        user_num=idx,
                        loss_func=loss,
                        dataset=local_train_loader.dataset,
                        idxs=non_iid[idx])

    #print("asdf22")

    w, loss = local.train(net=model.to(arguemnts.device))

    #print("asdf33")

    #lock.acquire()
    q_l.put(loss)
    q_w.put(w)
    #lock.release()

    time.sleep(5)
Example #2
0
def test_img_local(net_g, dataset_test, dict_users, args):
    net_g.eval()
    # testing
    test_loss = 0
    correct = 0

    acc_locals, local_data_sizes, loss_locals = [], [], []
    idxs_users = range(args.num_users)
    for idx in idxs_users:
        local = LocalUpdate(args=args,
                            dataset=dataset_test,
                            idxs=dict_users[idx],
                            train=False)
        (acc, local_data_size), loss = local.test(net=net_g.to(args.device))
        acc_locals.append(acc)
        local_data_sizes.append(local_data_size)
        loss_locals.append(loss)
    test_loss = sum([i * j for i, j in zip(local_data_sizes, loss_locals)
                     ]) / sum(local_data_sizes)
    correct = sum([i * j for i, j in zip(local_data_sizes, acc_locals)])
    accuracy = correct / sum(local_data_sizes)
    if args.verbose:
        print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.
              format(test_loss, correct, sum(local_data_sizes), accuracy))

    return acc_locals, accuracy, test_loss
def train(net_glob, db, w_glob, args):
    # training
    loss_train = []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0
    net_best = None
    best_loss = None
    val_acc_list, net_list = [], []

    # originally assign clients and Fed Avg -> mediator Fed Avg
    if args.all_clients:
        print("Aggregation over all clients")
        w_locals = [w_glob for i in range(len(db.dp.mediator))]
    # 3 : for each synchronization round r=1; 2; . . . ; R do
    for iter in range(args.epochs):
        # 4 : for each mediator m in 1; 2; . . . ; M parallelly do
        for i, mdt in enumerate(db.mediator):
            # 5- :
            loss_locals = []
            if not args.all_clients:
                w_locals = []
            need_index = [db.dp.local_train_index[k] for k in mdt]
            local = LocalUpdate(args=args,
                                dataset=dp,
                                idxs=np.hstack(need_index))
            w, loss = local.train(net=copy.deepcopy(net_glob).to(
                args.device))  # for lEpoch in range(E): 在local.train完成
            if args.all_clients:
                w_locals[i] = copy.deepcopy(w)
            else:
                w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))

        # update global weights
        w_glob = FedAvg(w_locals)

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)

        # print loss
        loss_avg = sum(loss_locals) / len(loss_locals)
        print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
        loss_train.append(loss_avg)

    # plot loss curve
    plt.figure()
    plt.plot(range(len(loss_train)), loss_train)
    plt.ylabel('train_loss')
    plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(
        args.dataset, args.model, args.epochs, args.frac, args.iid))
    return net_glob
Example #4
0
def multi_train_local_dif(q_l, q_w, arguemnts, idx, loss, data_loader,
                          distribution, model):

    print("training user: " + str(idx))

    local = LocalUpdate(args=arguemnts,
                        user_num=idx,
                        loss_func=loss,
                        dataset=data_loader.dataset,
                        idxs=distribution[idx])
    w, loss = local.train(net=model.to(arguemnts.device))
    q_l.put(loss)
    q_w.put(w)
    time.sleep(5)
Example #5
0
def run(rank, world_size, loss_train, acc_train, dataset_train, idxs_users, net_glob, grc):
    # net_glob.load_state_dict(torch.load('net_state_dict.pt'))
    if rank == 0:
        #compressor, epoch, dgc
        foldername = f'{args.compressor}epoch{args.epochs}ratio{args.gsr}'
        tb = SummaryWriter("runs/" + foldername)
    round = 0
    for i in idxs_users:
        #for each epoch
        idx = dict_users[i[rank]]

        epoch_loss = torch.zeros(1)
        optimizer = torch.optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum)

        local = LocalUpdate(args=args, dataset=dataset_train, idxs=idx) #create LocalUpdate class
        train_loss = local.train(net=net_glob) #train local
        for index, (name, parameter) in enumerate(net_glob.named_parameters()):
                grad = parameter.grad.data
                grc.acc(grad)
                new_tensor = grc.step(grad, name)
                grad.copy_(new_tensor)
        optimizer.step()
        net_glob.zero_grad()
        epoch_loss += train_loss
        dist.reduce(epoch_loss, 0, dist.ReduceOp.SUM)

        net_glob.eval()
        train_acc = torch.zeros(1)
        acc, loss = local.inference(net_glob, dataset_train, idx)
        train_acc += acc
        dist.reduce(train_acc, 0, dist.ReduceOp.SUM)

        if rank == 0:
            torch.save(net_glob.state_dict(), 'net_state_dict.pt')
            epoch_loss /= world_size
            train_acc /= world_size
            loss_train[round] = epoch_loss[0]
            acc_train[round] = train_acc[0]
            tb.add_scalar("Loss", epoch_loss[0], round)
            tb.add_scalar("Accuracy", train_acc[0], round)
            tb.add_scalar("Uncompressed Size", grc.uncompressed_size, round)
            tb.add_scalar("Compressed Size", grc.size, round)
            if round % 50 == 0:
                print('Round {:3d}, Rank {:1d}, Average loss {:.6f}, Average Accuracy {:.2f}%'.format(round, dist.get_rank(), epoch_loss[0], train_acc[0]))
        round+=1
    if rank == 0:
        tb.close()
        print("Printing Compression Stats...")
        grc.printr()
Example #6
0
def clustering_encoder(dict_users, dataset_train, ae_model_dict, args):

    idxs_users = np.arange(args.num_users)

    centers = np.zeros((args.num_users, 2, 2))
    embedding_matrix = np.zeros((len(dict_users[0])*args.num_users, 2))
    for user_id in tqdm(idxs_users, desc='Custering in progress ...'):
        local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[user_id])
        
        user_dataset_train = local.ldr_train.dataset
            
        encoder = Encoder(ae_model_dict['model'], ae_model_dict['name'], 
                          args.model_root_dir, args.manifold_dim, 
                          user_dataset_train, user_id, device=args.device)
         
        encoder.autoencoder()
        #encoder.manifold_approximation_umap()
        #reducer = encoder.umap_reducer
        # embedding1 = encoder.umap_embedding
        embedding1 = encoder.ae_embedding_np
        
        # ----------------------------------
        # use Kmeans to cluster the data into 2 clusters
        X = list(embedding1)
        embedding_matrix[user_id*len(dict_users[0]): len(dict_users[0])*(user_id + 1),:] = embedding1
        kmeans = KMeans(n_clusters=2, random_state=0).fit(np.array(X))
        centers[user_id,:,:] = kmeans.cluster_centers_
    
    clustering_matrix_soft = np.zeros((num_users, num_users))
    clustering_matrix = np.zeros((num_users, num_users))

    for idx0 in idxs_users:
        for idx1 in idxs_users:
            c0 = centers[idx0]
            c1 = centers[idx1]
        
            distance = min_matching_distance(c0, c1)
            
            clustering_matrix_soft[idx0][idx1] = distance
        
            if distance < 1:
                clustering_matrix[idx0][idx1] = 1
            else:
                clustering_matrix[idx0][idx1] = 0

    return clustering_matrix, clustering_matrix_soft, centers, embedding_matrix
Example #7
0
def clustering_perfect(num_users, dict_users, dataset_train, cluster, args):
    idxs_users = np.arange(num_users)
    ar_label = np.zeros((args.num_users, args.num_classes))-1
    for idx in idxs_users:
        local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
        label_matrix = np.empty(0, dtype=int)
        for batch_idx, (images, labels) in enumerate(local.ldr_train):
            label_matrix = np.concatenate((label_matrix, labels.numpy()), axis=0)
        label_matrix = np.unique(label_matrix)
        ar_label[idx][0:len(label_matrix)] = label_matrix
    
    clustering_matrix = np.zeros((num_users, num_users))
    for idx in idxs_users:
        for idx0 in idxs_users:
            if np.all(ar_label[idx0][0:len(cluster.T)] == ar_label[idx][0:len(cluster.T)]):
                clustering_matrix[idx][idx0] = 1
                
    return clustering_matrix
Example #8
0
def clustering_pca_kmeans(dict_users, cluster, dataset_train, args):
    idxs_users = np.random.choice(args.num_users, args.num_users, replace=False)
    
    centers = np.empty((0, args.latent_dim), dtype=int)
    center_dict = {}
    embedding_matrix = np.zeros((len(dict_users[0])*args.num_users, args.latent_dim))
    
    for user_id in tqdm(idxs_users, desc='Clustering in progress ...'):
        local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[user_id])
        
        user_dataset_train = local.ldr_train.dataset
        
        user_data_np = np.squeeze(np.array([item[0].view(1, -1).numpy() for item in user_dataset_train]))
        if args.latent_dim > len(user_dataset_train):
            user_data_np = np.repeat(user_data_np, np.ceil(args.latent_dim/len(user_dataset_train)),axis=0) 
        pca = PCA(n_components=args.latent_dim)
        embedding = pca.fit_transform(user_data_np)
        
        kmeans = KMeans(n_clusters=5, random_state=43).fit(embedding)
        centers = np.vstack((centers, kmeans.cluster_centers_))
        
        center_dict[user_id] = kmeans.cluster_centers_
    
    clustering_matrix_soft = np.zeros((args.num_users, args.num_users))
    clustering_matrix = np.zeros((args.num_users, args.num_users))

    c_dict = center_dict

    for idx0 in tqdm(idxs_users, desc='Creating clustering matrix'):
        c0 = c_dict[idx0]
        for idx1 in idxs_users:
            c0 = c_dict[idx0]
            c1 = c_dict[idx1]
        
            distance = min_matching_distance(c0, c1)
            
            clustering_matrix_soft[idx0][idx1] = distance
        
            if distance < 1.2:
                clustering_matrix[idx0][idx1] = 1
            else:
                clustering_matrix[idx0][idx1] = 0
                
    return clustering_matrix, clustering_matrix_soft, centers, c_dict
Example #9
0
def clustering_umap(num_users, dict_users, dataset_train, args):
    reducer_loaded = pickle.load( open( f'{args.model_root_dir}/umap_reducer_EMNIST.p', "rb" ) )
    reducer = reducer_loaded

    idxs_users = np.arange(num_users)
    
    input_dim = dataset_train[0][0].shape[-1]
    channel_dim = dataset_train[0][0].shape[0]
    
    centers = np.zeros((num_users, 2, 2))
    for idx in tqdm(idxs_users, desc='Clustering progress'):
        images_matrix = np.empty((0, channel_dim*input_dim*input_dim))
        local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
        for batch_idx, (images, labels) in enumerate(local.ldr_train):#TODO: concatenate the matrices
            # if batch_idx == 3:# TODO: abalation test
            #     break
            ne = images.numpy().flatten().T.reshape((len(labels), channel_dim*input_dim*input_dim))
            images_matrix = np.vstack((images_matrix, ne))
        embedding1 = reducer.transform(images_matrix)
        X = list(embedding1)
        kmeans = KMeans(n_clusters=2, random_state=0).fit(np.array(X))
        centers[idx,:,:] = kmeans.cluster_centers_
    
    clustering_matrix_soft = np.zeros((num_users, num_users))
    clustering_matrix = np.zeros((num_users, num_users))

    for idx0 in tqdm(idxs_users, desc='Clustering matrix generation'):
        for idx1 in idxs_users:
            c0 = centers[idx0]
            c1 = centers[idx1]

            distance = min_matching_distance(c0, c1)

            clustering_matrix_soft[idx0][idx1] = distance
        
            if distance < 1:
                clustering_matrix[idx0][idx1] = 1
            else:
                clustering_matrix[idx0][idx1] = 0

    return clustering_matrix, clustering_matrix_soft, centers
Example #10
0
def run(rank, world_size, loss_train, acc_train, epoch, dataset_train, idx,
        net_glob):
    net_glob.load_state_dict(torch.load('net_state_dict.pt'))
    dgc_trainer = DGC(model=net_glob,
                      rank=rank,
                      size=world_size,
                      momentum=args.momentum,
                      full_update_layers=[4],
                      percentage=args.dgc)
    dgc_trainer.load_state_dict(torch.load('dgc_state_dict.pt'))

    epoch_loss = torch.zeros(1)
    for iter in range(args.local_ep):
        local = LocalUpdate(args=args, dataset=dataset_train,
                            idxs=idx)  #create LocalUpdate class
        b_loss = local.train(net=net_glob, world_size=world_size,
                             rank=rank)  #train local
        epoch_loss += b_loss
        if rank == 0:
            print("Local Epoch: {}, Local Epoch Loss: {}".format(iter, b_loss))
    dgc_trainer.gradient_update()
    epoch_loss /= args.local_ep
    dist.reduce(epoch_loss, 0, dist.ReduceOp.SUM)

    net_glob.eval()
    train_acc = torch.zeros(1)
    local = LocalUpdate(args=args, dataset=dataset_train,
                        idxs=idx)  #create LocalUpdate class
    acc, loss = local.inference(net_glob, dataset_train, idx)
    train_acc += acc
    dist.reduce(train_acc, 0, dist.ReduceOp.SUM)

    if rank == 0:
        torch.save(net_glob.state_dict(), 'net_state_dict.pt')
        torch.save(dgc_trainer.state_dict(), 'dgc_state_dict.pt')
        epoch_loss /= world_size
        train_acc /= world_size
        loss_train[epoch] = epoch_loss[0]
        acc_train[epoch] = train_acc[0]
        print(
            'Round {:3d}, Rank {:1d}, Average loss {:.6f}, Average Accuracy {:.2f}%'
            .format(epoch, dist.get_rank(), epoch_loss[0], train_acc[0]))
    #    dict_users=vehicle_iid(Vehicle_train(),USER)
    is_support = torch.cuda.is_available()
    if is_support:
        device = torch.device('cuda:0')
    dict_users = balanced_dataset(USER)
    #    dict_users=unbalanced_dataset(USER)
    net_glob = CNNLSTM()
    net_glob.train()
    w_glob = net_glob.state_dict()
    loss_train = []
    acc = 0
    for iter in range(EPOCH):
        w_locals, loss_locals = [], []
        idxs_users = USER
        for idx in range(idxs_users):
            local = LocalUpdate(dataset=Vehicle_train(datasets[idx]),
                                idxs=dict_users[idx])
            #            local = LocalUpdate( dataset=Vehicle_train(unbalanced_datasets[idx]), idxs=dict_users[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob))
            w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))
        w_glob = FedAvg(w_locals)
        net_glob.load_state_dict(w_glob)
        net_glob.eval()
        acc_test, loss_test = test_img(net_glob, Vehicle_test())
        loss_avg = sum(loss_locals) / len(loss_locals)
        print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
        loss_train.append(loss_avg)
    np.save('loss', loss_train)

    # testing
    #    net_glob=torch.load('model_fed.pkl')
Example #12
0
    fina_net_list = copy.deepcopy(net_local_list)

    results = []
    results.append(np.array([-1, acc_test_local, acc_test_avg, acc_test_local, None, None]))
    print('Round {:3d}, Acc (local): {:.2f}, Acc (avg): {:.2f}, Acc (local-best): {:.2f}'.format(
        -1, acc_test_local, acc_test_avg, acc_test_local))

    for iter in range(args.epochs):
        w_glob = {}
        loss_locals = []
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        w_keys_epoch = w_glob_keys

        for idx in idxs_users:
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users_train[idx])
            net_local = net_local_list[idx]

            w_local, loss = local.train(net=net_local.to(args.device), lr=args.lr)
            loss_locals.append(copy.deepcopy(loss))

            # sum up weights
            if len(w_glob) == 0:
                w_glob = copy.deepcopy(w_local)
            else:
                for k in w_keys_epoch:
                    w_glob[k] += w_local[k]
        loss_avg = sum(loss_locals) / len(loss_locals)
        loss_train.append(loss_avg)

        # get weighted average for global weights
        for idx_cluster, _users in clusters.items():
            idx_users, loss_local = [], []
            # 每个簇头用户 model初始状态
            net_tmp = copy.deepcopy(net_glob)
            # 遍历簇内的每个用户
            for user_key, user_val in _users.items():
                idx_users.append(int(user_key))  # 该簇内的所有用户idx
            # print(idx_users)

            # shuffle the in-cluster sequential order and randomly select a CH
            random.shuffle(idx_users)
            # each cluster is performed parallel
            start_time = time.time()
            for idx in idx_users:
                local = LocalUpdate(args=args,
                                    dataset=dataset_train,
                                    idxs=dict_users[idx])
                w, loss = local.train(net=copy.deepcopy(net_tmp).to(
                    args.device),
                                      stepsize=step_size)
                loss_local.append(copy.deepcopy(loss))
                # 用相邻节点的 model初始化下一节点的 model
                net_tmp.load_state_dict(w)
            # 一个簇内的用户按 seq 方式训练完成后,记录每个簇参与上传的 model
            end_time = time.time()
            w_clusters.append(copy.deepcopy(w))
            loss_clusters.append(sum(loss_local) / len(loss_local))
            comp_time.append(end_time - start_time)
        loss_avg = sum(loss_clusters) / len(loss_clusters)
        # 对每个簇产生的 model进行Aggregation
        start_time = time.time()
        if not args.all_clients:
            w_locals = []

        round_idx = valid_list[round]
        user_idx_this_round = round_idx[np.where(round_idx != -1)]

        # 随机
        # user_idx_this_round = np.random.choice(range(args.num_users), 10, replace=False)  # 在num_users里面选m个


        if len(user_idx_this_round) > 0:

            for idx in user_idx_this_round:

                local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])

                weight, loss = local.train(net=copy.deepcopy(global_net).to(args.device))

                if args.all_clients:
                    w_locals[idx] = copy.deepcopy(weight)
                else:
                    w_locals.append(copy.deepcopy(weight))

                loss_locals.append(copy.deepcopy(loss))

            # update global weights
            w_glob = FedAvg(w_locals)

            # copy weight to net_glob
            global_net.load_state_dict(w_glob)
Example #15
0
    val_acc_list, net_list = [], []

    if args.all_clients:
        print("Aggregation over all clients")
        w_locals = [w_glob for i in range(args.num_users)]
    for iter in range(args.epochs):  # 每一全局模型更新轮次
        loss_locals = []
        if not args.all_clients:
            w_locals = []
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(
            range(args.num_users), m,
            replace=False)  # 随机选择一部分Client,全部选择会增大通信量,且实验效果可能会不好
        for idx in idxs_users:  # 每一client并行地做
            local = LocalUpdate(
                args=args, dataset=dataset_train, idxs=dict_users[idx]
            )  # models/Update.py:在client端更新,获取当前Client训练得到的参数
            w, loss = local.train(net=copy.deepcopy(net_glob).to(
                args.device))  # 最重要!!! 服务器传给客户端传的当前全局模型!!!
            if args.all_clients:
                w_locals[idx] = copy.deepcopy(w)
            else:
                w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))
        # update global weights
        w_glob = FedAvg(w_locals)  # models/Fed.py:对所有的Client返回的参数聚合

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)

        # print loss

    for iter in range(args.epochs):

        net_glob.train()
        net_ema_glob.train()

        epoch_comu = []
        w_locals, w_ema_locals, loss_locals, loss_consistent_locals = [], [], [], []
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        for idx in idxs_users:
            dict_userepoch[idx] = dict_userepoch[idx]+1
            local = LocalUpdate(args=args, dataset=dataset_train, dataset_ema=dataset_train_ema, 
                                idxs=dict_users[idx], idxs_labeled=dict_users_labeled[idx], 
                                pseudo_label=pseudo_label)
            
            w_dic, w_ema_dic, w_ema ,loss, loss_consistent, diff_w_ema, comu_w, comu_w_ema = local.trainc(
                    net=copy.deepcopy(net_glob).to(args.device),
                    net_ema=copy.deepcopy(net_ema_glob).to(args.device),
                    args=args, 
                    iter_glob=iter, 
                    user_epoch=dict_userepoch[idx],
                    diff_w_old = diff_w_old
                    )

            for i in list(w_ema.keys()):
                diff_w_old_dic.append(diff_w_ema[i])

        print('# of current epoch is ' + str(currentEpoch))

        workerNow = np.random.choice(idxs_users, 1, replace=False).tolist()[0]

        staleFlag = np.random.randint(-1, 4, size=1)

        print('The staleFlag of worker ' + str(workerNow) + ' is ' +
              str(staleFlag))

        if staleFlag <= 4:

            # judge the malicious node
            if workerNow not in maliciousN:
                local = LocalUpdate(args=args,
                                    dataset=dataset_train,
                                    idxs=dict_users[workerNow])
                w, loss = local.train(
                    net=copy.deepcopy(net_glob).to(args.device))
            else:
                w = torch.load(
                    './data/genesisGPUForCNN.pkl',
                    map_location=torch.device(
                        'cuda:{}'.format(args.gpu) if torch.cuda.is_available(
                        ) and args.gpu != -1 else 'cpu'))
                print('Training of malicious node device ' + str(workerNow) +
                      ' in iteration ' + str(currentEpoch) + ' has done!')

            # means that the alpha is 0.5
            w_fAvg.append(copy.deepcopy(base_glob))
            w_fAvg.append(copy.deepcopy(w))
Example #18
0
    val_acc_list, net_list = [], []
    dict_userepoch = {i: 0 for i in range(100)}

    for iter in range(args.epochs):

        net_glob.train()
        net_ema_glob.train()

        w_locals, w_ema_locals, loss_locals, loss_consistent_locals = [], [], [], []
        m = max(int(args.frac * args.num_users), 1)  #choice trained users
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        for idx in idxs_users:
            dict_userepoch[idx] = dict_userepoch[idx] + 1
            local = LocalUpdate(args=args,
                                dataset=dataset_train,
                                dataset_ema=dataset_train_ema,
                                idxs=dict_users[idx],
                                idxs_labeled=dict_users_labeled[idx],
                                pseudo_label=pseudo_label)

            w, w_ema, loss, loss_consistent = local.train(
                net=copy.deepcopy(net_glob).to(args.device),
                net_ema=copy.deepcopy(net_ema_glob).to(args.device),
                args=args,
                iter_glob=iter + 1,
                user_epoch=dict_userepoch[idx])

            w_locals.append(copy.deepcopy(w))
            w_ema_locals.append(copy.deepcopy(w_ema))
            loss_locals.append(copy.deepcopy(loss))
            loss_consistent_locals.append(copy.deepcopy(loss_consistent))
    val_acc_list, net_list = [], []

    acc_train = []
    loss_tr = []
    acc_test = []
    loss_test = []

    for iter in range(args.epochs):
        #net_glob.train()
        w_locals, loss_locals = [], []
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        for order, idx in enumerate(idxs_users):
            local = LocalUpdate(args=args,
                                dataset=dataset_train,
                                idxs=dict_users[idx],
                                net=copy.deepcopy(net_glob).to(args.device),
                                epochs=iter)
            #w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            #w_locals.append(copy.deepcopy(w))
            #loss_locals.append(copy.deepcopy(loss))
            #w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            w, loss = local.train()

            print(
                '\rEpochs: {}\tUserID: {}\tSequence: {}\tLoss: {:.6f}'.format(
                    iter, idx, order, loss))
            loss_locals.append(copy.deepcopy(loss))
            #w_locals.append(copy.deepcopy(w.state_dict()))
            nets_users[idx][0] = 1
            nets_users[idx][1] = copy.deepcopy(
    for iter in range(args.epochs):
        if iter > 135:
            args.lr = 0.01
        w_locals, loss_locals, ac_locals, num_samples = [], [], [], []
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        pro_ground_truth = ground_truth_composition(dict_users, idxs_users, 26,
                                                    label_train)
        print(pro_ground_truth)

        for idx in idxs_users:
            local = LocalUpdate(args=args,
                                dataset=dataset_train,
                                label=label_train,
                                idxs=dict_users[idx],
                                alpha=ratio,
                                size_average=True)
            w, loss, ac = local.train(
                net=copy.deepcopy(net_glob).to(args.device))
            w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))
            ac_locals.append(copy.deepcopy(ac))
            num_samples.append(len(dict_users[idx]))

        # monitor
        cc_net, cc_loss = [], []
        aux_class = [i for i in range(26)]
        for i in aux_class:
            cc_local = LocalUpdate(args=args,
                                   dataset=dataset_train,
Example #21
0
def Proposed_G1(net_glob, dict_workers_index, dict_users_data,
                dict_labels_counter_mainFL, args, cost, dataset_train,
                dataset_test, valid_ds, loss_test_final_main,
                optimal_clients_number, optimal_delay):

    data_Global_DCFL = {
        "C": [],
        "Round": [],
        "Average Loss Train": [],
        "SDS Loss": [],
        "SDS Accuracy": [],
        "Workers Number": [],
        "Large Test Loss": [],
        "Large Test Accuracy": [],
        "Communication Cost": []
    }
    Final_LargeDataSetTest_DCFL = {
        "C": [],
        "Test Accuracy": [],
        "Test Loss": [],
        "Train Loss": [],
        "Train Accuracy": [],
        "Total Rounds": [],
        "Communication Cost": []
    }
    # copy weights
    # w_glob = net_glob.state_dict()

    temp = copy.deepcopy(net_glob)

    # training
    loss_train = []
    Loss_local_each_global_total = []
    selected_clients_costs_total = []

    loss_workers_total = np.zeros(shape=(args.num_users, 100 * args.epochs))

    workers_percent_dist = []
    workers_participation = np.zeros((args.num_users, 100 * args.epochs))
    workers = []
    for i in range(args.num_users):
        workers.append(i)

    n_k = np.zeros(shape=(args.num_users))
    for i in range(len(dict_users_data)):
        n_k[i] = len(dict_users_data[i])

    Global_Accuracy_Tracker = np.zeros(100 * args.epochs)
    Global_Loss_Tracker = np.zeros(100 * args.epochs)

    Goal_Loss = float(loss_test_final_main)

    net_glob.eval()
    acc_test_final, loss_test_final = test_img(net_glob, dataset_test, args)
    while_counter = float(loss_test_final)
    iter = 0

    total_rounds_dcfl = 0
    pre_net_glob = copy.deepcopy(net_glob)

    while abs(while_counter - Goal_Loss) >= 0.05:
        # print("G1 Loss is ", while_counter)
        selected_clients_costs_round = []
        w_locals, loss_locals = [], []
        m = max(int(args.frac * args.num_users), 1)

        x = net_glob
        x.eval()
        acc_test_global, loss_test_global = test_img(x, valid_ds, args)
        Loss_local_each_global_total.append(acc_test_global)
        Global_Accuracy_Tracker[iter] = acc_test_global
        Global_Loss_Tracker[iter] = loss_test_global
        workers_count = 0

        temp_w_locals = []
        temp_workers_loss = np.zeros(args.num_users)
        temp_workers_accuracy = np.zeros(args.num_users)
        temp_workers_loss_test = np.zeros(args.num_users)
        temp_workers_loss_difference = np.zeros((args.num_users, 2))
        flag = np.zeros(args.num_users)

        list_of_random_workers_newfl = []
        if iter < (args.epochs):
            for key, value in dict_workers_index.items():
                if key == iter:
                    list_of_random_workers_newfl = dict_workers_index[key]
        else:
            list_of_random_workers_newfl = random.sample(workers, m)

        initial_global_model = copy.deepcopy(net_glob).to(args.device)
        initial_global_model.eval()

        for idx in list_of_random_workers_newfl:
            local = LocalUpdate(args=args,
                                dataset=dataset_train,
                                idxs=dict_users_data[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))

            temp_w_locals.append(copy.deepcopy(w))
            temp_workers_loss[idx] = copy.deepcopy(loss)

            temp.load_state_dict(w)
            temp.eval()

            acc_test_local_after, loss_test_local_after = test_img(
                temp, valid_ds, args)
            temp_workers_accuracy[idx] = acc_test_local_after
            temp_workers_loss_test[idx] = loss_test_local_after
            temp_workers_loss_difference[idx, 0] = int(idx)
            temp_workers_loss_difference[idx, 1] = (loss_test_local_after)

        global_loss_diff = (Global_Loss_Tracker[iter])
        if global_loss_diff >= 0:
            # print("yes")
            for i in range(len(temp_w_locals)):
                if cost[int(temp_workers_loss_difference[i, 0])] <= optimal_delay and\
                        temp_workers_loss_difference[i, 1] >= global_loss_diff:
                    w_locals.append(copy.deepcopy(temp_w_locals[i]))
                    loss_locals.append(temp_workers_loss[int(
                        temp_workers_loss_difference[i, 0])])
                    flag[int(temp_workers_loss_difference[i, 0])] = 1
                    workers_count += 1
                    workers_participation[int(
                        temp_workers_loss_difference[i, 0])][iter] = 1
                    selected_clients_costs_round.append(cost[int(
                        temp_workers_loss_difference[i, 0])])
        if len(w_locals) < 1:
            for i in range(len(temp_w_locals)):
                w_locals.append(copy.deepcopy(temp_w_locals[i]))
                loss_locals.append(temp_workers_loss[int(
                    temp_workers_loss_difference[i, 0])])
                flag[int(temp_workers_loss_difference[i, 0])] = 1
                workers_count += 1
                workers_participation[int(
                    temp_workers_loss_difference[i, 0])][iter] = 1
                selected_clients_costs_round.append(cost[int(
                    temp_workers_loss_difference[i, 0])])

        # update global weights
        # w_glob = FedAvg(w_locals)

        for n in range(args.num_users - len(w_locals)):
            w_locals.append(pre_net_glob.state_dict())
        w_glob = fed_avg(w_locals, n_k)

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)

        #print("round completed")
        if len(loss_locals) > 0:
            loss_avg = sum(loss_locals) / len(loss_locals)
        else:
            loss_avg = None
        loss_train.append(loss_avg)
        workers_percent_dist.append(workers_count / args.num_users)
        print(iter, " round G1 fl finished")

        acc_test_final, loss_test_final = test_img(net_glob, dataset_test,
                                                   args)
        while_counter = loss_test_final

        data_Global_DCFL["Round"].append(iter)
        data_Global_DCFL["C"].append(args.frac)
        data_Global_DCFL["Average Loss Train"].append(loss_avg)
        data_Global_DCFL["SDS Accuracy"].append(Global_Accuracy_Tracker[iter])
        data_Global_DCFL["SDS Loss"].append(Global_Loss_Tracker[iter])
        data_Global_DCFL["Workers Number"].append(workers_count)
        data_Global_DCFL["Large Test Loss"].append(float(loss_test_final))
        data_Global_DCFL["Large Test Accuracy"].append(float(acc_test_final))
        data_Global_DCFL["Communication Cost"].append(
            sum(selected_clients_costs_round))

        selected_clients_costs_total.append(sum(selected_clients_costs_round))

        iter += 1
        total_rounds_dcfl = iter
        pre_net_glob = copy.deepcopy(net_glob)

    # plot workers percent of participating
    workers_percent_final = np.zeros(args.num_users)
    workers_name = np.zeros(args.num_users)
    for i in range(len(workers_participation[:, 1])):
        workers_percent_final[i] = sum(
            workers_participation[i, :]) / (iter - 1)
        workers_name[i] = i

    # testing
    net_glob.eval()
    acc_train_final, loss_train_final = test_img(net_glob, dataset_train, args)
    acc_test_final, loss_test_final = test_img(net_glob, dataset_test, args)

    Final_LargeDataSetTest_DCFL["C"].append(args.frac)
    Final_LargeDataSetTest_DCFL["Test Loss"].append(float(loss_test_final))
    Final_LargeDataSetTest_DCFL["Test Accuracy"].append(float(acc_test_final))
    Final_LargeDataSetTest_DCFL["Train Loss"].append(float(loss_train_final))
    Final_LargeDataSetTest_DCFL["Train Accuracy"].append(
        float(acc_train_final))
    Final_LargeDataSetTest_DCFL["Total Rounds"].append(int(total_rounds_dcfl))
    Final_LargeDataSetTest_DCFL["Communication Cost"].append(
        sum(selected_clients_costs_total))

    return Final_LargeDataSetTest_DCFL, data_Global_DCFL
Example #22
0
        acc_test_fl, loss_test_flxx = test_img(net_glob_fl, dataset_test, args)
        print("Testing accuracy: {:.2f}".format(acc_test_fl))
        acc_train_fl_his.append(acc_test_fl)

        filename = 'result/MLP/' + "Accuracy_FedAvg_unbalance_MLP.csv"
        with open(filename, "a") as myfile:
            myfile.write(str(acc_test_fl) + ',')

        w_locals, loss_locals = [], []
        # M clients local update
        m = max(int(args.frac * args.num_users), 1)  # num of selected users
        idxs_users = np.random.choice(range(
            args.num_users), m, replace=False)  # select randomly m clients
        for idx in idxs_users:
            local = LocalUpdate(args=args,
                                dataset=dataset_train,
                                idxs=dict_users[idx])  # data select
            w, loss = local.train(
                net=copy.deepcopy(net_glob_fl).to(args.device))
            w_locals.append(copy.deepcopy(w))  # collect local model
            loss_locals.append(
                copy.deepcopy(loss))  # collect local loss fucntion

        w_glob_fl = FedAvg(w_locals)  # update the global model
        net_glob_fl.load_state_dict(w_glob_fl)  # copy weight to net_glob

        # Loss
        loss = sum(loss_locals) / len(loss_locals)
        print('fl,iter = ', iter, 'loss=', loss)
        filename = 'result/MLP/' + "Loss_FedAvg_unbalance_MLP.csv"
        with open(filename, "a") as myfile:
Example #23
0
def main():
    # parse args
    args = args_parser()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    dataPath = args.datasetPath

    # random seed
    np.random.seed(args.seed)
    cudnn.benchmark = False
    cudnn.deterministic = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)

    # load dataset and split users
    if args.dataset == 'cifar10':
        _CIFAR_TRAIN_TRANSFORMS = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ]
        dataset_train = datasets.CIFAR10(
            dataPath,
            train=True,
            download=True,
            transform=transforms.Compose(_CIFAR_TRAIN_TRANSFORMS))

        _CIFAR_TEST_TRANSFORMS = [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ]
        dataset_test = datasets.CIFAR10(
            dataPath,
            train=False,
            transform=transforms.Compose(_CIFAR_TEST_TRANSFORMS))

        if args.iid == 0:  # IID
            dict_users = cifar_iid(dataset_train, args.num_users)
        elif args.iid == 2:  # non-IID
            dict_users = cifar_noniid_2(dataset_train, args.num_users)
        else:
            exit('Error: unrecognized class')

    elif args.dataset == 'emnist':
        _MNIST_TRAIN_TRANSFORMS = _MNIST_TEST_TRANSFORMS = [
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.Pad(2),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ]
        dataset_train = datasets.EMNIST(
            dataPath,
            train=True,
            download=True,
            transform=transforms.Compose(_MNIST_TRAIN_TRANSFORMS),
            split='letters')
        dataset_test = datasets.EMNIST(
            dataPath,
            train=False,
            download=True,
            transform=transforms.Compose(_MNIST_TEST_TRANSFORMS),
            split='letters')

        dict_users = femnist_star(dataset_train, args.num_users)

    elif args.dataset == 'cifar100':
        _CIFAR_TRAIN_TRANSFORMS = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ]
        dataset_train = datasets.CIFAR100(
            dataPath,
            train=True,
            download=True,
            transform=transforms.Compose(_CIFAR_TRAIN_TRANSFORMS))

        _CIFAR_TEST_TRANSFORMS = [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ]
        dataset_test = datasets.CIFAR100(
            dataPath,
            train=False,
            transform=transforms.Compose(_CIFAR_TEST_TRANSFORMS))
        if args.iid == 0:  # IID
            dict_users = cifar_100_iid(dataset_train, args.num_users)
        elif args.iid == 2:  # non-IID
            dict_users = cifar_100_noniid(dataset_train, args.num_users)
    else:
        exit('Error: unrecognized dataset')

    # build model
    if args.dataset == 'cifar10':
        if args.model == "CNNStd5":
            net_glob = CNNCifarStd5().cuda()
        else:
            exit('Error: unrecognized model')
    elif args.dataset == 'emnist':
        if args.model == "CNNStd5":
            net_glob = CNNEmnistStd5().cuda()
        else:
            exit('Error: unrecognized model')
    elif args.dataset == 'cifar100':
        if args.model == "CNNStd5":
            net_glob = CNNCifar100Std5().cuda()
        else:
            exit('Error: unrecognized model')
    else:
        exit('Error: unrecognized model')

    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in net_glob.parameters()])))

    net_glob.train()

    learning_rate = args.lr
    test_acc = []
    avg_loss = []

    # Train
    for iter in range(args.epochs):

        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        w_locals, loss_locals = [], []
        for i, idx in enumerate(idxs_users):
            print('user: {:d}'.format(idx))
            local = LocalUpdate(args=args,
                                dataset=dataset_train,
                                idxs=dict_users[idx])
            w, loss = local.train(model=copy.deepcopy(net_glob).cuda(),
                                  lr=learning_rate)

            w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))

        # update global weights
        w_glob = FedAvg(w_locals)

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)

        # print loss
        loss_avg = sum(loss_locals) / len(loss_locals)
        print('Round {:3d}, Average loss {:.6f}'.format(iter, loss_avg))

        acc_test, _ = test_img(net_glob.cuda(), dataset_test, args)
        print("test accuracy: {:.4f}".format(acc_test))
        test_acc.append(acc_test)

        avg_loss.append(loss_avg)

        learning_rate = adjust_learning_rate(learning_rate, args.lr_drop)

    filename = './accuracy-' + str(args.dataset) + '-iid' + str(args.iid) + '-' + str(args.epochs) + '-seed' \
               + str(args.seed) + '-' + str(args.loss_type) + '-beta' + str(args.beta) + '-mu' + str(args.mu)
    save_result(test_acc, avg_loss, filename)
Example #24
0
                                         dataset=dataset_test,
                                         idxs=dict_users_test[idx])
            w = local.one_sgd_step(net=copy.deepcopy(
                net_local_list[user_idx]).to(args.device),
                                   lr=lr,
                                   beta=0.1)
            net_local_list[user_idx].load_state_dict(w)

        # fine-tuning
        if args.fine_tuning:
            local_ep_backup = args.local_ep
            args.local_ep = args.ft_ep

            for user_idx in range(args.num_users):
                local = LocalUpdate(args=args,
                                    dataset=dataset_train,
                                    idxs=dict_users_train[idx])
                w, loss = local.train(net=copy.deepcopy(
                    net_local_list[user_idx]).to(args.device),
                                      body_lr=lr,
                                      head_lr=lr)
                net_local_list[user_idx].load_state_dict(w)

            args.local_ep = local_ep_backup

        if (iter + 1) % args.test_freq == 0:
            acc_test, loss_test = test_img_local_all(net_local_list,
                                                     args,
                                                     dataset_test,
                                                     dict_users_test,
                                                     return_all=False)
Example #25
0
def main():
    # parse args
    args = args_parser()
    args.device = torch.device('cuda:{}'.format(
        args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

    # load dataset and split users
    if args.dataset == 'mnist':
        trans_mnist = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])
        dataset_train = datasets.MNIST('../data/mnist/',
                                       train=True,
                                       download=True,
                                       transform=trans_mnist)
        dataset_test = datasets.MNIST('../data/mnist/',
                                      train=False,
                                      download=True,
                                      transform=trans_mnist)
        print("type of test dataset", type(dataset_test))
        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users, dict_labels_counter = mnist_noniid(
                dataset_train, args.num_users)
            dict_users_2, dict_labels_counter_2 = dict_users, dict_labels_counter
            #dict_users, dict_labels_counter = mnist_noniid_unequal(dataset_train, args.num_users)
    elif args.dataset == 'cifar':
        trans_cifar = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset_train = datasets.CIFAR10('../data/cifar',
                                         train=True,
                                         download=True,
                                         transform=trans_cifar)
        dataset_test = datasets.CIFAR10('../data/cifar',
                                        train=False,
                                        download=True,
                                        transform=trans_cifar)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in CIFAR10')
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape

    # build model
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
        net_glob_2 = CNNCifar(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
        net_glob_2 = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=200,
                       dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')

    #print(net_glob)

    #net_glob.train()

    acc_test, loss_test = test_img(net_glob, dataset_test, args)
    print("val test finished")
    print("{:.2f}".format(acc_test))
    temp = net_glob

    #net_glob_2 = net_glob
    temp_2 = net_glob_2

    # copy weights
    w_glob = net_glob.state_dict()

    # training
    loss_train = []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0
    net_best = None
    best_loss = None
    val_acc_list, net_list = [], []

    Loss_local_each_global_total = []

    test_ds, valid_ds = torch.utils.data.random_split(dataset_test,
                                                      (9500, 500))
    loss_workers_total = np.zeros(shape=(args.num_users, args.epochs))
    label_workers = {
        i: np.array([], dtype='int64')
        for i in range(args.num_users)
    }

    workers_percent = []
    workers_count = 0
    acc_test_global, loss_test_global = test_img(x, valid_ds, args)
    selected_users_index = []

    for idx in range(args.num_users):
        # print("train started")
        local = LocalUpdate(args=args,
                            dataset=dataset_train,
                            idxs=dict_users[idx])
        w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
        # print(w)
        # print("train completed")

        # temp = FedAvg(w)
        temp.load_state_dict(w)
        temp.eval()
        acc_test_local, loss_test_local = test_img(temp, valid_ds, args)
        loss_workers_total[idx, iter] = acc_test_local

        if workers_count >= (args.num_users / 2):
            break
        elif acc_test_local >= (0.7 * acc_test_global):
            selected_users_index.append(idx)

    for iter in range(args.epochs):
        print("round started")
        Loss_local_each_global = []
        loss_workers = np.zeros((args.num_users, args.epochs))
        w_locals, loss_locals = [], []
        m = max(int(args.frac * args.num_users), 1)
        #idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        #if iter % 5 == 0:
        # Minoo
        x = net_glob
        x.eval()

        Loss_local_each_global_total.append(acc_test_global)

        for idx in selected_users_index:
            #print("train started")
            local = LocalUpdate(args=args,
                                dataset=dataset_train,
                                idxs=dict_users[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            #print(w)
            #print("train completed")

            #temp = FedAvg(w)
            temp.load_state_dict(w)
            temp.eval()
            acc_test_local, loss_test_local = test_img(temp, valid_ds, args)
            loss_workers_total[idx, iter] = acc_test_local

            if workers_count >= (args.num_users / 2):
                break
            elif acc_test_local >= (0.7 * acc_test_global):
                w_locals.append(copy.deepcopy(w))
                loss_locals.append(copy.deepcopy(loss))
                print("Update Received")
                workers_count += 1

        # update global weights
        w_glob = FedAvg(w_locals)

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)

        print("round completed")
        loss_avg = sum(loss_locals) / len(loss_locals)
        print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
        loss_train.append(loss_avg)
        workers_percent.append(workers_count)

    # plot loss curve
    plt.figure()
    plt.plot(range(len(workers_percent)), workers_percent)
    plt.ylabel('train_loss')
    plt.savefig(
        './save/Newfed_WorkersPercent_0916_{}_{}_{}_C{}_iid{}.png'.format(
            args.dataset, args.model, args.epochs, args.frac, args.iid))
    # print(loss_workers_total)

    # plot loss curve
    # plt.figure()
    # plt.plot(range(len(loss_train)), loss_train)
    # plt.ylabel('train_loss')
    # plt.savefig('./save/Newfed_0916_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))
    #

    plt.figure()
    for i in range(args.num_users):
        plot = plt.plot(range(len(loss_workers_total[i, :])),
                        loss_workers_total[i, :],
                        label="Worker {}".format(i))
    plot5 = plt.plot(range(len(Loss_local_each_global_total)),
                     Loss_local_each_global_total,
                     color='000000',
                     label="Global")
    plt.legend(loc='best')
    plt.ylabel('Small Test Set Accuracy of workers')
    plt.xlabel('Number of Rounds')
    plt.savefig(
        './save/NewFed_2workers_Acc_0916_{}_{}_{}_C{}_iid{}.png'.format(
            args.dataset, args.model, args.epochs, args.frac, args.iid))

    # plt.figure()
    # bins = np.linspace(0, 9, 3)
    # a = dict_labels_counter[:, 0].ravel()
    # print(type(a))
    # b = dict_labels_counter[:, 1].ravel()
    # x_labels = ['0', '1', '2', '3','4','5','6','7','8','9']
    # # Set plot parameters
    # fig, ax = plt.subplots()
    # width = 0.1  # width of bar
    # x = np.arange(10)
    # ax.bar(x, dict_labels_counter[:, 0], width, color='#000080', label='Worker 1')
    # ax.bar(x + width, dict_labels_counter[:, 1], width, color='#73C2FB', label='Worker 2')
    # ax.bar(x + 2*width, dict_labels_counter[:, 2], width, color='#ff0000', label='Worker 3')
    # ax.bar(x + 3*width, dict_labels_counter[:, 3], width, color='#32CD32', label='Worker 4')
    # ax.set_ylabel('Number of Labels')
    # ax.set_xticks(x + width + width / 2)
    # ax.set_xticklabels(x_labels)
    # ax.set_xlabel('Labels')
    # ax.legend()
    # plt.grid(True, 'major', 'y', ls='--', lw=.5, c='k', alpha=.3)
    # fig.tight_layout()
    # plt.savefig(
    #     './save/Newfed_2workersLabels_0916_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac,
    #                                                                args.iid))

    # testing
    print("testing started")
    net_glob.eval()
    print("train test started")
    acc_train_final, loss_train_final = test_img(net_glob, dataset_train, args)
    print("train test finished")
    acc_test_final, loss_test_final = test_img(net_glob, dataset_test, args)
    print("val test finished")
    #print("Training accuracy: {:.2f}".format(acc_train))
    #print("Testing accuracy: {:.2f}".format(acc_test))
    print("{:.2f}".format(acc_test_final))
    #print("{:.2f".format(Loss_local_each_worker))

    # training
    w_glob_2 = net_glob_2.state_dict()

    loss_train_2 = []
    cv_loss_2, cv_acc_2 = [], []
    val_loss_pre_2, counter_2 = 0, 0
    net_best_2 = None
    best_loss_2 = None
    val_acc_list_2, net_list_2 = [], []

    Loss_local_each_global_total_2 = []

    loss_workers_total_2 = np.zeros(shape=(args.num_users, args.epochs))
    label_workers_2 = {
        i: np.array([], dtype='int64')
        for i in range(args.num_users)
    }

    for iter in range(args.epochs):
        print("round started")
        Loss_local_each_global_2 = []
        loss_workers_2 = np.zeros((args.num_users, args.epochs))
        w_locals_2, loss_locals_2 = [], []
        m_2 = max(int(args.frac * args.num_users), 1)
        idxs_users_2 = np.random.choice(range(args.num_users),
                                        m_2,
                                        replace=False)

        # Minoo
        x_2 = net_glob_2
        x_2.eval()
        acc_test_global_2, loss_test_global_2 = test_img(x_2, valid_ds, args)
        Loss_local_each_global_total_2.append(acc_test_global_2)

        for idx in idxs_users_2:
            #print("train started")
            local_2 = LocalUpdate(args=args,
                                  dataset=dataset_train,
                                  idxs=dict_users_2[idx])
            w_2, loss_2 = local_2.train(
                net=copy.deepcopy(net_glob_2).to(args.device))
            #print(w)
            #print("train completed")
            w_locals_2.append(copy.deepcopy(w_2))
            loss_locals_2.append(copy.deepcopy(loss_2))
            #temp = FedAvg(w)
            temp_2.load_state_dict(w_2)
            temp_2.eval()
            acc_test_local_2, loss_test_local_2 = test_img(
                temp_2, valid_ds, args)
            loss_workers_total_2[idx, iter] = acc_test_local_2

        # update global weights
        w_glob_2 = FedAvg(w_locals_2)

        # copy weight to net_glob
        net_glob_2.load_state_dict(w_glob_2)

        print("round completed")
        loss_avg_2 = sum(loss_locals_2) / len(loss_locals_2)
        print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg_2))
        loss_train_2.append(loss_avg_2)
        print("round completed")

        # plot loss curve
    plt.figure()
    plt.plot(range(len(loss_train_2)),
             loss_train_2,
             color='#000000',
             label="Main FL")
    plt.plot(range(len(loss_train)),
             loss_train,
             color='#ff0000',
             label="Centralized Algorithm")
    plt.ylabel('train_loss')
    plt.savefig('./save/main_fed_0916_{}_{}_{}_C{}_iid{}.png'.format(
        args.dataset, args.model, args.epochs, args.frac, args.iid))
    # print(loss_workers_total)

    plt.figure()
    for i in range(args.num_users):
        plot = plt.plot(range(len(loss_workers_total_2[i, :])),
                        loss_workers_total_2[i, :],
                        label="Worker {}".format(i))
    plot5 = plt.plot(range(len(Loss_local_each_global_total_2)),
                     Loss_local_each_global_total_2,
                     color='000000',
                     label="Global")
    plt.legend(loc='best')
    plt.ylabel('Small Test Set Accuracy of workers')
    plt.xlabel('Number of Rounds')
    plt.savefig('./save/mainfed_Acc_0916_{}_{}_{}_C{}_iid{}.png'.format(
        args.dataset, args.model, args.epochs, args.frac, args.iid))

    # plt.figure()
    # bins = np.linspace(0, 9, 3)
    # a = dict_labels_counter_2[:, 0].ravel()
    # print(type(a))
    # b = dict_labels_counter_2[:, 1].ravel()
    # x_labels = ['0', '1', '2', '3','4','5','6','7','8','9']
    # # Set plot parameters
    # fig, ax = plt.subplots()
    # width = 0.1  # width of bar
    # x = np.arange(10)
    # ax.bar(x, dict_labels_counter_2[:, 0], width, color='#000080', label='Worker 1')
    # ax.bar(x + width, dict_labels_counter_2[:, 1], width, color='#73C2FB', label='Worker 2')
    # ax.bar(x + 2*width, dict_labels_counter_2[:, 2], width, color='#ff0000', label='Worker 3')
    # ax.bar(x + 3*width, dict_labels_counter_2[:, 3], width, color='#32CD32', label='Worker 4')
    # ax.set_ylabel('Number of Labels')
    # ax.set_xticks(x + width + width / 2)
    # ax.set_xticklabels(x_labels)
    # ax.set_xlabel('Labels')
    # ax.legend()
    # plt.grid(True, 'major', 'y', ls='--', lw=.5, c='k', alpha=.3)
    # fig.tight_layout()
    # plt.savefig(
    #     './save/main_fed_2workersLabels_0916_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac,
    #                                                                args.iid))

    # testing
    print("testing started")
    net_glob.eval()
    print("train test started")
    acc_train_final, loss_train_final = test_img(net_glob, dataset_train, args)
    print("train test finished")
    acc_test_final, loss_test_final = test_img(net_glob, dataset_test, args)
    print("val test finished")
    #print("Training accuracy: {:.2f}".format(acc_train))
    #print("Testing accuracy: {:.2f}".format(acc_test))
    print("{:.2f}".format(acc_test_final))
    #print("{:.2f".format(Loss_local_each_worker))

    return loss_test_final, loss_train_final
    if args.mode == 'IIL':
        args.epochs = 1
    
    for iter in range(start_epoch, args.epochs):
        print("Global Epoch:", iter +1)
        w_locals, loss_locals = [], []
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        for idx in idxs_users:
            print("User:{:3d}".format(idx))
            user_train_paths, user_train_set= [], []
            for i in dict_users[idx]:
                user_train_paths.append(train_file_paths[i])
                
            user_train_set = load_data_h5(train_file_paths = user_train_paths)
            local = LocalUpdate(args=args, dataset=user_train_set, idxs=dict_users[idx], logwriter=writer, user_id=idx, testLoader = test_loader, epoch = iter)
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))

        old_loc_epoch = args.local_ep
        if args.ibsr:
            args.local_ep = 1
            m = max(int(args.frac * num_ibsr_users), 1)
            ibsr_users = np.random.choice(range(num_ibsr_users), m, replace=False)
            for idx in ibsr_users:
                print("IBSR User:{:3d}".format(idx))
                user_train_paths, user_train_set= [], []
                for i in dict_ibsr_users[idx]:
                    user_train_paths.append(ibsr_paths[i])
    count_array_10 = []



    for iter in range(args.epochs):
        #agent_found_count = 0
        w_locals, loss_locals = [], []          #w_locals = array of local_weights
        w_locals_1, loss_locals_1 = [],[]
        w_locals_5, loss_locals_5 = [],[]
        w_locals_7, loss_locals_7 = [],[]
        w_locals_10, loss_locals_10 = [],[]
        m = max(int(args.frac * args.num_users), 1)     #m = number of users used in one ROUND/EPOCH, check utils.options for more clarity on this
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)  #Randomly selecting m users out of 32 users. NEED TO REPLACE THIS WITH OUR SAMPLING MECHANISM

        for idx in idxs_users:
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            local1 = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            local5 = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            local7 = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            local10 = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])

            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            w1, loss1 = local1.train(net=copy.deepcopy(net_glob1).to(args.device))
            w5, loss5 = local5.train(net=copy.deepcopy(net_glob5).to(args.device))
            w7, loss7 = local7.train(net=copy.deepcopy(net_glob7).to(args.device))
            w10, loss10 = local10.train(net=copy.deepcopy(net_glob10).to(args.device))
            print("***BLAH BLAH BLAH***")


            if idx==fixed_agent_1:
                if updates_recorded_1:
Example #28
0
def ICC_FL(net_glob, dict_workers_index, dict_users_data, dict_labels_counter_mainFL, args, cost,
           dataset_train, dataset_test, valid_ds, loss_test_final_main, optimal_delay):


    data_Global_DCFL = {"C": [], "Round": [], "Average Loss Train": [], "SDS Loss": [], "SDS Accuracy": [],
                        "Workers Number": [], "Large Test Loss": [], "Large Test Accuracy": [], "Communication Cost": []}
    Final_LargeDataSetTest_DCFL = {"C": [], "Test Accuracy": [], "Test Loss": [], "Train Loss": [],
                                   "Train Accuracy": [],
                                   "Total Rounds": [], "Communication Cost": []}
    # copy weights
    w_glob = net_glob.state_dict()

    temp = copy.deepcopy(net_glob)

    # training
    loss_train = []
    Loss_local_each_global_total = []

    selected_clients_costs_total = []
    loss_workers_total = np.zeros(shape=(args.num_users, 10 * args.epochs))

    workers_percent_dist = []
    workers_participation = np.zeros((args.num_users, 10 * args.epochs))
    workers = []
    for i in range(args.num_users):
        workers.append(i)

    n_k = np.zeros(shape=(args.num_users))
    for i in range(len(dict_users_data)):
        n_k[i] = len(dict_users_data[i])

    counter_threshold_decrease = np.zeros(10 * args.epochs)
    Global_Accuracy_Tracker = np.zeros(10 * args.epochs)
    Global_Loss_Tracker = np.zeros(10 * args.epochs)
    threshold = 1.0
    beta = 0.1 ##delta accuracy controller
    gamma = 0.05  ##threshold decrease parameter

    Goal_Loss = float(loss_test_final_main)

    net_glob.eval()
    acc_test_final, loss_test_final = test_img(net_glob, dataset_test, args)
    while_counter = float(loss_test_final)
    iter = 0

    total_rounds_dcfl = 0
    pre_net_glob = copy.deepcopy(net_glob)

    while abs(while_counter - Goal_Loss) >= 0.05:
        selected_clients_costs_round = []
        w_locals, loss_locals = [], []
        m = max(int(args.frac * args.num_users), 1)
        counter_threshold = 0

        x = net_glob
        x.eval()
        acc_test_global, loss_test_global = test_img(x, valid_ds, args)
        Loss_local_each_global_total.append(acc_test_global)
        Global_Accuracy_Tracker[iter] = acc_test_global
        Global_Loss_Tracker[iter] = loss_test_global
        if iter > 0 & (Global_Loss_Tracker[iter-1] - Global_Loss_Tracker[iter] <= beta):
            threshold = threshold - gamma
            if threshold == 0.0:
                threshold = 1.0
        workers_count = 0


        temp_w_locals = []
        temp_workers_loss = np.zeros(args.num_users)
        temp_workers_accuracy = np.zeros(args.num_users)
        temp_workers_loss_test = np.zeros(args.num_users)
        temp_workers_loss_difference = np.zeros(args.num_users)
        flag = np.zeros(args.num_users)

        list_of_random_workers_newfl = []
        if iter < (args.epochs):
            for key, value in dict_workers_index.items():
                if key == iter:
                    list_of_random_workers_newfl = dict_workers_index[key]
        else:
            list_of_random_workers_newfl = random.sample(workers, m)


        for idx in list_of_random_workers_newfl:
            initial_global_model = copy.deepcopy(net_glob).to(args.device)
            initial_global_model.eval()
            acc_test_local_initial, loss_test_local_initial = test_img(initial_global_model, valid_ds, args)


            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users_data[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))

            temp_w_locals.append(copy.deepcopy(w))
            temp_workers_loss[idx] = copy.deepcopy(loss)

            temp.load_state_dict(w)
            temp.eval()

            acc_test_local_after, loss_test_local_after = test_img(temp, valid_ds, args)
            loss_workers_total[idx, iter] = loss_test_local_after
            temp_workers_accuracy[idx] = acc_test_local_after
            temp_workers_loss_test[idx] = loss_test_local_after
            temp_workers_loss_difference[idx] = abs(loss_test_local_after - loss_test_local_initial)

        while len(w_locals) < 1:
            index = 0
            for idx in list_of_random_workers_newfl:
                if workers_count >= m:
                    break
                elif temp_workers_loss_test[idx] <= threshold and flag[idx]==0 and cost[idx] <= optimal_delay:
                    w_locals.append(copy.deepcopy(temp_w_locals[index]))
                    loss_locals.append(temp_workers_loss[idx])
                    flag[idx] = 1
                    workers_count += 1
                    workers_participation[idx][iter] = 1
                    selected_clients_costs_round.append(cost[idx])
                index += 1
            if len(w_locals) < 1:
                threshold = threshold * 2




        # update global weights
        w_glob = FedAvg(w_locals)

        # for n in range(args.num_users - len(w_locals)):
        #     w_locals.append(pre_net_glob.state_dict())
        # w_glob = fed_avg(w_locals, n_k)

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)

        loss_avg = sum(loss_locals) / len(loss_locals)
        loss_train.append(loss_avg)
        workers_percent_dist.append(workers_count/args.num_users)


        counter_threshold_decrease[iter] = counter_threshold
        print(iter, " round dist fl finished")


        acc_test_final, loss_test_final = test_img(net_glob, dataset_test, args)
        while_counter = loss_test_final


        data_Global_DCFL["Round"].append(iter)
        data_Global_DCFL["C"].append(args.frac)
        data_Global_DCFL["Average Loss Train"].append(loss_avg)
        data_Global_DCFL["SDS Accuracy"].append(Global_Accuracy_Tracker[iter])
        data_Global_DCFL["SDS Loss"].append(Global_Loss_Tracker[iter])
        data_Global_DCFL["Workers Number"].append(workers_count)
        data_Global_DCFL["Large Test Loss"].append(float(loss_test_final))
        data_Global_DCFL["Large Test Accuracy"].append(float(acc_test_final))
        data_Global_DCFL["Communication Cost"].append(sum(selected_clients_costs_round))

        selected_clients_costs_total.append(sum(selected_clients_costs_round))

        iter += 1
        total_rounds_dcfl = iter

        pre_net_glob = copy.deepcopy(net_glob)

    # plot workers percent of participating
    workers_percent_final = np.zeros(args.num_users)
    workers_name = np.zeros(args.num_users)
    for i in range(len(workers_participation[:, 1])):
        workers_percent_final[i] = sum(workers_participation[i, :]) / (iter - 1)
        workers_name[i] = i

    # selected_clients_costs_total.append(sum(selected_clients_costs_round))

    # testing
    net_glob.eval()
    acc_train_final, loss_train_final = test_img(net_glob, dataset_train, args)
    acc_test_final, loss_test_final = test_img(net_glob, dataset_test, args)


    Final_LargeDataSetTest_DCFL["C"].append(args.frac)
    Final_LargeDataSetTest_DCFL["Test Loss"].append(float(loss_test_final))
    Final_LargeDataSetTest_DCFL["Test Accuracy"].append(float(acc_test_final))
    Final_LargeDataSetTest_DCFL["Train Loss"].append(float(loss_train_final))
    Final_LargeDataSetTest_DCFL["Train Accuracy"].append(float(acc_train_final))
    Final_LargeDataSetTest_DCFL["Total Rounds"].append(int(total_rounds_dcfl))
    Final_LargeDataSetTest_DCFL["Communication Cost"].append(sum(selected_clients_costs_total))


    return Final_LargeDataSetTest_DCFL, data_Global_DCFL
Example #29
0
        w_locals_5, loss_locals_5 = [], []
        w_locals_10, loss_locals_10 = [], []
        w_locals_15, loss_locals_15 = [], []
        w_locals_20, loss_locals_20 = [], []
        w_locals_25, loss_locals_25 = [], []
        w_locals_30, loss_locals_30 = [], []
        m = max(
            int(args.frac * args.num_users), 1
        )  #m = number of users used in one ROUND/EPOCH, check utils.options for more clarity on this
        idxs_users = np.random.choice(
            range(args.num_users), m, replace=False
        )  #Randomly selecting m users out of 32 users. NEED TO REPLACE THIS WITH OUR SAMPLING MECHANISM

        for idx in idxs_users:
            local = LocalUpdate(args=args,
                                dataset=dataset_train,
                                idxs=dict_users[idx])
            local1 = LocalUpdate(args=args,
                                 dataset=dataset_train,
                                 idxs=dict_users[idx])
            local5 = LocalUpdate(args=args,
                                 dataset=dataset_train,
                                 idxs=dict_users[idx])
            local10 = LocalUpdate(args=args,
                                  dataset=dataset_train,
                                  idxs=dict_users[idx])
            local15 = LocalUpdate(args=args,
                                  dataset=dataset_train,
                                  idxs=dict_users[idx])
            local20 = LocalUpdate(args=args,
                                  dataset=dataset_train,
Example #30
0
def mainFl(
    net_glob_mainFL: Any,
    dict_users_mainFL: Dict[int, Any],
    dict_labels_counter_mainFL,
    args,
    cost,
    dataset_train,
    dataset_test,
    small_shared_dataset
):
    """

    Args:
        net_glob_mainFL (torch.nn.Module): global model
        dict_users_mainFL (Dict: dict_users_mainFL[idx_user]): dict contains users data indexes, to access one user's data index,
        write dict_users_mainFL[idx_user]
        dict_labels_counter_mainFL: dict contains each users's labels total number, we do not use it right now
        args: all args. You can look for details in utils/options.py.
        cost: An array contains cost of sending locally updated models from users to server. We do not use it right now.
        dataset_train (torch dataset): Total train data set. We need it for train part, as in dict_users, we just have index of data.
        dataset_test (torch dataset): Total test dataset.
        small_shared_dataset (torch dataset): The small shared dataset that we just use it here for tracking and comparing with our
        algorithm, not for decision making.

    Returns:
        float(loss_test_final_main): final loss test over main test dataset
        dict_workers_index: index of selected workers in each round, we need it for use in other algorithms
        Final_LargeDataSetTest_MainFL: A dict contains macroscopic data to be saved after each total FL process (each C)
        data_Global_main: A dict contains microscopic data to be saved after each total FL process (each round)

    """

    data_Global_main = {"C": [], "Round": [], "Average Loss Train": [], "SDS Loss": [], "SDS Accuracy": [],
                        "Workers Number": [], "Large Test Loss": [], "Large Test Accuracy": [], "Communication Cost": []}
    Final_LargeDataSetTest_MainFL = {"C": [], "Test Accuracy": [], "Test Loss": [], "Train Loss": [],
                                     "Train Accuracy": [], "Total Rounds": [], "Communication Cost": []}

    # saving index of workers
    dict_workers_index = defaultdict(list)

    n_k = np.zeros(shape=(args.num_users))
    for i in range(len(dict_users_mainFL)):
        n_k[i] = len(dict_users_mainFL[i])
    # print(n_k)

    # Main FL

    # contains average loss over each clients' loss
    loss_train_mainFL = []
    # contains loss of
    Loss_local_each_global_total_mainFL = []
    Accuracy_local_each_global_total_mainFL = []
    # contains loss of each workers over small shared dataset in each round
    loss_workers_total_mainFL = np.zeros(shape=(args.num_users, args.epochs))
    label_workers_mainFL = {i: np.array(
        [], dtype='int64') for i in range(args.num_users)}

    #
    validation_test_mainFed = []
    acc_test, loss_test = test_img(net_glob_mainFL, dataset_test, args)
    workers_participation_main_fd = np.zeros((args.num_users, args.epochs))
    workers_percent_main = []

    net_glob_mainFL.eval()
    acc_test_final_mainFL, loss_test_final_mainFL = test_img(
        net_glob_mainFL, dataset_test, args)
    print("main fl initial loss is ", loss_test_final_mainFL)

    # while counter initialization
    iter_mainFL = 0

    # assign index to each worker in workers_mainFL arr
    workers_mainFL = []
    for i in range(args.num_users):
        workers_mainFL.append(i)

    temp_netglob_mainFL = copy.deepcopy(net_glob_mainFL)

    selected_clients_costs_total = []
    total_rounds_mainFL = 0

    pre_net_glob = copy.deepcopy(net_glob_mainFL)

    while iter_mainFL < (args.epochs):
        # print(f"iter {iter_mainFL} is started")
        selected_clients_costs_round = []
        w_locals_mainFL, loss_locals_mainFL = [], []
        m_mainFL = max(int(args.frac * args.num_users), 1)

        # selecting some clients randomly and save the index of them for use in other algorithms
        list_of_random_workers = random.sample(workers_mainFL, m_mainFL)
        # print("list of random workers is ", list_of_random_workers)
        for i in range(len(list_of_random_workers)):
            dict_workers_index[iter_mainFL].append(list_of_random_workers[i])

        # calculating and saving initial loss of global model over small shared dataset for just record
        x_mainFL = copy.deepcopy(net_glob_mainFL)
        x_mainFL.eval()
        acc_test_global_mainFL, loss_test_global_mainFL = test_img(
            x_mainFL, small_shared_dataset, args)
        Loss_local_each_global_total_mainFL.append(loss_test_global_mainFL)
        Accuracy_local_each_global_total_mainFL.append(acc_test_global_mainFL)
        # print("loss global is ", loss_test_global_mainFL)
        # print("accuracy global is ", acc_test_global_mainFL)
        workers_count_mainFL = 0
        for idx in list_of_random_workers:
            # start training each selected client
            # print("idx is ", idx)
            local_mainFL = LocalUpdate(
                args=args, dataset=dataset_train, idxs=dict_users_mainFL[idx])
            w_mainFL, loss_mainFL = local_mainFL.train(
                net=copy.deepcopy(net_glob_mainFL).to(args.device))

            # copy its updated weights
            w_locals_mainFL.append(copy.deepcopy(w_mainFL))
            # copy the training loss of that client
            loss_locals_mainFL.append(loss_mainFL)

            temp_netglob_mainFL.load_state_dict(w_mainFL)
            # test the locally updated model over small shared dataset and save its loss and accuracy for record
            temp_netglob_mainFL.eval()
            acc_test_local_mainFL, loss_test_local_mainFL = test_img(
                temp_netglob_mainFL, small_shared_dataset, args)
            # print("client loss is ", loss_test_local_mainFL)
            # print("accuracy of client is ", acc_test_local_mainFL)
            # loss_workers_total_mainFL[idx, iter_mainFL] = acc_test_local_mainFL
            # saving how many times each client is participating for just record
            workers_participation_main_fd[idx][iter_mainFL] = 1
            # saving total number of clients participated in that round (equal to C*N)
            workers_count_mainFL += 1
            selected_clients_costs_round.append(cost[idx])

        # Add others clients weights who did not participate
        # for i in range(args.num_users - len(list_of_random_workers)):
        #     w_locals_mainFL.append(pre_weights.state_dict())

        # update global weights
        # w_glob_mainFL = FedAvg(w_locals_mainFL)


        for n in range(args.num_users - m_mainFL):
            w_locals_mainFL.append(pre_net_glob.state_dict())
        # NOTE: Updated weights (@author Nathaniel).
        w_glob_mainFL = fed_avg(w_locals_mainFL, n_k)

        # copy weight to net_glob
        net_glob_mainFL.load_state_dict(w_glob_mainFL)
        # print("after ", net_glob_mainFL)

        # calculating average training loss
        # print(loss_locals_mainFL)
        loss_avg_mainFL = sum(loss_locals_mainFL) / len(loss_locals_mainFL)
        loss_train_mainFL.append(loss_avg_mainFL)
        # print(loss_avg_mainFL)

        # calculating test loss and accuracy over main large test dataset
        acc_test_round_mainfed, loss_test_round_mainfed = test_img(
            net_glob_mainFL, dataset_test, args)
        validation_test_mainFed.append(acc_test_round_mainfed)
        workers_percent_main.append(workers_count_mainFL / args.num_users)
        # calculating accuracy and loss over small shared dataset
        acc_test_final_mainFL, loss_test_final_mainFL = test_img(
            net_glob_mainFL, dataset_test, args)

        data_Global_main["Round"].append(iter_mainFL)
        data_Global_main["C"].append(args.frac)
        data_Global_main["Average Loss Train"].append(float(loss_avg_mainFL))
        data_Global_main["SDS Loss"].append(float(loss_test_global_mainFL))
        data_Global_main["SDS Accuracy"].append(float(acc_test_global_mainFL))
        data_Global_main["Workers Number"].append(float(workers_count_mainFL))
        data_Global_main["Large Test Loss"].append(
            float(loss_test_final_mainFL))
        data_Global_main["Large Test Accuracy"].append(
            float(acc_test_final_mainFL))
        data_Global_main["Communication Cost"].append(
            sum(selected_clients_costs_round))

        # TODO: This doesn't make sense?
        selected_clients_costs_total.append(sum(selected_clients_costs_round))

        iter_mainFL += 1
        # total_rounds_mainFL = iter_mainFL
        pre_net_glob = copy.deepcopy(net_glob_mainFL)

        # print(f"iter {iter_mainFL} is finished")

    # calculating the percentage of each workers participation
    workers_percent_final_mainFL = np.zeros(args.num_users)
    workers_name_mainFL = np.empty(args.num_users)
    for i in range(len(workers_participation_main_fd[:, 1])):
        workers_percent_final_mainFL[i] = sum(
            workers_participation_main_fd[i, :]) / args.epochs
        workers_name_mainFL[i] = i

    net_glob_mainFL.eval()
    # print("train test started")
    acc_train_final_main, loss_train_final_main = test_img(
        net_glob_mainFL, dataset_train, args)
    # print("train test finished")
    acc_test_final_main, loss_test_final_main = test_img(
        net_glob_mainFL, dataset_test, args)

    Final_LargeDataSetTest_MainFL["C"].append(args.frac)
    Final_LargeDataSetTest_MainFL["Test Loss"].append(
        float(loss_test_final_main))
    Final_LargeDataSetTest_MainFL["Test Accuracy"].append(
        float(acc_test_final_main))
    Final_LargeDataSetTest_MainFL["Train Loss"].append(
        float(loss_train_final_main))
    Final_LargeDataSetTest_MainFL["Train Accuracy"].append(
        float(acc_train_final_main))
    Final_LargeDataSetTest_MainFL["Communication Cost"].append(
        sum(selected_clients_costs_total))
    Final_LargeDataSetTest_MainFL["Total Rounds"].append(args.epochs)

    return float(loss_test_final_main), dict_workers_index, Final_LargeDataSetTest_MainFL, data_Global_main