from federated_learning.arguments import Arguments from federated_learning.nets import Cifar10CNN from federated_learning.nets import Cifar100ResNet from federated_learning.nets import FashionMNISTCNN from federated_learning.nets import FashionMNISTResNet from federated_learning.nets import Cifar10ResNet from federated_learning.nets import Cifar100VGG, STL10VGG, MNISTCNN # from federated_learning.nets import TRECCNN import os import torch from loguru import logger if __name__ == '__main__': args = Arguments(logger) if not os.path.exists(args.get_default_model_folder_path()): os.mkdir(args.get_default_model_folder_path()) # --------------------------------- # ----------- Cifar10CNN ---------- # --------------------------------- full_save_path = os.path.join(args.get_default_model_folder_path(), "Cifar10CNN.model") torch.save(Cifar10CNN().state_dict(), full_save_path) # --------------------------------- # --------- Cifar10ResNet --------- # --------------------------------- full_save_path = os.path.join(args.get_default_model_folder_path(), "Cifar10ResNet.model") torch.save(Cifar10ResNet().state_dict(), full_save_path) # ---------------------------------
from loguru import logger import pathlib import os from federated_learning.arguments import Arguments from federated_learning.datasets import CIFAR10Dataset from federated_learning.datasets import FashionMNISTDataset from federated_learning.utils import generate_train_loader from federated_learning.utils import generate_test_loader from federated_learning.utils import save_data_loader_to_file if __name__ == '__main__': args = Arguments(logger) # --------------------------------- # ------------ CIFAR10 ------------ # --------------------------------- dataset = CIFAR10Dataset(args) TRAIN_DATA_LOADER_FILE_PATH = "data_loaders/cifar10/train_data_loader.pickle" TEST_DATA_LOADER_FILE_PATH = "data_loaders/cifar10/test_data_loader.pickle" if not os.path.exists("data_loaders/cifar10"): pathlib.Path("data_loaders/cifar10").mkdir(parents=True, exist_ok=True) train_data_loader = generate_train_loader(args, dataset) test_data_loader = generate_test_loader(args, dataset) with open(TRAIN_DATA_LOADER_FILE_PATH, "wb") as f: save_data_loader_to_file(train_data_loader, f) with open(TEST_DATA_LOADER_FILE_PATH, "wb") as f: save_data_loader_to_file(test_data_loader, f)
def run_exp(replacement_method, num_poisoned_workers, KWARGS, client_selection_strategy, idx): log_files, results_files, models_folders, worker_selections_files = generate_experiment_ids( idx, 1) # Initialize logger handler = logger.add(log_files[0], enqueue=True) args = Arguments(logger) args.set_model_save_path(models_folders[0]) args.set_num_poisoned_workers(num_poisoned_workers) args.set_round_worker_selection_strategy_kwargs(KWARGS) args.set_client_selection_strategy(client_selection_strategy) args.log() train_data_loader = load_train_data_loader(logger, args) benign_data_loader = load_benign_data_loader(logger, args) malicious_data_loader = load_malicious_data_loader(logger, args) test_data_loader = load_test_data_loader(logger, args) # Distribute batches if args.get_distribution_method() == "bias": distributed_train_dataset = distribute_batches_bias( train_data_loader, args.get_num_workers()) elif args.get_distribution_method() == "iid": distributed_train_dataset = distribute_batches_equally( train_data_loader, args.get_num_workers()) elif args.get_distribution_method() == "noniid_1": distributed_train_dataset = distribute_batches_1_class( train_data_loader, args.get_num_workers(), args=args) elif args.get_distribution_method() == "noniid_2": distributed_train_dataset = distribute_batches_2_class( train_data_loader, args.get_num_workers(), args=args) elif args.get_distribution_method() == "noniid_mal": distributed_train_dataset = distribute_batches_noniid_mal( benign_data_loader, malicious_data_loader, args.get_num_workers(), args=args) else: distributed_train_dataset = distribute_batches_equally( train_data_loader, args.get_num_workers()) distributed_train_dataset = convert_distributed_data_into_numpy( distributed_train_dataset) poisoned_workers = identify_random_elements( args.get_num_workers(), args.get_num_poisoned_workers()) distributed_train_dataset = poison_data(logger, distributed_train_dataset, args.get_num_workers(), poisoned_workers, replacement_method, args.get_poison_effort) train_data_loaders = generate_data_loaders_from_distributed_dataset( distributed_train_dataset, args.get_batch_size()) clients = create_clients(args, train_data_loaders, test_data_loader, distributed_train_dataset) results, worker_selection = run_machine_learning(clients, args, poisoned_workers) save_results(results, results_files[0]) save_results(worker_selection, worker_selections_files[0]) logger.remove(handler)
gradient[1], color="blue", marker="x", s=1000, linewidth=5) else: plt.scatter(gradient[0], gradient[1], color="orange", s=180) fig.set_size_inches(SAVE_SIZE, forward=False) plt.grid(False) plt.margins(0, 0) plt.savefig(SAVE_NAME, bbox_inches='tight', pad_inches=0.1) if __name__ == '__main__': args = Arguments(logger) args.log() model_files = sorted(os.listdir(MODELS_PATH)) logger.debug("Number of models: {}", str(len(model_files))) param_diff = [] worker_ids = [] for epoch in EPOCHS: start_model_files = get_model_files_for_epoch(model_files, epoch) start_model_file = get_model_files_for_suffix( start_model_files, args.get_epoch_save_start_suffix())[0] start_model_file = os.path.join(MODELS_PATH, start_model_file) start_model = load_models(args, [start_model_file])[0]