from federated_learning.arguments import Arguments
from federated_learning.nets import Cifar10CNN
from federated_learning.nets import Cifar100ResNet
from federated_learning.nets import FashionMNISTCNN
from federated_learning.nets import FashionMNISTResNet
from federated_learning.nets import Cifar10ResNet
from federated_learning.nets import Cifar100VGG, STL10VGG, MNISTCNN
# from federated_learning.nets import TRECCNN
import os
import torch
from loguru import logger

if __name__ == '__main__':
    args = Arguments(logger)
    if not os.path.exists(args.get_default_model_folder_path()):
        os.mkdir(args.get_default_model_folder_path())

    # ---------------------------------
    # ----------- Cifar10CNN ----------
    # ---------------------------------
    full_save_path = os.path.join(args.get_default_model_folder_path(),
                                  "Cifar10CNN.model")
    torch.save(Cifar10CNN().state_dict(), full_save_path)
    # ---------------------------------
    # --------- Cifar10ResNet ---------
    # ---------------------------------
    full_save_path = os.path.join(args.get_default_model_folder_path(),
                                  "Cifar10ResNet.model")
    torch.save(Cifar10ResNet().state_dict(), full_save_path)

    # ---------------------------------
from loguru import logger
import pathlib
import os
from federated_learning.arguments import Arguments
from federated_learning.datasets import CIFAR10Dataset
from federated_learning.datasets import FashionMNISTDataset
from federated_learning.utils import generate_train_loader
from federated_learning.utils import generate_test_loader
from federated_learning.utils import save_data_loader_to_file

if __name__ == '__main__':
    args = Arguments(logger)

    # ---------------------------------
    # ------------ CIFAR10 ------------
    # ---------------------------------
    dataset = CIFAR10Dataset(args)
    TRAIN_DATA_LOADER_FILE_PATH = "data_loaders/cifar10/train_data_loader.pickle"
    TEST_DATA_LOADER_FILE_PATH = "data_loaders/cifar10/test_data_loader.pickle"

    if not os.path.exists("data_loaders/cifar10"):
        pathlib.Path("data_loaders/cifar10").mkdir(parents=True, exist_ok=True)

    train_data_loader = generate_train_loader(args, dataset)
    test_data_loader = generate_test_loader(args, dataset)

    with open(TRAIN_DATA_LOADER_FILE_PATH, "wb") as f:
        save_data_loader_to_file(train_data_loader, f)

    with open(TEST_DATA_LOADER_FILE_PATH, "wb") as f:
        save_data_loader_to_file(test_data_loader, f)
Esempio n. 3
0
def run_exp(replacement_method, num_poisoned_workers, KWARGS,
            client_selection_strategy, idx):
    log_files, results_files, models_folders, worker_selections_files = generate_experiment_ids(
        idx, 1)

    # Initialize logger
    handler = logger.add(log_files[0], enqueue=True)

    args = Arguments(logger)
    args.set_model_save_path(models_folders[0])
    args.set_num_poisoned_workers(num_poisoned_workers)
    args.set_round_worker_selection_strategy_kwargs(KWARGS)
    args.set_client_selection_strategy(client_selection_strategy)
    args.log()

    train_data_loader = load_train_data_loader(logger, args)
    benign_data_loader = load_benign_data_loader(logger, args)
    malicious_data_loader = load_malicious_data_loader(logger, args)

    test_data_loader = load_test_data_loader(logger, args)

    # Distribute batches

    if args.get_distribution_method() == "bias":
        distributed_train_dataset = distribute_batches_bias(
            train_data_loader, args.get_num_workers())
    elif args.get_distribution_method() == "iid":
        distributed_train_dataset = distribute_batches_equally(
            train_data_loader, args.get_num_workers())
    elif args.get_distribution_method() == "noniid_1":
        distributed_train_dataset = distribute_batches_1_class(
            train_data_loader, args.get_num_workers(), args=args)
    elif args.get_distribution_method() == "noniid_2":
        distributed_train_dataset = distribute_batches_2_class(
            train_data_loader, args.get_num_workers(), args=args)
    elif args.get_distribution_method() == "noniid_mal":
        distributed_train_dataset = distribute_batches_noniid_mal(
            benign_data_loader,
            malicious_data_loader,
            args.get_num_workers(),
            args=args)
    else:
        distributed_train_dataset = distribute_batches_equally(
            train_data_loader, args.get_num_workers())

    distributed_train_dataset = convert_distributed_data_into_numpy(
        distributed_train_dataset)

    poisoned_workers = identify_random_elements(
        args.get_num_workers(), args.get_num_poisoned_workers())
    distributed_train_dataset = poison_data(logger, distributed_train_dataset,
                                            args.get_num_workers(),
                                            poisoned_workers,
                                            replacement_method,
                                            args.get_poison_effort)

    train_data_loaders = generate_data_loaders_from_distributed_dataset(
        distributed_train_dataset, args.get_batch_size())

    clients = create_clients(args, train_data_loaders, test_data_loader,
                             distributed_train_dataset)

    results, worker_selection = run_machine_learning(clients, args,
                                                     poisoned_workers)
    save_results(results, results_files[0])
    save_results(worker_selection, worker_selections_files[0])

    logger.remove(handler)
Esempio n. 4
0
                        gradient[1],
                        color="blue",
                        marker="x",
                        s=1000,
                        linewidth=5)
        else:
            plt.scatter(gradient[0], gradient[1], color="orange", s=180)

    fig.set_size_inches(SAVE_SIZE, forward=False)
    plt.grid(False)
    plt.margins(0, 0)
    plt.savefig(SAVE_NAME, bbox_inches='tight', pad_inches=0.1)


if __name__ == '__main__':
    args = Arguments(logger)
    args.log()

    model_files = sorted(os.listdir(MODELS_PATH))
    logger.debug("Number of models: {}", str(len(model_files)))

    param_diff = []
    worker_ids = []

    for epoch in EPOCHS:
        start_model_files = get_model_files_for_epoch(model_files, epoch)
        start_model_file = get_model_files_for_suffix(
            start_model_files, args.get_epoch_save_start_suffix())[0]
        start_model_file = os.path.join(MODELS_PATH, start_model_file)
        start_model = load_models(args, [start_model_file])[0]