# Add
        testacc_check = 100 * train_accuracy[-1]
        epoch = epoch + 1

        # print global training loss after every 'i' rounds
        if (epoch + 1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            print('Train Accuracy: {:.2f}% \n'.format(100 *
                                                      train_accuracy[-1]))

    print('\n Total Run Time: {0:0.4f}'.format(time.time() - start_time))

    # Test inference after completion of training
    test_acc, test_loss = test_inference(args,
                                         global_model,
                                         test_dataset,
                                         dtype=torch.float16)

    # print(f' \n Results after {args.epochs} global rounds of training:')
    print(f"\nAvg Training Stats after {epoch} global rounds:")
    print("|---- Avg Train Accuracy: {:.2f}%".format(100 * train_accuracy[-1]))
    print("|---- Test Accuracy: {:.2f}%".format(100 * test_acc))

    # Saving the objects train_loss and train_accuracy:
    file_name = '../save/objects_fp16/HFL8_{}_{}_{}_lr[{}]_C[{}]_iid[{}]_E[{}]_B[{}]_FP16.pkl'.\
    format(args.dataset, args.model, epoch, args.lr, args.frac, args.iid,
           args.local_ep, args.local_bs)

    with open(file_name, 'wb') as f:
        pickle.dump([train_loss, train_accuracy], f)
Beispiel #2
0
def main_test(args):
    start_time = time.time()
    now = datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')
    # define paths

    logger = SummaryWriter('../logs')

    # easydict 사용하는 경우 주석처리
    # args = args_parser()

    # checkpoint 생성위치
    args.save_path = os.path.join(args.save_path, args.exp_folder)
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    save_path_tmp = os.path.join(args.save_path, 'tmp_{}'.format(now))
    if not os.path.exists(save_path_tmp):
        os.makedirs(save_path_tmp)
    SAVE_PATH = os.path.join(args.save_path, '{}_{}_T[{}]_C[{}]_iid[{}]_E[{}]_B[{}]'.
                             format(args.dataset, args.model, args.epochs, args.frac, args.iid,
                                    args.local_ep, args.local_bs))

    # 시드 고정
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)



#    torch.cuda.set_device(0)
    device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu")
    cpu_device = torch.device('cpu')
    # log 파일 생성
    log_path = os.path.join('../logs', args.exp_folder)
    if not os.path.exists(log_path):
        os.makedirs(log_path)

    loggertxt = get_logger(
        os.path.join(log_path, '{}_{}_{}_{}.log'.format(args.model, args.optimizer, args.norm, now)))
    logging.info(args)
    # csv
    csv_save = '../csv/' + now
    csv_path = os.path.join(csv_save, 'accuracy.csv')
    csv_logger_keys = ['train_loss', 'accuracy']
    csvlogger = CSVLogger(csv_path, csv_logger_keys)

    # load dataset and user groups
    train_dataset, test_dataset, client_loader_dict = get_dataset(args)

    # cifar-100의 경우 자동 설정
    if args.dataset == 'cifar100':
        args.num_classes = 100
    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural network
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)
        elif args.dataset == 'cifar100':
            global_model = CNNCifar(args=args)

    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in, dim_hidden=64,
                               dim_out=args.num_classes)
    elif args.model == 'cnn_vc':
        global_model = CNNCifar_fedVC(args=args)
    elif args.model == 'cnn_vcbn':
        global_model = CNNCifar_VCBN(args=args)
    elif args.model == 'cnn_vcgn':
        global_model = CNNCifar_VCGN(args=args)
    elif args.model == 'resnet18_ws':
        global_model = resnet18(num_classes=args.num_classes, weight_stand=1)
    elif args.model == 'resnet18':
        global_model = resnet18(num_classes=args.num_classes, weight_stand=0)
    elif args.model == 'resnet32':
        global_model = ResNet32_test(num_classes=args.num_classes)
    elif args.model == 'resnet18_mabn':
        global_model = resnet18_mabn(num_classes=args.num_classes)
    elif args.model == 'vgg':
        global_model = vgg11()
    elif args.model == 'cnn_ws':
        global_model = CNNCifar_WS(args=args)


    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    loggertxt.info(global_model)
    # fedBN처럼 gn no communication 용
    client_models = [copy.deepcopy(global_model) for idx in range(args.num_users)]

    # copy weights
    global_weights = global_model.state_dict()

    global_model.to(device)
    global_model.train()

    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []


    # how does help BN 확인용
    client_loss = [[] for i in range(args.num_users)]
    client_conv_grad = [[] for i in range(args.num_users)]
    client_fc_grad = [[] for i in range(args.num_users)]
    client_total_grad_norm = [[] for i in range(args.num_users)]
    # 전체 loss 추적용 -how does help BN

    # 재시작
    if args.resume:
        checkpoint = torch.load(SAVE_PATH)
        global_model.load_state_dict(checkpoint['global_model'])
        if args.hold_normalize:
            for client_idx in range(args.num_users):
                client_models[client_idx].load_state_dict(checkpoint['model_{}'.format(client_idx)])
        else:
            for client_idx in range(args.num_users):
                client_models[client_idx].load_state_dict(checkpoint['global_model'])
        resume_iter = int(checkpoint['a_iter']) + 1
        print('Resume trainig form epoch {}'.format(resume_iter))
    else:
        resume_iter = 0


    # learning rate scheduler
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, gamma=0.1,step_size=500)

    # start training
    for epoch in tqdm(range(args.epochs)):
        local_weights, local_losses = [], []
        if args.verbose:
            print(f'\n | Global Training Round : {epoch + 1} |\n')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)


        for idx in idxs_users:
            """
            for key in global_model.state_dict().keys():
                if args.hold_normalize:
                    if 'bn' not in key:
                        client_models[idx].state_dict()[key].data.copy_(global_model.state_dict()[key])
                else:
                    client_models[idx].state_dict()[key].data.copy_(global_model.state_dict()[key])
            """
            torch.cuda.empty_cache()


            local_model = LocalUpdate(args=args, logger=logger, train_loader=client_loader_dict[idx], device=device)
            w, loss, batch_loss, conv_grad, fc_grad, total_gard_norm = local_model.update_weights(
                model=copy.deepcopy(global_model), global_round=epoch, idx_user=idx)
            local_weights.append(copy.deepcopy(w))
            # client의 1 epoch에서의 평균 loss값  ex)0.35(즉, batch loss들의 평균)
            local_losses.append(copy.deepcopy(loss))

            # 전체 round scheduler
          #  scheduler.step()
            # loss graph용 -> client당 loss값 진행 저장 -> 모두 client별로 저장.
            client_loss[idx].append(batch_loss)
            client_conv_grad[idx].append(conv_grad)
            client_fc_grad[idx].append(fc_grad)
            client_total_grad_norm[idx].append(total_gard_norm)

            # print(total_gard_norm)
            # gn, bn 복사
            # client_models[idx].load_state_dict(w)
            del local_model
            del w
        # update global weights
        global_weights = average_weights(local_weights, client_loader_dict, idxs_users)
        # update global weights
#        opt = OptRepo.name2cls('adam')(global_model.parameters(), lr=0.01, betas=(0.9, 0.99), eps=1e-3)
        opt = OptRepo.name2cls('sgd')(global_model.parameters(), lr=10, momentum=0.9)
        opt.zero_grad()
        opt_state = opt.state_dict()
        global_weights = aggregation(global_weights, global_model)
        global_model.load_state_dict(global_weights)
        opt = OptRepo.name2cls('sgd')(global_model.parameters(), lr=10, momentum=0.9)
#        opt = OptRepo.name2cls('adam')(global_model.parameters(), lr=0.01, betas=(0.9, 0.99), eps=1e-3)
        opt.load_state_dict(opt_state)
        opt.step()
        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        global_model.eval()
        #        for c in range(args.num_users):
        #            local_model = LocalUpdate(args=args, dataset=train_dataset,
        #                                      idxs=user_groups[idx], logger=logger)
        #            acc, loss = local_model.inference(model=global_model)
        #            list_acc.append(acc)
        #            list_loss.append(loss)
        #        train_accuracy.append(sum(list_acc)/len(list_acc))
        train_accuracy = test_inference(args, global_model, test_dataset, device=device)
        val_acc_list.append(train_accuracy)
        # print global training loss after every 'i' rounds
        # if (epoch+1) % print_every == 0:
        loggertxt.info(f' \nAvg Training Stats after {epoch + 1} global rounds:')
        loggertxt.info(f'Training Loss : {loss_avg}')
        loggertxt.info('Train Accuracy: {:.2f}% \n'.format(100 * train_accuracy))
        csvlogger.write_row([loss_avg, 100 * train_accuracy])
        if (epoch + 1) % 100 == 0:
            tmp_save_path = os.path.join(save_path_tmp, 'tmp_{}.pt'.format(epoch+1))
            torch.save(global_model.state_dict(),tmp_save_path)
    # Test inference after completion of training
    test_acc = test_inference(args, global_model, test_dataset, device=device)

    print(' Saving checkpoints to {}...'.format(SAVE_PATH))
    if args.hold_normalize:
        client_dict = {}
        for idx, model in enumerate(client_models):
            client_dict['model_{}'.format(idx)] = model.state_dict()
        torch.save(client_dict, SAVE_PATH)
    else:
        torch.save({'global_model': global_model.state_dict()}, SAVE_PATH)

    loggertxt.info(f' \n Results after {args.epochs} global rounds of training:')
    # loggertxt.info("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
    loggertxt.info("|---- Test Accuracy: {:.2f}%".format(100 * test_acc))


    # frac이 1이 아닐경우 잘 작동하지않음.
    # batch_loss_list = np.array(client_loss).sum(axis=0) / args.num_users

    # conv_grad_list = np.array(client_conv_grad).sum(axis=0) / args.num_users
    # fc_grad_list = np.array(client_fc_grad).sum(axis=0) / args.num_users
    # total_grad_list = np.array(client_total_grad_norm).sum(axis=0) /args.num_users
    # client의 avg를 구하고 싶었으나 현재는 client 0만 확인
    # client마다 batch가 다를 경우 bug 예상
    return train_loss, val_acc_list, client_loss[0], client_conv_grad[0], client_fc_grad[0], client_total_grad_norm[0]
Beispiel #3
0
                '\nAvg Training Stats after {} global rounds:'.format(epoch +
                                                                      1))
            print_log('Training Loss : {}'.format(np.mean(
                np.array(train_loss))))
            print_log('Local Test Accuracy: {:.2f}% '.format(
                local_test_accuracy[-1]))
            print_log('Local Test IoU: {:.2f}%'.format(local_test_iou[-1]))
            print_log('Run Time: {0:0.4f}\n'.format(
                (time.time() - last_time) // 60))

    # torch.save(weights, 'weights.pt')# comment off for checking weights update

    # GLOBAL test dataset evaluation after completion of training
        if not args.train_only and (epoch + 1) % args.test_frequency == 0:
            print_log('\nTesting global model on global test dataset')
            test_acc, test_iou, confmat = test_inference(
                args, global_model, test_loader)
            print_log(confmat)
            print_log('\nResults after {} global rounds of training:'.format(
                args.epochs))
            print_log("|---- Global Test Accuracy: {:.2f}%".format(test_acc))
            print_log("|---- Global Test IoU: {:.2f}%".format(test_iou))
            print_log('\n Total Run Time: {0:0.4f}'.format(
                (time.time() - start_time) // 60))

    # Plot Loss curve
    if args.epochs > 1:
        plt.figure()
        plt.title('Training Loss vs Communication rounds')
        plt.plot(range(len(train_loss)), train_loss, color='r')
        plt.ylabel('Training loss')
        plt.xlabel('Communication Rounds')
Beispiel #4
0
        # for idx in idxs_users:
        #     local_model = LocalUpdate(args=args, dataset=train_dataset[idx],
        #                               idxs=user_groups[idx], logger=logger)
        #     acc, loss = local_model.inference(model=global_model)
        #     list_acc.append(acc)
        #     list_loss.append(loss)
        # train_accuracy.append(sum(list_acc)/len(list_acc))

        # print global training loss after every 'i' rounds
        if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            # print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))

    # Test inference after completion of training
    test_acc, test_loss = test_inference(args, global_model, test_dataset)

    print(f' \n Results after {args.epochs} global rounds of training:')
    # print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
    print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

    ######## Timing starts ########
    start_time = time.time() # start the timer 

    accuracy_dict[tuple(idxs_users)] = test_acc
    submodel_dict[tuple(idxs_users)] = copy.deepcopy(global_model)
    err = 0.01 # convergence criteria
    
    # initialize the tmc scores. it will be a (0xn) array where n is the number of users
    mem_tmc = np.zeros((0, args.num_users))
        # for idx in idxs_users:
        #     local_model = LocalUpdate(args=args, dataset=train_dataset[idx],
        #                               idxs=user_groups[idx], logger=logger)
        #     acc, loss = local_model.inference(model=global_model)
        #     list_acc.append(acc)
        #     list_loss.append(loss)
        # train_accuracy.append(sum(list_acc)/len(list_acc))

        # print global training loss after every 'i' rounds
        if (epoch + 1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            # print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))

    # Test inference after completion of training
    test_acc, test_loss = test_inference(args, global_model, test_dataset)

    print(f' \n Results after {args.epochs} global rounds of training:')
    # print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
    print("|---- Test Accuracy: {:.2f}%".format(100 * test_acc))

    ######## Timing starts ########
    start_time = time.time()  # start the timer

    # accuracy for the full dataset is the same as global accuracy
    accuracy_dict[powerset[-1]] = test_acc

    # Test inference for the sub-models in submodel_dict
    for subset in powerset[:-1]:
        test_acc, test_loss = test_inference(args, submodel_dict[subset],
                                             test_dataset)
def poisoned_NoDefense(nb_attackers, seed=1):

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = args_parser()
    exp_details(args)

    # set seed
    torch.manual_seed(seed)
    np.random.seed(seed)

    # device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load dataset and user groups
    train_dataset, test_dataset, user_groups = get_dataset(args)


    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)

    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in, dim_hidden=64,
                               dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)

    # copy weights
    global_weights = global_model.state_dict()

    # backdoor model
    dummy_model = copy.deepcopy(global_model)
    dummy_model.load_state_dict(torch.load('../save/all_5_model.pth'))
    dummy_norm = 0
    for x in dummy_model.state_dict().values():
        dummy_norm += x.norm(2).item() ** 2
    dummy_norm = dummy_norm ** (1. / 2)

    # testing accuracy for global model
    testing_accuracy = [0.1]

    for epoch in tqdm(range(args.epochs)):
        local_del_w = []
        print(f'\n | Global Training Round : {epoch+1} |\n')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        # Adversary updates
        for idx in idxs_users[0:nb_attackers]:
            print("evil")
            local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger)
            #del_w, _ = local_model.poisoned_SGA(model=copy.deepcopy(global_model), change=1)

            w = copy.deepcopy(dummy_model)
            # compute change in parameters and norm
            zeta = 0
            for del_w, w_old in zip(w.parameters(), global_model.parameters()):
                del_w.data -= copy.deepcopy(w_old.data)
                del_w.data *= m / nb_attackers
                del_w.data += copy.deepcopy(w_old.data)
                zeta += del_w.norm(2).item() ** 2
            zeta = zeta ** (1. / 2)
            del_w = copy.deepcopy(w.state_dict())
            local_del_w.append(copy.deepcopy(del_w))


        # Non-adversarial updates
        for idx in idxs_users[nb_attackers:]:
            print("good")
            local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger)
            del_w, _ = local_model.update_weights(model=copy.deepcopy(global_model), change=1)
            local_del_w.append(copy.deepcopy(del_w))

        # average local updates
        average_del_w = average_weights(local_del_w)

        # Update global model: w_{t+1} = w_{t} + average_del_w
        for param, param_del_w in zip(global_weights.values(), average_del_w.values()):
            param += param_del_w
        global_model.load_state_dict(global_weights)

        # test accuracy
        test_acc, test_loss = test_inference(args, global_model, test_dataset)
        testing_accuracy.append(test_acc)

        print("Test accuracy")
        print(testing_accuracy)

    # save test accuracy
    np.savetxt('../save/RandomAttack/NoDefense_iid_{}_{}_attackers{}_seed{}.txt'.
                 format(args.dataset, args.model, nb_attackers, s), testing_accuracy)
            list_acc.append(acc)
            list_loss.append(loss)
        train_accuracy.append(sum(list_acc)/len(list_acc))
        '''
        # print global training loss after every 'i' rounds
        if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))



        #print("DICTIONARY     :       ", rm_dict)
        # I moved this section inside
        # Test inference after completion of training
        test_acc, test_l = test_inference(args, model=global_model, test_dataset=test_dataset, avg_rm=average_rm, avg_rv=average_rv)
        test_accuracy.append(test_acc)
        test_loss.append(test_l)
        print(f' \n Results after {args.epochs} global rounds of training:')
        #print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
        print("|---- Avg Train Accuracy: {:.2f}%".format(100*np.mean(np.array(train_accuracy))))
        print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

    # Saving the objects train_loss and train_accuracy:
    

    file_directory = "/home/janati/Desktop/github/FL-achwin/Federated-Learning-PyTorch/save/objects"
    file_name = os.path.join(file_directory, '{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
        format(args.dataset, args.model, args.epochs, args.frac, args.iid,
               args.local_ep, args.local_bs))
Beispiel #8
0
def central_DP_FL(norm_bound, noise_scale, seed=1):
    # Central DP to protect against attackers

    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = args_parser()
    exp_details(args)

    # set seed
    torch.manual_seed(seed)
    np.random.seed(seed)

    # device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load dataset and user groups
    train_dataset, test_dataset, user_groups = get_dataset(args)

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)

    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in,
                               dim_hidden=64,
                               dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)

    # copy weights
    global_weights = global_model.state_dict()

    # testing accuracy for global model
    testing_accuracy = [0.1]

    for epoch in tqdm(range(args.epochs)):
        local_del_w, local_norms = [], []
        print(f'\n | Global Training Round : {epoch+1} |\n')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        # Local Client Updates
        for idx in idxs_users:
            local_model = LocalUpdate(args=args,
                                      dataset=train_dataset,
                                      idxs=user_groups[idx],
                                      logger=logger)
            del_w, zeta = local_model.update_weights(
                model=copy.deepcopy(global_model), change=1)
            local_del_w.append(copy.deepcopy(del_w))
            local_norms.append(copy.deepcopy(zeta))

        # norm bound (e.g. median of norms)
        clip_factor = norm_bound  #min(norm_bound, np.median(local_norms))
        print(clip_factor)

        # clip weight updates
        for i in range(len(idxs_users)):
            for param in local_del_w[i].values():
                param /= max(1, local_norms[i] / clip_factor)

        # average the clipped weight updates
        average_del_w = average_weights(local_del_w)

        # Update global model using clipped weight updates, and add noise
        # w_{t+1} = w_{t} + avg(del_w1 + del_w2 + ... + del_wc) + Noise
        for param, param_del_w in zip(global_weights.values(),
                                      average_del_w.values()):
            param += param_del_w
            param += torch.randn(
                param.size()) * noise_scale * clip_factor / len(idxs_users)
        global_model.load_state_dict(global_weights)

        # test accuracy
        test_acc, test_loss = test_inference(args, global_model, test_dataset)
        testing_accuracy.append(test_acc)

        print("Test accuracy")
        print(testing_accuracy)

    # save test accuracy
    np.savetxt(
        '../save/NoAttacks/GDP_iid{}_{}_{}_norm{}_scale{}_seed{}.txt'.format(
            args.iid, args.dataset, args.model, norm_bound, noise_scale, s),
        testing_accuracy)
Beispiel #9
0
def non_private_FL(seed=1):
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = args_parser()
    exp_details(args)

    # set seed
    torch.manual_seed(seed)
    np.random.seed(seed)

    # device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load dataset and user groups
    train_dataset, test_dataset, user_groups = get_dataset(args)

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)

    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in,
                               dim_hidden=64,
                               dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()

    # copy weights
    global_weights = global_model.state_dict()

    # testing accuracy for global model
    testing_accuracy = [0.1]

    for epoch in tqdm(range(args.epochs)):
        local_del_w, local_norms = [], []
        print(f'\n | Global Training Round : {epoch+1} |\n')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        # Local Client Updates
        for idx in idxs_users:
            local_model = LocalUpdate(args=args,
                                      dataset=train_dataset,
                                      idxs=user_groups[idx],
                                      logger=logger)
            del_w, _ = local_model.update_weights(
                model=copy.deepcopy(global_model), change=1)
            local_del_w.append(del_w)

        # average local updates
        average_del_w = average_weights(local_del_w)

        # Update global model: w_{t+1} = w_{t} + average_del_w
        for param, param_del_w in zip(global_weights.values(),
                                      average_del_w.values()):
            param += param_del_w
        global_model.load_state_dict(global_weights)

        # test accuracy, backdoor accuracy
        test_acc, test_loss = test_inference(args, global_model, test_dataset)
        testing_accuracy.append(test_acc)

        print("Test & Backdoor accuracy")
        print(testing_accuracy)

    # save accuracy
    np.savetxt(
        '../save/NoAttacks/NonPrivate_iid{}_{}_{}_seed{}.txt'.format(
            args.iid, args.dataset, args.model, s), testing_accuracy)
def poisoned_LDP(nb_attackers, norm_bound, noise_scale, seed=1):
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = args_parser()
    exp_details(args)

    # set seed
    torch.manual_seed(seed)
    np.random.seed(seed)

    # device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load dataset and user groups
    train_dataset, test_dataset, user_groups = get_dataset(args)

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)

    elif args.model == 'mlp':
        # Multi-layer perceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in,
                               dim_hidden=64,
                               dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)

    # copy weights
    global_weights = global_model.state_dict()

    # testing accuracy for global model
    testing_accuracy = [0.1]

    for epoch in tqdm(range(args.epochs)):
        local_w = []
        print(f'\n | Global Training Round : {epoch + 1} |\n')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        # Adversary updates
        print("Evil")
        for idx in idxs_users[0:nb_attackers]:
            print(idx)
            local_model = LocalUpdate(args=args,
                                      dataset=train_dataset,
                                      idxs=user_groups[idx],
                                      logger=logger)
            w, _ = local_model.poisoned_ldp(model=copy.deepcopy(global_model),
                                            norm_bound=norm_bound,
                                            noise_scale=noise_scale)
            local_w.append(copy.deepcopy(w))

        # Non-adversarial updates
        print("Good")
        for idx in idxs_users[nb_attackers:]:
            print(idx)
            local_model = LocalUpdate(args=args,
                                      dataset=train_dataset,
                                      idxs=user_groups[idx],
                                      logger=logger)
            w, _ = local_model.dp_sgd(model=copy.deepcopy(global_model),
                                      norm_bound=norm_bound,
                                      noise_scale=noise_scale)
            local_w.append(copy.deepcopy(w))

        # update global weights
        global_weights = average_weights(local_w)
        global_model.load_state_dict(global_weights)

        # test accuracy
        test_acc, test_loss = test_inference(args, global_model, test_dataset)
        testing_accuracy.append(test_acc)

        print("Test accuracy")
        print(testing_accuracy)

    # save accuracy
    np.savetxt(
        '../save/RandomAttack/LDP_FL_{}_{}_norm{}_scale{}_attackers{}_seed{}.txt'
        .format(args.dataset, args.model, norm_bound, noise_scale,
                nb_attackers, s), testing_accuracy)
        # update global weights
        global_weights = average_weights_baseline(local_weights)

        # update global weights
        global_model.load_state_dict(global_weights, strict=False)

        loss_avg = sum(local_losses) / len(local_losses)
        accuracy_avg = sum(local_accuracies) / len(local_accuracies)
        train_loss.append(loss_avg)
        train_accuracy.append(accuracy_avg)

        # print global training loss after every 'i' rounds
        if (epoch + 1) % save_log == 0:

            # Test inference
            test_acc, test_l = test_inference(args, global_model, test_dataset)
            test_accuracy.append(test_acc)
            test_loss.append(test_l)

            logging.info('Epoch : %d', epoch + 1)
            logging.info('|---- Avg Train Accuracy: {:.2f}'.format(
                100 * np.mean(np.array(train_accuracy))))
            logging.info('|---- Training Loss : {:.2f}'.format(
                np.mean(np.array(train_loss))))
            logging.info('|---- Test Accuracy: {:.2f}'.format(100 * test_acc))
            logging.info('|---- Test loss: {:.2f}'.format(test_l))

            if (epoch + 1) % print_every == 0:
                print(
                    f' \n Results after {epoch+1} global rounds of training:')
                #print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
Beispiel #12
0
def main():
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')
    args = args_parser()
    args = adatok.arguments(args)
    exp_details(args)
    if args.gpu:
        torch.cuda.set_device(args.gpu)
    device = 'cuda' if args.gpu else 'cpu'

    # load dataset and user groups
    train_dataset, test_dataset, user_groups = get_dataset(args)

    if adatok.data.image_initialization == True:
        adatok.data.image_initialization = False
        return

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)

    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in,
                               dim_hidden=64,
                               dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    #print(global_model)

    # copy weights
    global_weights = global_model.state_dict()

    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 2
    val_loss_pre, counter = 0, 0

    for epoch in tqdm(range(args.epochs)):
        local_weights, local_losses = [], []
        #print(f'\n | Global Training Round : {epoch+1} |\n')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        for idx in idxs_users:
            local_model = LocalUpdate(args=args,
                                      dataset=train_dataset,
                                      idxs=user_groups[idx],
                                      logger=logger)
            w, loss = local_model.update_weights(
                model=copy.deepcopy(global_model), global_round=epoch)
            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))

        # update global weights
        global_weights = average_weights(local_weights)

        # update global weights
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # Calculate avg training accuracy over all users at every epoch
        list_acc, list_loss = [], []
        global_model.eval()
        for c in range(args.num_users):
            local_model = LocalUpdate(args=args,
                                      dataset=train_dataset,
                                      idxs=user_groups[idx],
                                      logger=logger)
            acc, loss = local_model.inference(model=global_model)
            list_acc.append(acc)
            list_loss.append(loss)
        train_accuracy.append(sum(list_acc) / len(list_acc))

        # print global training loss after every 'i' rounds
        '''if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))'''

        # Test inference after completion of training
        for i in adatok.data.test_groups_in_binary:
            adatok.data.actual_test_group_in_binary = i
            test_acc, test_loss = test_inference(args, global_model,
                                                 test_dataset)
            print("Resoults")
            print(epoch)
            print(adatok.data.actual_train_group_in_binary)
            print(adatok.data.actual_test_group_in_binary)
            print(test_acc)
            print(test_loss)
    '''
Beispiel #13
0
def main():

    model_path = 'results/%s/%s/%s/seed_%d' % (args.dataset, args.method, args.net_type, args.seed)
    if not os.path.isdir(model_path):
        mkdir_p(model_path)
    # load datasets
    train_dataset, test_dataset, _ = get_dataset(args)

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar10':
            global_model = CNNCifar(args=args)
        elif args.dataset == 'cub200':
            if args.net_type == 'resnet':
                #global_model = models.resnet50(pretrained=True)
                global_model = models.resnet18(pretrained=True)
                global_model.fc = torch.nn.Linear(global_model.fc.in_features, cf.num_classes[args.dataset])
    elif args.model == 'mlp':
        # Multi-layer preceptron
        #img_size is torch.Size([1, 28, 28])
        img_size = train_dataset[0][0].shape
        len_in = 1
    #toclarify: why do we have to call the MLP code 3 times?
    #TODO: try to move global_model out of the bracket  
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in, dim_hidden=64,
                               dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)
    wandb.watch(global_model)

    # Training
    # Set optimizer and criterion
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(global_model.parameters(), lr=args.lr,
                                    momentum=cf.momentum[args.dataset], weight_decay=5e-4)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(global_model.parameters(), lr=args.lr,
                                     weight_decay=1e-4)
#batch_size = 
    trainloader = DataLoader(train_dataset, batch_size=int(args.local_bs * (args.num_users * args.frac)), shuffle=True, num_workers=args.workers,
                             pin_memory=use_cuda, drop_last=True)

    criterion = torch.nn.CrossEntropyLoss().to(device)

    epoch_loss = []
    test_acc_lst = []
    best_acc = 0
    args.lr = cf.lr[args.dataset]

    for epoch in tqdm(range(args.epochs)):
        global_model.train()
        batch_loss = []

        # adjest learning rate per global round
        if epoch != 0:
            adjust_learning_rate([optimizer], args, epoch)

        for batch_idx, (images, labels) in enumerate(trainloader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = global_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            #if batch_idx % 50 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLr: {:.4f}'.format(
                    epoch+1, (batch_idx+1) * len(images), len(trainloader.dataset),
                    100. * (batch_idx+1) / len(trainloader), loss.item(), args.lr))
            batch_loss.append(loss.item())

            wandb.log({'Train Loss': loss.item()})

        loss_avg = sum(batch_loss)/len(batch_loss)
        print('\nTrain loss: \n', loss_avg)
        epoch_loss.append(loss_avg)

        test_acc, test_loss = test_inference(args, global_model, test_dataset)
        test_acc_lst.append(test_acc)


        #save model
        if test_acc > best_acc:
            best_acc = test_acc
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': global_model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            }, dir=model_path, filename='checkpoint.pth.tar')

        print('\nTrain Epoch: {}, Test acc: {:.2f}%, Best Test acc: {:.2f}%'.format(epoch + 1, test_acc, best_acc))

        # log training loss, test accuracy at wandb
        wandb.log({'Test Acc': test_acc,
                   'Best Acc': best_acc})

        # if model achieves target test acc, stop training
        if best_acc >= args.target_acc:
            print('Total Global round: ', epoch+1)
            break

    if not os.path.isdir(os.path.join(model_path, 'save')):
        mkdir_p(os.path.join(model_path, 'save'))
    # Plot loss
    plt.figure()
    plt.plot(range(len(epoch_loss)), epoch_loss)
    plt.xlabel('epochs')
    plt.ylabel('Train loss')
    plt.savefig(os.path.join(model_path, 'save/nn_{}_{}_{}_loss.png'.format(args.dataset, args.model,
                                                 args.epochs)))

    # Plot test acc per epoch
    plt.figure()
    plt.plot(range(len(test_acc_lst)), test_acc_lst)
    plt.xlabel('epochs')
    plt.ylabel('Test accuracy')
    plt.savefig(os.path.join(model_path, 'save/nn_{}_{}_{}_acc.png'.format(args.dataset, args.model,
                                                                            args.epochs)))
    # testing
    #test_acc, test_loss = test_inference(args, global_model, test_dataset)
    print('Test on', len(test_dataset), 'samples')
    print("Best Test Accuracy: {:.2f}%".format(best_acc))
                                       idxs=user_groups[idx],
                                       logger=logger,
                                       device=device)
            acc, loss = client_shard.inference(model=global_model)
            list_acc.append(acc)
            list_loss.append(loss)
        train_accuracy.append(sum(list_acc) / len(list_acc))
        ###
        # print global training loss after every 'i' rounds
        if (epoch + 1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch + 1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            print('Train Accuracy: {:.2f}% \n'.format(100 * train_accuracy[-1]))

    # Test inference after completion of training
    test_acc, test_loss = test_inference(args=args, model=global_model, test_dataset=test_dataset, device=device)

    print(f' \n Results after {args.epochs} global rounds of training:')
    print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
    print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))
    ### }}} End of Federated learning.

    # Saving the objects train_loss and train_accuracy:
    makedirs('./save/objects/', exist_ok=True)
    file_name = './save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
        format(args.dataset, args.model, args.epochs, args.frac, args.iid,
               args.local_ep, args.local_bs)

    with open(file_name, 'wb') as f:
        pickle.dump([train_loss, train_accuracy], f)
Beispiel #15
0
                                      idxs=user_groups[idx],
                                      logger=logger)
            acc, loss = local_model.inference(model=global_model)
            list_acc.append(acc)
            list_loss.append(loss)
        train_accuracy.append(sum(list_acc) / len(list_acc))

        # print global training loss after every 'i' rounds
        if (epoch + 1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            print('Train Accuracy: {:.2f}% \n'.format(100 *
                                                      train_accuracy[-1]))

    # Test inference after completion of training
    test_acc, test_loss = test_inference(args, global_model,
                                         test_dataset)  # 在测试集上测试全局模型的准确率

    print(f' \n Results after {args.epochs} global rounds of training:')
    print("|---- Avg Train Accuracy: {:.2f}%".format(
        100 * train_accuracy[-1]))  # -1是因为
    print("|---- Test Accuracy: {:.2f}%".format(100 * test_acc))

    # Saving the objects train_loss and train_accuracy:
    file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
        format(args.dataset, args.model, args.epochs, args.frac, args.iid,
               args.local_ep, args.local_bs)

    with open(file_name, 'wb') as f:
        pickle.dump([train_loss, train_accuracy], f)

    print('\n Total Run Time: {0:0.4f}'.format(time.time() - start_time))
def main():
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = args_parser()
    exp_details(args)

    if args.gpu:
        torch.cuda.set_device(0)
    device = 'cuda' if args.gpu else 'cpu'

    # load dataset and user groups
    train_dataset, test_dataset, user_groups = get_dataset(args)

    args.num_users = len(user_groups)

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)

    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in,
                               dim_hidden=64,
                               dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()

    # copy weights
    global_weights = global_model.state_dict()

    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 2
    val_loss_pre, counter = 0, 0

    #Beolvassuk, hogy éppen mely résztvevők vesznek részt a tanításban (0 jelentése, hogy benne van, 1 az hogy nincs)
    users = []
    fp = open('users.txt', "r")
    x = fp.readline().split(' ')
    for i in x:
        if i != '':
            users.append(int(i))
    fp.close()

    #for epoch in tqdm(range(args.epochs)):
    for epoch in range(args.epochs):
        local_weights, local_losses = [], []
        #print(f'\n | Global Training Round : {epoch+1} |\n')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        for idx in idxs_users:
            local_model = LocalUpdate(args=args,
                                      dataset=train_dataset,
                                      idxs=user_groups[idx],
                                      logger=logger)
            w, loss = local_model.update_weights(
                model=copy.deepcopy(global_model), global_round=epoch)
            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))

        global_weights = average_weights(local_weights)

        # update global weights
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # Calculate avg training accuracy over all users at every epoch
        list_acc, list_loss = [], []
        global_model.eval()
        for c in range(args.num_users):
            local_model = LocalUpdate(args=args,
                                      dataset=train_dataset,
                                      idxs=user_groups[idx],
                                      logger=logger)
            acc, loss = local_model.inference(model=global_model)
            list_acc.append(acc)
            list_loss.append(loss)
        train_accuracy.append(sum(list_acc) / len(list_acc))

        # print global training loss after every 'i' rounds
        '''if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))'''

    # Test inference after completion of training

    #Beolvassuk hogy mely résztvevőnek mely labeleket osztottuk ki.
    ftrain = open('traindataset.txt')
    testlabels = []
    line = ftrain.readline()
    while line != "":
        sor = line.split(' ')
        array = []
        for i in sor:
            array.append(int(i))
        testlabels.append(array)
        line = ftrain.readline()
    ftrain.close()

    print("USERS LABELS")
    print(testlabels)

    #Minden lehetséges koalícióra lefut a tesztelés
    for j in range((2**args.num_users) - 1):
        binary = numberToBinary(j, len(users))

        test_acc, test_loss = test_inference(args, global_model, test_dataset,
                                             testlabels, binary, len(binary))

        #Teszt eredmények kiírása
        print("RESZTVEVOK")
        print(users)
        print("TEST NUMBER")
        print(j)
        print("TEST BINARY")
        print(binary)
        print("TEST LABELS")
        print(testlabels)
        print("Test Accuracy")
        print("{:.2f}%".format(100 * test_acc))
        print()

    # Saving the objects train_loss and train_accuracy:
    '''file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
        list_acc, list_loss = [], []
        global_model.eval()
        # for c in range(args.num_users):
        #     local_model = LocalUpdate(args=args, dataset=train_dataset,
        #                               idxs=user_groups[idx], logger=logger) # 只是返回了local_model的类
        #     acc, loss = local_model.inference(model=global_model) # 这一步只是用了local_model的数据集,即用global_model在training dataset上做测试
        #     list_acc.append(acc)
        #     list_loss.append(loss)
        # train_accuracy.append(sum(list_acc)/len(list_acc))

        # print global training loss after every 'i' rounds
        if (epoch + 1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            # print(f'Training Loss : {np.mean(np.array(train_loss))}')
            # print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))
            test_acc, test_loss = test_inference(args, global_model,
                                                 train_dataset)
            print(
                "test accuracy for training set: {} after {} epochs\n".format(
                    test_acc, epoch + 1))
            test_acc, test_loss = test_inference(args, global_model,
                                                 test_dataset)
            print("test accuracy for test set: {} after {} epochs\n".format(
                test_acc, epoch + 1))

    # Test inference after completion of training
    test_acc, test_loss = test_inference(args, global_model, test_dataset)

    print(f' \n Results after {args.epochs} global rounds of training:')
    print("|---- Avg Train Accuracy: {:.2f}%".format(100 * train_accuracy[-1]))
    print("|---- Test Accuracy: {:.2f}%".format(100 * test_acc))
def AdjustedTMC_oneiteration(test_acc, idxs_users, submodel_dict, accuracy_dict, fraction, args, test_dataset, tolerance=0.005, num_tolerance=2,\
                    random_score=0):
    """
    Runs one iteration of Adjusted OR-TMC-Shapley algorithm
    :param tolerance: percentage difference from the test_acc of the global model
    :param num_tolerance: number of time the performance tolerance has to be met to truncate the loop
    :param test_acc: test accuracy of the global model on the official MNIST testset
    :param submodel_dict: dictionary containing the submodels (only those based on single participant's data)
    :param idxs_users: the array containing the indexes of participants
    :param fraction: list of fraction of data from each user (all users)
    :param args: used in test_inference function 
    :param test_dataset: the test dataset used to evaluate the submodels
    :param accuracy_dict: the dictionary to store the evaluation of the submodels
    :param random_score: the accuracy of the randomly initialized model 

    :returns marginal_contribs: a 1D array storing the marginal contributions of participants
    :returns
    """
    average_marginal_contribs = np.zeros(len(idxs_users))
    for i in idxs_users:
        # form a random permutation from the array of indexes of participants
        idxs = np.concatenate(
            (np.array([i]),
             np.random.permutation([x for x in idxs_users if x != i])))
        # print("idxs is", idxs) # print check

        # initialize the marginal contributions of all participants to be 0
        marginal_contribs = np.zeros(len(idxs_users))

        # the truncation counter is incremented when the performance tolerance is met
        # when truncation counter is 2, truncate the algorithm
        truncation_counter = 0

        # initialize the score to be the random score before the data comes in
        new_score = random_score

        # list to keep track of the data fraction. the first element is the fraction so far, the second element is
        # the fraction of the current user with idx
        fraction_tmc = [0, 0]

        # performance tolerance
        # print("performance tolerance is", tolerance * test_acc) # print check

        for n, idx in enumerate(idxs):
            # print("idx is", idx) # print check
            # add the fraction of data from user with index idx as the second element of fraction_tmc
            fraction_tmc[1] = fraction[idx - 1]
            # print("fraction_tmc is", fraction_tmc) # print check
            old_score = new_score

            # if n == 0:
            #     # initialize the model to the the one from by the first participant in the permutation
            #     model = copy.deepcopy(submodel_dict[(idx,)])
            # else:
            #     # calculate the average of the subset of weights from list of all the weights
            #     subset = idxs[:n+1]
            #     subset = tuple(np.sort(subset, kind='mergesort')) # sort the subset and change it to a tuple
            #     print("subset is", subset) # print check

            #     subset_weights = average_weights([submodel_dict[(i,)].state_dict() for i in subset], [fraction[i-1] for i in subset])
            #     # subset_weights = average_weights([model.state_dict(), submodel_dict[(idx,)].state_dict()], fraction_tmc)
            #     # form the model up till that point
            #     model.load_state_dict(subset_weights)
            #     # store it in the originaly submodel_dict for easy reference
            #     submodel_dict[subset].load_state_dict(subset_weights)

            # calculate the average of the subset of weights from list of all the weights
            subset = idxs[:n + 1]
            subset = tuple(np.sort(
                subset,
                kind='mergesort'))  # sort the subset and change it to a tuple
            # print("subset is", subset) # print check
            model = copy.deepcopy(submodel_dict[()])
            if accuracy_dict.get(subset) == None:
                if submodel_dict.get(subset) == None:
                    subset_weights = average_weights(
                        [submodel_dict[(i, )].state_dict() for i in subset],
                        [fraction[i - 1] for i in subset])
                    # subset_weights = average_weights([model.state_dict(), submodel_dict[(idx,)].state_dict()], fraction_tmc)
                    # form the model up till that point
                    model.load_state_dict(subset_weights)
                    # store it in the originaly submodel_dict for easy reference
                    submodel_dict[subset] = copy.deepcopy(model)
                else:
                    model = copy.deepcopy(submodel_dict[subset])

                # get the new score
                new_score, _ = test_inference(args, model, test_dataset)
                accuracy_dict[subset] = new_score
            else:
                new_score = accuracy_dict[subset]

            # print("new_score is", new_score) # print check
            marginal_contribs[idx - 1] = new_score - old_score
            # print("marginals:", marginal_contribs) # print check
            # find the distance to full score (the test_acc of the global model on all the 5 datapoints)
            distance_to_full_score = np.abs(new_score - test_acc)
            # print("distance is", distance_to_full_score) # print check
            if distance_to_full_score <= tolerance * test_acc:
                truncation_counter += 1
                # print("truncation_counter", truncation_counter)
                #truncate when the distance_to_full_score becomes very close for 2 times
                if truncation_counter >= num_tolerance:
                    break
            else:
                truncation_counter = 0

            fraction_tmc[0] += fraction_tmc[
                1]  # sum the fractions onto the first element in the list
        average_marginal_contribs += marginal_contribs

    average_marginal_contribs /= args.num_users  # find the average of the num_users number of permutations
    # print("number of times in the loop for this TMC iteration is (n+1) = ", n+1) # print check
    return average_marginal_contribs