# ============== EVAL ==============
        # Calculate avg training accuracy over all users at every epoch
        list_acc, list_loss = [], []
        global_model.eval()
        # print("========== idx ========== ", idx)
        for c in range(args.num_users):
            # for c in range(cluster_size):
            # C = np.random.choice(keylist, int(args.frac * args.num_users), replace=False) # random set of clients
            # print("C: ", C)
            # for c in C:
            local_model = LocalUpdate(args=args,
                                      dataset=train_dataset,
                                      idxs=user_groups[c],
                                      logger=logger)
            acc, loss = local_model.inference(model=global_model,
                                              dtype=torch.float16)
            list_acc.append(acc)
            list_loss.append(loss)
        train_accuracy.append(sum(list_acc) / len(list_acc))
        # Add
        testacc_check = 100 * train_accuracy[-1]
        epoch = epoch + 1

        # print global training loss after every 'i' rounds
        if (epoch + 1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            print('Train Accuracy: {:.2f}% \n'.format(100 *
                                                      train_accuracy[-1]))

    print('\n Total Run Time: {0:0.4f}'.format(time.time() - start_time))
コード例 #2
0
ファイル: federated_main.py プロジェクト: BSAqua/federated
        # update global weights
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # Calculate avg training accuracy over all users at every epoch
        list_acc, list_loss = [], []
        global_model.eval()
        for c in range(args.num_users):
            local_model = LocalUpdate(args=args,
                                      dataset=train_dataset,
                                      idxs=user_groups[idx],
                                      logger=logger)
            acc, loss = local_model.inference(model=global_model)
            list_acc.append(acc)
            list_loss.append(loss)
        train_accuracy.append(sum(list_acc) / len(list_acc))

        # print global training loss after every 'i' rounds
        if (epoch + 1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            print('Train Accuracy: {:.2f}% \n'.format(100 *
                                                      train_accuracy[-1]))

    # Test inference after completion of training
    test_acc, test_loss = test_inference(args, global_model, test_dataset)

    print(f' \n Results after {args.epochs} global rounds of training:')
コード例 #3
0
def train(args, global_model, raw_data_train, raw_data_test):
    start_time = time.time()
    user_list = list(raw_data_train[2].keys())
    user_weights = [None for _ in range(len(user_list))]
    user_assignments = [i % args.clusters for i in range(len(user_list))]

    # global_model.to(device)
    # global_weights = global_model.state_dict()
    global_models = [copy.deepcopy(global_model) for _ in range(args.clusters)]
    for m in global_models:
        m.to(device)

    # if args.frac == -1:
    #     m = args.cpr
    #     if m > len(user_list):
    #         raise ValueError(f"Clients Per Round: {args.cpr} is greater than number of users: {len(user_list)}")
    # else:
    #     m = max(int(args.frac * len(user_list)), 1)
    # print(f"Training {m} users each round")

    train_loss, train_accuracy = [], []
    for epoch in range(args.epochs):
        print(f"Global Training Round: {epoch + 1}/{args.epochs}")
        local_losses = []
        for modelidx, cluster_model in tqdm(enumerate(global_models)):
            local_weights = []
            for useridx, (user, user_assign) in enumerate(
                    zip(user_list, user_assignments)):
                if user_assign == modelidx:
                    local_model = LocalUpdate(args=args,
                                              raw_data=raw_data_train,
                                              user=user)
                    w, loss = local_model.update_weights(
                        copy.deepcopy(cluster_model))
                    local_weights.append(w)
                    local_losses.append(loss)
                    user_weights[useridx] = w
            if local_weights:
                cluster_model.load_state_dict(average_weights(local_weights))

        train_loss.append(sum(local_losses) / len(local_losses))

        # sampled_users = random.sample(user_list, m)
        # for user in tqdm(sampled_users):
        # FedSEM cluster reassignment step
        print(f"Calculating User Assignments")
        dists = np.zeros((len(user_list), len(global_models)))
        for cidx, cluster_model in enumerate(global_models):
            for ridx, user_weight in enumerate(user_weights):
                dists[ridx, cidx] = weight_dist(user_weight,
                                                cluster_model.state_dict())

        user_assignments = list(np.argmin(dists, axis=1))
        print("Cluster: number of clients in that cluster index")
        print(Counter(user_assignments))
        print(f"")

        # Calculate avg training accuracy over all users at every epoch
        test_acc, test_loss = [], []
        for modelidx, cluster_model in enumerate(global_models):
            local_weights = []
            for user, user_assign in zip(user_list, user_assignments):
                if modelidx == user_assign:
                    local_model = LocalUpdate(args=args,
                                              raw_data=raw_data_test,
                                              user=user)
                    acc, loss = local_model.inference(model=cluster_model)
                    test_acc.append(acc)
                    test_loss.append(loss)

        train_accuracy.append(sum(test_acc) / len(test_acc))
        wandb.log({
            "Train Loss": train_loss[-1],
            "Test Accuracy": (100 * train_accuracy[-1])
        })
        print(
            f"Train Loss: {train_loss[-1]:.4f}\t Test Accuracy: {(100 * train_accuracy[-1]):.2f}%"
        )

    print(f"Results after {args.epochs} global rounds of training:")
    print("Avg Train Accuracy: {:.2f}%".format(100 * train_accuracy[-1]))
    print(f"Total Run Time: {(time.time() - start_time):0.4f}")
コード例 #4
0
        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # Calculate avg test accuracy over LOCAL data of a fraction of users at every epoch
        last_time = time.time()
        test_users = int(args.local_test_frac * args.num_users)
        print_log('Testing global model on {} users'.format(test_users))
        list_acc, list_iou = [], []
        global_model.eval()

        for c in tqdm(range(test_users)):
            local_model = LocalUpdate(args=args,
                                      dataset=train_dataset,
                                      idxs=user_groups[idx])
            acc, iou = local_model.inference(model=global_model)
            list_acc.append(acc)
            list_iou.append(iou)
        local_test_accuracy.append(sum(list_acc) / len(list_acc))
        local_test_iou.append(sum(list_iou) / len(list_iou))

        # print global training loss after every 'i' rounds
        if (epoch + 1) % args.test_frequency == 0:
            print_log(
                '\nAvg Training Stats after {} global rounds:'.format(epoch +
                                                                      1))
            print_log('Training Loss : {}'.format(np.mean(
                np.array(train_loss))))
            print_log('Local Test Accuracy: {:.2f}% '.format(
                local_test_accuracy[-1]))
            print_log('Local Test IoU: {:.2f}%'.format(local_test_iou[-1]))
コード例 #5
0
        # update global weights
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # Calculate avg training accuracy over all users at every epoch
        list_acc, list_loss = [], []
        global_model.eval()
        for c in range(args.num_users):
            local_model = LocalUpdate(args=args,
                                      dataset=train_dataset,
                                      idxs=user_groups[idx],
                                      logger=logger)  # 只是返回了local_model的类
            acc, loss = local_model.inference(
                model=global_model
            )  # 这一步只是用了local_model的数据集,即用global_model在training dataset上做测试
            list_acc.append(acc)
            list_loss.append(loss)
        train_accuracy.append(sum(list_acc) / len(list_acc))

        # print global training loss after every 'i' rounds
        if (epoch + 1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            # print(f'Training Loss : {np.mean(np.array(train_loss))}')
            print('Train Accuracy: {:.2f}% \n'.format(100 *
                                                      train_accuracy[-1]))
            test_acc, test_loss = test_inference(args, global_model,
                                                 train_dataset)
            print(
                "test accuracy for training set: {} after {} epochs\n".format(
コード例 #6
0
ファイル: cfl.py プロジェクト: joseph-x-li/fedlearn
def train(args, global_model, raw_data_train, raw_data_test):
    start_time = time.time()
    user_list = list(raw_data_train[2].keys())[:100]
    nusers = len(user_list)
    cluster_models = [copy.deepcopy(global_model)]
    del global_model
    cluster_models[0].to(device)
    cluster_assignments = [
        user_list.copy()
    ]  # all users assigned to single cluster_model in beginning

    if args.cfl_wsharing:
        shaccumulator = Accumulator()

    if args.frac == -1:
        m = args.cpr
        if m > nusers:
            raise ValueError(
                f"Clients Per Round: {args.cpr} is greater than number of users: {nusers}"
            )
    else:
        m = max(int(args.frac * nusers), 1)
    print(f"Training {m} users each round")
    print(f"Trying to split after every {args.cfl_split_every} rounds")

    train_loss, train_accuracy = [], []
    for epoch in range(args.epochs):
        # CFL
        if (epoch + 1) % args.cfl_split_every == 0:
            all_losses = []
            new_cluster_models, new_cluster_assignments = [], []
            for cidx, (cluster_model, assignments) in enumerate(
                    tzip(cluster_models,
                         cluster_assignments,
                         desc="Try to split each cluster")):
                # First, train all models in cluster
                local_weights = []
                for user in tqdm(assignments,
                                 desc="Train ALL users in the cluster",
                                 leave=False):
                    local_model = LocalUpdate(args=args,
                                              raw_data=raw_data_train,
                                              user=user)
                    w, loss = local_model.update_weights(
                        copy.deepcopy(cluster_model),
                        local_ep_override=args.cfl_local_epochs)
                    local_weights.append(copy.deepcopy(w))
                    all_losses.append(loss)

                # record shared weights so far
                if args.cfl_wsharing:
                    shaccumulator.add(local_weights)

                weight_updates = subtract_weights(local_weights,
                                                  cluster_model.state_dict(),
                                                  args)
                similarities = pairwise_cossim(weight_updates)

                max_norm = compute_max_update_norm(weight_updates)
                mean_norm = compute_mean_update_norm(weight_updates)

                # wandb.log({"mean_norm / eps1": mean_norm, "max_norm / eps2": max_norm}, commit=False)
                split = mean_norm < args.cfl_e1 and max_norm > args.cfl_e2 and len(
                    assignments) > args.cfl_min_size
                print(f"CIDX: {cidx}[{len(assignments)}] elem")
                print(
                    f"mean_norm: {(mean_norm):.4f}; max_norm: {(max_norm):.4f}"
                )
                print(f"split? {split}")
                if split:
                    c1, c2 = cluster_clients(similarities)
                    assignments1 = [assignments[i] for i in c1]
                    assignments2 = [assignments[i] for i in c2]
                    new_cluster_assignments += [assignments1, assignments2]
                    print(
                        f"Cluster[{cidx}][{len(assignments)}] -> ({len(assignments1)}, {len(assignments2)})"
                    )

                    local_weights1 = [local_weights[i] for i in c1]
                    local_weights2 = [local_weights[i] for i in c2]

                    cluster_model.load_state_dict(
                        average_weights(local_weights1))
                    new_cluster_models.append(cluster_model)

                    cluster_model2 = copy.deepcopy(cluster_model)
                    cluster_model2.load_state_dict(
                        average_weights(local_weights2))
                    new_cluster_models.append(cluster_model2)

                else:
                    cluster_model.load_state_dict(
                        average_weights(local_weights))
                    new_cluster_models.append(cluster_model)
                    new_cluster_assignments.append(assignments)

            # Write everything
            cluster_models = new_cluster_models
            if args.cfl_wsharing:
                shaccumulator.write(cluster_models)
                shaccumulator.flush()
            cluster_assignments = new_cluster_assignments
            train_loss.append(sum(all_losses) / len(all_losses))

        # Regular FedAvg
        else:
            all_losses = []

            # Do FedAvg for each cluster
            for cluster_model, assignments in tzip(
                    cluster_models,
                    cluster_assignments,
                    desc="Train each cluster through FedAvg"):
                if args.sample_dist == "uniform":
                    sampled_users = random.sample(assignments, m)
                else:
                    xs = np.linspace(-args.sigm_domain, args.sigm_domain,
                                     len(assignments))
                    sigmdist = 1 / (1 + np.exp(-xs))
                    sampled_users = np.random.choice(assignments,
                                                     m,
                                                     p=sigmdist /
                                                     sigmdist.sum())

                local_weights = []
                for user in tqdm(sampled_users,
                                 desc="Training Selected Users",
                                 leave=False):
                    local_model = LocalUpdate(args=args,
                                              raw_data=raw_data_train,
                                              user=user)
                    w, loss = local_model.update_weights(
                        copy.deepcopy(cluster_model))
                    local_weights.append(copy.deepcopy(w))
                    all_losses.append(loss)

                # update global and shared weights
                if args.cfl_wsharing:
                    shaccumulator.add(local_weights)
                new_cluster_weights = average_weights(local_weights)
                cluster_model.load_state_dict(new_cluster_weights)

            if args.cfl_wsharing:
                shaccumulator.write(cluster_models)
                shaccumulator.flush()
            train_loss.append(sum(all_losses) / len(all_losses))

        # Calculate avg training accuracy over all users at every epoch
        # regardless if it was a CFL step or not
        test_acc, test_loss = [], []
        for cluster_model, assignments in zip(cluster_models,
                                              cluster_assignments):
            for user in assignments:
                local_model = LocalUpdate(args=args,
                                          raw_data=raw_data_test,
                                          user=user)
                acc, loss = local_model.inference(model=cluster_model)
                test_acc.append(acc)
                test_loss.append(loss)
        train_accuracy.append(sum(test_acc) / len(test_acc))

        wandb.log({
            "Train Loss": train_loss[-1],
            "Test Accuracy": (100 * train_accuracy[-1]),
            "Clusters": len(cluster_models)
        })
        print(
            f"Train Loss: {train_loss[-1]:.4f}\t Test Accuracy: {(100 * train_accuracy[-1]):.2f}%"
        )

    print(f"Results after {args.epochs} global rounds of training:")
    print("Avg Train Accuracy: {:.2f}%".format(100 * train_accuracy[-1]))
    print(f"Total Run Time: {(time.time() - start_time):0.4f}")
コード例 #7
0
def main():
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = args_parser()
    exp_details(args)

    if args.gpu:
        torch.cuda.set_device(0)
    device = 'cuda' if args.gpu else 'cpu'

    # load dataset and user groups
    train_dataset, test_dataset, user_groups = get_dataset(args)

    args.num_users = len(user_groups)

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)

    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in,
                               dim_hidden=64,
                               dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()

    # copy weights
    global_weights = global_model.state_dict()

    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 2
    val_loss_pre, counter = 0, 0

    #Beolvassuk, hogy éppen mely résztvevők vesznek részt a tanításban (0 jelentése, hogy benne van, 1 az hogy nincs)
    users = []
    fp = open('users.txt', "r")
    x = fp.readline().split(' ')
    for i in x:
        if i != '':
            users.append(int(i))
    fp.close()

    #for epoch in tqdm(range(args.epochs)):
    for epoch in range(args.epochs):
        local_weights, local_losses = [], []
        #print(f'\n | Global Training Round : {epoch+1} |\n')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        for idx in idxs_users:
            local_model = LocalUpdate(args=args,
                                      dataset=train_dataset,
                                      idxs=user_groups[idx],
                                      logger=logger)
            w, loss = local_model.update_weights(
                model=copy.deepcopy(global_model), global_round=epoch)
            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))

        global_weights = average_weights(local_weights)

        # update global weights
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # Calculate avg training accuracy over all users at every epoch
        list_acc, list_loss = [], []
        global_model.eval()
        for c in range(args.num_users):
            local_model = LocalUpdate(args=args,
                                      dataset=train_dataset,
                                      idxs=user_groups[idx],
                                      logger=logger)
            acc, loss = local_model.inference(model=global_model)
            list_acc.append(acc)
            list_loss.append(loss)
        train_accuracy.append(sum(list_acc) / len(list_acc))

        # print global training loss after every 'i' rounds
        '''if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))'''

    # Test inference after completion of training

    #Beolvassuk hogy mely résztvevőnek mely labeleket osztottuk ki.
    ftrain = open('traindataset.txt')
    testlabels = []
    line = ftrain.readline()
    while line != "":
        sor = line.split(' ')
        array = []
        for i in sor:
            array.append(int(i))
        testlabels.append(array)
        line = ftrain.readline()
    ftrain.close()

    print("USERS LABELS")
    print(testlabels)

    #Minden lehetséges koalícióra lefut a tesztelés
    for j in range((2**args.num_users) - 1):
        binary = numberToBinary(j, len(users))

        test_acc, test_loss = test_inference(args, global_model, test_dataset,
                                             testlabels, binary, len(binary))

        #Teszt eredmények kiírása
        print("RESZTVEVOK")
        print(users)
        print("TEST NUMBER")
        print(j)
        print("TEST BINARY")
        print(binary)
        print("TEST LABELS")
        print(testlabels)
        print("Test Accuracy")
        print("{:.2f}%".format(100 * test_acc))
        print()

    # Saving the objects train_loss and train_accuracy:
    '''file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
コード例 #8
0
def train(args, global_model, raw_data_train, raw_data_test):
    start_time = time.time()
    user_list = list(raw_data_train[2].keys())
    global_model.to(device)
    global_weights = global_model.state_dict()

    if args.frac == -1:
        m = args.cpr
        if m > len(user_list):
            raise ValueError(
                f"Clients Per Round: {args.cpr} is greater than number of users: {len(user_list)}"
            )
    else:
        m = max(int(args.frac * len(user_list)), 1)
    print(f"Training {m} users each round")

    train_loss, train_accuracy = [], []
    for epoch in range(args.epochs):
        local_weights, local_losses = [], []
        print(f"Global Training Round: {epoch + 1}/{args.epochs}")

        if args.sample_dist == "uniform":
            sampled_users = random.sample(user_list, m)
        else:
            xs = np.linspace(-args.sigm_domain, args.sigm_domain,
                             len(user_list))
            sigmdist = 1 / (1 + np.exp(-xs))
            sampled_users = np.random.choice(user_list,
                                             m,
                                             p=sigmdist / sigmdist.sum())

        for user in tqdm(sampled_users):
            local_model = LocalUpdate(args=args,
                                      raw_data=raw_data_train,
                                      user=user)
            w, loss = local_model.update_weights(copy.deepcopy(global_model))
            local_weights.append(copy.deepcopy(w))
            local_losses.append(loss)

        # update global weights
        global_weights = average_weights(local_weights)
        global_model.load_state_dict(global_weights)

        train_loss.append(sum(local_losses) / len(local_losses))

        # Calculate avg training accuracy over all users at every epoch
        test_acc, test_loss = [], []
        for user in user_list:
            local_model = LocalUpdate(args=args,
                                      raw_data=raw_data_test,
                                      user=user)
            acc, loss = local_model.inference(model=global_model)
            test_acc.append(acc)
            test_loss.append(loss)

        train_accuracy.append(sum(test_acc) / len(test_acc))
        wandb.log({
            "Train Loss": train_loss[-1],
            "Test Accuracy": (100 * train_accuracy[-1])
        })
        print(
            f"Train Loss: {train_loss[-1]:.4f}\t Test Accuracy: {(100 * train_accuracy[-1]):.2f}%"
        )

    print(f"Results after {args.epochs} global rounds of training:")
    print("Avg Train Accuracy: {:.2f}%".format(100 * train_accuracy[-1]))
    print(f"Total Run Time: {(time.time() - start_time):0.4f}")
コード例 #9
0
def main():
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')
    args = args_parser()
    args = adatok.arguments(args)
    exp_details(args)
    if args.gpu:
        torch.cuda.set_device(args.gpu)
    device = 'cuda' if args.gpu else 'cpu'

    # load dataset and user groups
    train_dataset, test_dataset, user_groups = get_dataset(args)

    if adatok.data.image_initialization == True:
        adatok.data.image_initialization = False
        return

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)

    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in,
                               dim_hidden=64,
                               dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    #print(global_model)

    # copy weights
    global_weights = global_model.state_dict()

    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 2
    val_loss_pre, counter = 0, 0

    for epoch in tqdm(range(args.epochs)):
        local_weights, local_losses = [], []
        #print(f'\n | Global Training Round : {epoch+1} |\n')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        for idx in idxs_users:
            local_model = LocalUpdate(args=args,
                                      dataset=train_dataset,
                                      idxs=user_groups[idx],
                                      logger=logger)
            w, loss = local_model.update_weights(
                model=copy.deepcopy(global_model), global_round=epoch)
            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))

        # update global weights
        global_weights = average_weights(local_weights)

        # update global weights
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # Calculate avg training accuracy over all users at every epoch
        list_acc, list_loss = [], []
        global_model.eval()
        for c in range(args.num_users):
            local_model = LocalUpdate(args=args,
                                      dataset=train_dataset,
                                      idxs=user_groups[idx],
                                      logger=logger)
            acc, loss = local_model.inference(model=global_model)
            list_acc.append(acc)
            list_loss.append(loss)
        train_accuracy.append(sum(list_acc) / len(list_acc))

        # print global training loss after every 'i' rounds
        '''if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))'''

        # Test inference after completion of training
        for i in adatok.data.test_groups_in_binary:
            adatok.data.actual_test_group_in_binary = i
            test_acc, test_loss = test_inference(args, global_model,
                                                 test_dataset)
            print("Resoults")
            print(epoch)
            print(adatok.data.actual_train_group_in_binary)
            print(adatok.data.actual_test_group_in_binary)
            print(test_acc)
            print(test_loss)
    '''