def load_data(args, dataset_name): if dataset_name == "mnist": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_mnist(args.batch_size) """ For shallow NN or linear models, we uniformly sample a fraction of clients each round (as the original FedAvg paper) """ args.client_num_in_total = client_num elif dataset_name == "femnist": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_emnist(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "shakespeare": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_shakespeare(args.batch_size) args.client_num_in_total = client_num elif dataset_name == "fed_shakespeare": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_shakespeare(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "fed_cifar100": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_cifar100(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "stackoverflow_lr": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_stackoverflow_lr(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "stackoverflow_nwp": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_stackoverflow_nwp(args.dataset, args.data_dir) args.client_num_in_total = client_num else: logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_emnist(args.dataset, args.data_dir) args.client_num_in_total = client_num dataset = [train_data_num, test_data_num, train_data_global, test_data_global, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num] return dataset
def load_data(args, dataset_name): if dataset_name == "mnist": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_mnist(args.batch_size) """ For shallow NN or linear models, we uniformly sample a fraction of clients each round (as the original FedAvg paper) """ args.client_num_in_total = client_num elif dataset_name == "shakespeare": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_shakespeare(args.batch_size) args.client_num_in_total = client_num else: if dataset_name == "cifar10": data_loader = load_partition_data_cifar10 elif dataset_name == "cifar100": data_loader = load_partition_data_cifar100 elif dataset_name == "cinic10": data_loader = load_partition_data_cinic10 else: data_loader = load_partition_data_cifar10 train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = data_loader(args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size) dataset = [train_data_num, test_data_num, train_data_global, test_data_global, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num] return dataset
def load_mnist_data(batch_size): logging.info("load_data. dataset_name = mnist") client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_mnist(batch_size) client_num_in_total = client_num dataset = [train_data_num, test_data_num, train_data_global, test_data_global, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num] return dataset
def load_data(args): if args.dataset == "mnist": logging.info("load_data. dataset_name = %s" % args.dataset) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_mnist(args.batch_size) args.client_num_in_total = client_num dataset = [ train_data_num, test_data_num, train_data_global, test_data_global, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num ] return dataset
def load_data(args, dataset_name): # check if the centralized training is enabled centralized = True if args.client_num_in_total == 1 else False # check if the full-batch training is enabled args_batch_size = args.batch_size if args.batch_size <= 0: full_batch = True args.batch_size = 128 # temporary batch size else: full_batch = False if dataset_name == "mnist": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_mnist(args.batch_size) """ For shallow NN or linear models, we uniformly sample a fraction of clients each round (as the original FedAvg paper) """ args.client_num_in_total = client_num elif dataset_name == "femnist": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_emnist(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "shakespeare": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_shakespeare(args.batch_size) args.client_num_in_total = client_num elif dataset_name == "fed_shakespeare": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_shakespeare(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "fed_cifar100": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_cifar100(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "stackoverflow_lr": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_stackoverflow_lr(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "stackoverflow_nwp": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_stackoverflow_nwp(args.dataset, args.data_dir) args.client_num_in_total = client_num else: if dataset_name == "cifar10": data_loader = load_partition_data_cifar10 elif dataset_name == "cifar100": data_loader = load_partition_data_cifar100 elif dataset_name == "cinic10": data_loader = load_partition_data_cinic10 else: data_loader = load_partition_data_cifar10 train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = data_loader(args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size) if centralized: train_data_local_num_dict = { 0: sum(user_train_data_num for user_train_data_num in train_data_local_num_dict.values())} train_data_local_dict = { 0: [batch for cid in sorted(train_data_local_dict.keys()) for batch in train_data_local_dict[cid]]} test_data_local_dict = { 0: [batch for cid in sorted(test_data_local_dict.keys()) for batch in test_data_local_dict[cid]]} args.client_num_in_total = 1 if full_batch: train_data_global = combine_batches(train_data_global) test_data_global = combine_batches(test_data_global) train_data_local_dict = {cid: combine_batches(train_data_local_dict[cid]) for cid in train_data_local_dict.keys()} test_data_local_dict = {cid: combine_batches(test_data_local_dict[cid]) for cid in test_data_local_dict.keys()} args.batch_size = args_batch_size dataset = [train_data_num, test_data_num, train_data_global, test_data_global, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num] return dataset
def load_data(args, dataset_name): # check if the centralized training is enabled centralized = True if args.client_num_in_total == 1 else False # check if the full-batch training is enabled args_batch_size = args.batch_size if args.batch_size <= 0: full_batch = True args.batch_size = 128 # temporary batch size else: full_batch = False if dataset_name in ["mnist", "fmnist", "emnist"]: logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_mnist( args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size ) """ For shallow NN or linear models, we uniformly sample a fraction of clients each round (as the original FedAvg paper) """ args.client_num_in_total = client_num elif dataset_name == "har": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_ucihar( args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size ) """ For shallow NN or linear models, we uniformly sample a fraction of clients each round (as the original FedAvg paper) """ args.client_num_in_total = client_num elif dataset_name == "har_subject": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_ucihar_subject( args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size ) """ For shallow NN or linear models, we uniformly sample a fraction of clients each round (as the original FedAvg paper) """ args.client_num_in_total = client_num elif dataset_name == "chmnist": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_chmnist( args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size ) args.client_num_in_total = client_num elif dataset_name == "adult": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_uciadult( args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size ) args.client_num_in_total = client_num elif dataset_name in ["purchase100", "texas100"]: logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_purchase( args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size ) args.client_num_in_total = client_num elif dataset_name == "femnist": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_emnist( args.dataset, args.data_dir, client_num_in_total = args.client_num_in_total, ) args.client_num_in_total = client_num elif dataset_name == "shakespeare": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_shakespeare(args.batch_size) args.client_num_in_total = client_num elif dataset_name == "fed_shakespeare": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_shakespeare(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "fed_cifar100": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_cifar100(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "stackoverflow_lr": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_stackoverflow_lr(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "stackoverflow_nwp": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_stackoverflow_nwp(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "ILSVRC2012": logging.info("load_data. dataset_name = %s" % dataset_name) train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_ImageNet(dataset=dataset_name, data_dir=args.data_dir, partition_method=None, partition_alpha=None, client_number=args.client_num_in_total, batch_size=args.batch_size) elif dataset_name == "gld23k": logging.info("load_data. dataset_name = %s" % dataset_name) args.client_num_in_total = 233 fed_train_map_file = os.path.join(args.data_dir, 'mini_gld_train_split.csv') fed_test_map_file = os.path.join(args.data_dir, 'mini_gld_test.csv') train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_landmarks(dataset=dataset_name, data_dir=args.data_dir, fed_train_map_file=fed_train_map_file, fed_test_map_file=fed_test_map_file, partition_method=None, partition_alpha=None, client_number=args.client_num_in_total, batch_size=args.batch_size) elif dataset_name == "gld160k": logging.info("load_data. dataset_name = %s" % dataset_name) args.client_num_in_total = 1262 fed_train_map_file = os.path.join(args.data_dir, 'federated_train.csv') fed_test_map_file = os.path.join(args.data_dir, 'test.csv') train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_landmarks(dataset=dataset_name, data_dir=args.data_dir, fed_train_map_file=fed_train_map_file, fed_test_map_file=fed_test_map_file, partition_method=None, partition_alpha=None, client_number=args.client_num_in_total, batch_size=args.batch_size) else: if dataset_name == "cifar10": data_loader = load_partition_data_cifar10 elif dataset_name == "cifar100": data_loader = load_partition_data_cifar100 elif dataset_name == "cinic10": data_loader = load_partition_data_cinic10 else: data_loader = load_partition_data_cifar10 train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = data_loader(args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size) if centralized: train_data_local_num_dict = { 0: sum(user_train_data_num for user_train_data_num in train_data_local_num_dict.values()) } train_data_local_dict = { 0: [ batch for cid in sorted(train_data_local_dict.keys()) for batch in train_data_local_dict[cid] ] } test_data_local_dict = { 0: [ batch for cid in sorted(test_data_local_dict.keys()) for batch in test_data_local_dict[cid] ] } args.client_num_in_total = 1 if full_batch: train_data_global = combine_batches(train_data_global) test_data_global = combine_batches(test_data_global) train_data_local_dict = { cid: combine_batches(train_data_local_dict[cid]) for cid in train_data_local_dict.keys() } test_data_local_dict = { cid: combine_batches(test_data_local_dict[cid]) for cid in test_data_local_dict.keys() } args.batch_size = args_batch_size dataset = [ train_data_num, test_data_num, train_data_global, test_data_global, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num ] return dataset
def load_data(args, dataset_name): if dataset_name == "mnist": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_mnist(args.batch_size) """ For shallow NN or linear models, we uniformly sample a fraction of clients each round (as the original FedAvg paper) """ args.client_num_in_total = client_num elif dataset_name == "femnist": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_emnist(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "shakespeare": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_shakespeare(args.batch_size) args.client_num_in_total = client_num elif dataset_name == "fed_shakespeare": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_shakespeare(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "fed_cifar100": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_cifar100(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "stackoverflow_lr": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_stackoverflow_lr(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "stackoverflow_nwp": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_stackoverflow_nwp(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "ILSVRC2012": logging.info("load_data. dataset_name = %s" % dataset_name) train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_ImageNet(dataset=dataset_name, data_dir=args.data_dir, partition_method=None, partition_alpha=None, client_number=args.client_num_in_total, batch_size=args.batch_size) elif dataset_name == "gld23k": logging.info("load_data. dataset_name = %s" % dataset_name) args.client_num_in_total = 233 fed_train_map_file = os.path.join(args.data_dir, 'mini_gld_train_split.csv') fed_test_map_file = os.path.join(args.data_dir, 'mini_gld_test.csv') args.data_dir = os.path.join(args.data_dir, 'images') train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_landmarks(dataset=dataset_name, data_dir=args.data_dir, fed_train_map_file=fed_train_map_file, fed_test_map_file=fed_test_map_file, partition_method=None, partition_alpha=None, client_number=args.client_num_in_total, batch_size=args.batch_size) elif dataset_name == "gld160k": logging.info("load_data. dataset_name = %s" % dataset_name) args.client_num_in_total = 1262 fed_train_map_file = os.path.join(args.data_dir, 'federated_train.csv') fed_test_map_file = os.path.join(args.data_dir, 'test.csv') args.data_dir = os.path.join(args.data_dir, 'images') train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_landmarks(dataset=dataset_name, data_dir=args.data_dir, fed_train_map_file=fed_train_map_file, fed_test_map_file=fed_test_map_file, partition_method=None, partition_alpha=None, client_number=args.client_num_in_total, batch_size=args.batch_size) else: if dataset_name == "cifar10": data_loader = load_partition_data_cifar10 elif dataset_name == "cifar100": data_loader = load_partition_data_cifar100 elif dataset_name == "cinic10": data_loader = load_partition_data_cinic10 else: data_loader = load_partition_data_cifar10 train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = data_loader(args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size) dataset = [ train_data_num, test_data_num, train_data_global, test_data_global, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num ] return dataset
def load_data(args, dataset_name): if dataset_name in ["mnist", "fmnist", "emnist"]: logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_mnist( args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size, training_data_ratio=args.training_data_ratio, ) """ For shallow NN or linear models, we uniformly sample a fraction of clients each round (as the original FedAvg paper) """ args.client_num_in_total = client_num elif dataset_name == "har": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_ucihar( args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size ) """ For shallow NN or linear models, we uniformly sample a fraction of clients each round (as the original FedAvg paper) """ args.client_num_in_total = client_num elif dataset_name == "chmnist": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_chmnist( args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size ) args.client_num_in_total = client_num elif dataset_name == "adult": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_uciadult( args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size ) args.client_num_in_total = client_num elif dataset_name in ["purchase100", "texas100"]: logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_purchase( args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size ) args.client_num_in_total = client_num elif dataset_name == "femnist": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_emnist(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "shakespeare": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_shakespeare(args.batch_size) args.client_num_in_total = client_num elif dataset_name == "fed_shakespeare": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_shakespeare(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "fed_cifar100": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_cifar100(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "stackoverflow_lr": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_stackoverflow_lr(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "stackoverflow_nwp": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_stackoverflow_nwp(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name in ["ILSVRC2012", "ILSVRC2012_hdf5"]: if args.data_parallel == 1: logging.info("load_data. dataset_name = %s" % dataset_name) train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = distributed_centralized_ImageNet_loader(dataset=dataset_name, data_dir=args.data_dir, world_size=args.world_size, rank=args.rank, batch_size=args.batch_size) else: logging.info("load_data. dataset_name = %s" % dataset_name) train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_ImageNet(dataset=dataset_name, data_dir=args.data_dir, partition_method=None, partition_alpha=None, client_number=args.client_num_in_total, batch_size=args.batch_size) elif dataset_name == "gld23k": logging.info("load_data. dataset_name = %s" % dataset_name) args.client_num_in_total = 233 fed_train_map_file = os.path.join( args.data_dir, 'data_user_dict/gld23k_user_dict_train.csv') fed_test_map_file = os.path.join( args.data_dir, 'data_user_dict/gld23k_user_dict_test.csv') args.data_dir = os.path.join(args.data_dir, 'images') train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_landmarks(dataset=dataset_name, data_dir=args.data_dir, fed_train_map_file=fed_train_map_file, fed_test_map_file=fed_test_map_file, partition_method=None, partition_alpha=None, client_number=args.client_num_in_total, batch_size=args.batch_size) elif dataset_name == "gld160k": logging.info("load_data. dataset_name = %s" % dataset_name) args.client_num_in_total = 1262 fed_train_map_file = os.path.join( args.data_dir, 'data_user_dict/gld160k_user_dict_train.csv') fed_test_map_file = os.path.join( args.data_dir, 'data_user_dict/gld160k_user_dict_test.csv') args.data_dir = os.path.join(args.data_dir, 'images') train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_landmarks(dataset=dataset_name, data_dir=args.data_dir, fed_train_map_file=fed_train_map_file, fed_test_map_file=fed_test_map_file, partition_method=None, partition_alpha=None, client_number=args.client_num_in_total, batch_size=args.batch_size) else: if dataset_name == "cifar10": data_loader = load_partition_data_cifar10 elif dataset_name == "cifar100": data_loader = load_partition_data_cifar100 elif dataset_name == "cinic10": data_loader = load_partition_data_cinic10 else: data_loader = load_partition_data_cifar10 train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = data_loader(args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size) dataset = [ train_data_num, test_data_num, train_data_global, test_data_global, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num ] return dataset
def load_data(args, dataset_name): # check if the centralized training is enabled(如果总共只有一个设备参与,则认为是中心化训练) centralized = True if args.client_num_in_total == 1 else False # check if the full-batch training is enabled args_batch_size = args.batch_size if args.batch_size <= 0: # 为什么是<=0? full_batch = True args.batch_size = 128 # temporary batch size else: full_batch = False # assign a initial feature_name = None if dataset_name == "mnist": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_mnist(args.batch_size) """ For shallow NN or linear models, we uniformly sample a fraction of clients each round (as the original FedAvg paper) """ args.client_num_in_total = client_num elif dataset_name == "mnist_test": logging.info("load_data. dataset_name = %s" % dataset_name) # 使用自带的dataloader方法进行划分 # client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ # train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ # class_num = load_partition_data_mnist(args.batch_size) # args.client_num_in_total = client_num # 更具有灵活性的指定,进行数据划分 train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_mnist_asign(args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size) elif dataset_name == "cervical_cancer": train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num, feature_name = load_partition_data_cervical(args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size) elif dataset_name == "femnist": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_emnist(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "shakespeare": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_shakespeare(args.batch_size) args.client_num_in_total = client_num elif dataset_name == "fed_shakespeare": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_shakespeare(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "fed_cifar100": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_cifar100(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "stackoverflow_lr": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_stackoverflow_lr(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "stackoverflow_nwp": logging.info("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_stackoverflow_nwp(args.dataset, args.data_dir) args.client_num_in_total = client_num elif dataset_name == "ILSVRC2012": logging.info("load_data. dataset_name = %s" % dataset_name) train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_ImageNet(dataset=dataset_name, data_dir=args.data_dir, partition_method=None, partition_alpha=None, client_number=args.client_num_in_total, batch_size=args.batch_size) elif dataset_name == "gld23k": logging.info("load_data. dataset_name = %s" % dataset_name) args.client_num_in_total = 233 fed_train_map_file = os.path.join(args.data_dir, 'mini_gld_train_split.csv') fed_test_map_file = os.path.join(args.data_dir, 'mini_gld_test.csv') train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_landmarks(dataset=dataset_name, data_dir=args.data_dir, fed_train_map_file=fed_train_map_file, fed_test_map_file=fed_test_map_file, partition_method=None, partition_alpha=None, client_number=args.client_num_in_total, batch_size=args.batch_size) elif dataset_name == "gld160k": logging.info("load_data. dataset_name = %s" % dataset_name) args.client_num_in_total = 1262 fed_train_map_file = os.path.join(args.data_dir, 'federated_train.csv') fed_test_map_file = os.path.join(args.data_dir, 'test.csv') train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_landmarks(dataset=dataset_name, data_dir=args.data_dir, fed_train_map_file=fed_train_map_file, fed_test_map_file=fed_test_map_file, partition_method=None, partition_alpha=None, client_number=args.client_num_in_total, batch_size=args.batch_size) else: if dataset_name == "cifar10": data_loader = load_partition_data_cifar10 elif dataset_name == "cifar100": data_loader = load_partition_data_cifar100 elif dataset_name == "cinic10": data_loader = load_partition_data_cinic10 else: data_loader = load_partition_data_cifar10 train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = data_loader(args.dataset, args.data_dir, args.partition_method, args.partition_alpha, args.client_num_in_total, args.batch_size) # 如果是中心化训练(只有一个设备参与),修改客户端本地数据映射字典,全部映射到0号 if centralized: train_data_local_num_dict = { 0: sum(user_train_data_num for user_train_data_num in train_data_local_num_dict.values()) } train_data_local_dict = { 0: [ batch for cid in sorted(train_data_local_dict.keys()) for batch in train_data_local_dict[cid] ] } test_data_local_dict = { 0: [ batch for cid in sorted(test_data_local_dict.keys()) for batch in test_data_local_dict[cid] ] } args.client_num_in_total = 1 # 如果是全批次训练,则 if full_batch: train_data_global = combine_batches(train_data_global) test_data_global = combine_batches(test_data_global) train_data_local_dict = { cid: combine_batches(train_data_local_dict[cid]) for cid in train_data_local_dict.keys() } test_data_local_dict = { cid: combine_batches(test_data_local_dict[cid]) for cid in test_data_local_dict.keys() } args.batch_size = args_batch_size dataset = [ train_data_num, test_data_num, train_data_global, test_data_global, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num, feature_name ] return dataset
def load_data(args, dataset_name): # check if the full-batch training is enabled args_batch_size = args.batch_size if args.batch_size <= 0: full_batch = True args.batch_size = 128 # temporary batch size else: full_batch = False logger.info("-------------dataset loading------------") if dataset_name == "mnist": logger.debug("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_mnist(args.batch_size) """ For shallow NN or linear models, we uniformly sample a fraction of clients each round (as the original FedAvg paper) """ elif dataset_name == "femnist": logger.debug("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_emnist(args.dataset, args.data_dir) elif dataset_name == "shakespeare": logger.debug("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_shakespeare(args.batch_size) elif dataset_name == "fed_shakespeare": logger.debug("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_shakespeare(args.dataset, args.data_dir) elif dataset_name == "fed_cifar100": logger.debug("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_cifar100(args.dataset, args.data_dir) elif dataset_name == "stackoverflow_lr": logger.debug("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_stackoverflow_lr(args.dataset, args.data_dir) elif dataset_name == "stackoverflow_nwp": logger.debug("load_data. dataset_name = %s" % dataset_name) client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = load_partition_data_federated_stackoverflow_nwp(args.dataset, args.data_dir) else: if dataset_name == "cifar10": data_loader = load_partition_data_cifar10 elif dataset_name == "cifar100": data_loader = load_partition_data_cifar100 elif dataset_name == "cinic10": data_loader = load_partition_data_cinic10 else: data_loader = load_partition_data_cifar10 client_num, train_data_num, test_data_num, train_data_global, test_data_global, \ train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \ class_num = data_loader(args.dataset, args.data_dir, args.partition_method, args.partition_alpha, client_num_in_total, args.batch_size) if full_batch: train_data_global = combine_batches(train_data_global) test_data_global = combine_batches(test_data_global) train_data_local_dict = { cid: combine_batches(train_data_local_dict[cid]) for cid in train_data_local_dict.keys() } test_data_local_dict = { cid: combine_batches(test_data_local_dict[cid]) for cid in test_data_local_dict.keys() } args.batch_size = args_batch_size dataset = [ client_num, train_data_num, test_data_num, train_data_global, test_data_global, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num ] return dataset