コード例 #1
0
ファイル: utils.py プロジェクト: meryemJanatiIdrissi/FedBS
def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'
        apply_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = cifar_noniid_unequal(train_dataset, args.num_users)

            elif args.unequalZipf:
                user_groups = cifar_noniid_unequal_zipf(train_dataset, args.num_users)

            else:
                # Chose euqal splits for every user
                user_groups = cifar_noniid(train_dataset, args.num_users)

    elif args.dataset == 'mnist' or 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = mnist_noniid(train_dataset, args.num_users)

    return train_dataset, test_dataset, user_groups
コード例 #2
0
def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar10':
        data_dir = '../data/cifar10/'
        apply_transform_train = transforms.Compose([
            transforms.RandomCrop(24),
            transforms.RandomHorizontalFlip(0.5),
            transforms.ColorJitter(brightness=0.2,
                                   contrast=0.2,
                                   saturation=0.2),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2470, 0.2435, 0.2616))
        ])

        apply_transform_test = transforms.Compose([
            transforms.CenterCrop(24),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2470, 0.2435, 0.2616))
        ])

        train_dataset = datasets.CIFAR10(data_dir,
                                         train=True,
                                         download=True,
                                         transform=apply_transform_train)

        test_dataset = datasets.CIFAR10(data_dir,
                                        train=False,
                                        download=True,
                                        transform=apply_transform_test)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.hard:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups = cifar_noniid(train_dataset, args.num_users)

    elif args.dataset == 'mnist' or args.dataset == 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'
        #TODO:1 Accommodate FMNIST case (mean, var). This is the mean, var of MNIST; Fashion MNIST may have different set of params/
        # shall we use the  params from opt instead of setting hard params?
        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        train_dataset = datasets.MNIST(data_dir,
                                       train=True,
                                       download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir,
                                      train=False,
                                      download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset,
                                                   args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = mnist_noniid(train_dataset, args.num_users)

    elif args.dataset == 'cub200':
        data_dir = '../data/cub200/'
        apply_transform_train = transforms.Compose([
            transforms.Resize(int(cf.imresize[args.net_type])),
            transforms.RandomRotation(10),
            transforms.RandomCrop(cf.imsize[args.net_type]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]),
        ])

        apply_transform_test = transforms.Compose([
            transforms.Resize(cf.imresize[args.net_type]),
            transforms.CenterCrop(cf.imsize[args.net_type]),
            transforms.ToTensor(),
            transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]),
        ])
        train_dataset = cub.CUB200(data_dir,
                                   year=2011,
                                   train=True,
                                   download=True,
                                   transform=apply_transform_train)

        test_dataset = cub.CUB200(data_dir,
                                  year=2011,
                                  train=False,
                                  download=True,
                                  transform=apply_transform_test)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cub_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.hard:
                # Chose uneuqal splits for every user
                user_groups = cub_noniid_hard(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = cub_noniid(train_dataset, args.num_users)

    return train_dataset, test_dataset, user_groups
コード例 #3
0
def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'
        apply_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups = cifar_noniid(train_dataset, args.num_users)

    elif args.dataset == 'medmnist':
        train_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(mean=[.5], std=[.5])])

        test_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(mean=[.5], std=[.5])])
        input_root = args.input_root

        train_dataset = BreastMNIST(root=input_root,
                                    split='train',
                                    transform=train_transform,
                                    download=True)

        test_dataset = BreastMNIST(root=input_root,
                                   split='test',
                                   transform=test_transform,
                                   download=True)

        user_groups = medmnist_iid(train_dataset, args.num_users)
        print('--------------------------------')
        print(train_dataset, test_dataset, user_groups)
        print('--------------------------------')

    elif args.dataset == 'mnist' or 'fmnist':
        print('<<<<>>>>>')
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = mnist_noniid(train_dataset, args.num_users)


    return train_dataset, test_dataset, user_groups
コード例 #4
0
def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'
        apply_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])
        test_apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])
        train_dataset = datasets.CIFAR10(data_dir,
                                         train=True,
                                         download=True,
                                         transform=apply_transform)

        test_dataset = datasets.CIFAR10(data_dir,
                                        train=False,
                                        download=True,
                                        transform=test_apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups = cifar_noniid(train_dataset, args.num_users)
    elif args.dataset == 'mnist' or args.dataset == 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        train_dataset = datasets.MNIST(data_dir,
                                       train=True,
                                       download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir,
                                      train=False,
                                      download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset,
                                                   args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = mnist_noniid(train_dataset, args.num_users)
    elif args.dataset == 'brats2018':
        from data.brats2018.dataset import BRATS2018Dataset, InstitutionWiseBRATS2018Dataset
        # from torch.utils.data import random_split
        from sampling import brats2018_iid, brats2018_unbalanced
        data_dir = args.data_dir
        test_dataset = None
        if args.balanced:
            train_dataset = BRATS2018Dataset(training_dir=data_dir,
                                             img_dim=128)
            user_groups = brats2018_iid(dataset=train_dataset,
                                        num_users=args.num_users)
        else:
            # BRATS2018 得到的数据来自于 19家机构. 默认
            train_dataset = InstitutionWiseBRATS2018Dataset(
                training_dir=data_dir,
                img_dim=128,
                config_json='../data/brats2018/hgg_config.json')
            user_groups = brats2018_unbalanced(dataset=train_dataset,
                                               num_users=args.num_users)
    elif args.dataset == 'brats2018_data_aug':
        from data.brats2018.dataset import InstitutionWiseBRATS2018DatasetDataAugmentation
        # from torch.utils.data import random_split
        from sampling import brats2018_iid, brats2018_unbalanced
        data_dir = args.data_dir
        test_dataset = None
        if args.balanced:
            raise NotImplementedError
        else:
            # BRATS2018 得到的数据来自于 19家机构. 默认
            train_dataset = InstitutionWiseBRATS2018DatasetDataAugmentation(
                training_dir=data_dir,
                img_dim=128,
                config_json='../data/brats2018/hgg_config.json')
            user_groups = brats2018_unbalanced(dataset=train_dataset,
                                               num_users=args.num_users)
    return train_dataset, test_dataset, user_groups
コード例 #5
0
def get_dataset(args,
                tokenizer=None,
                max_seq_len=MAX_SEQUENCE_LENGTH,
                custom_sampling=None):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.task == 'nlp':
        assert args.dataset == "ade", "Parsed dataset not implemented."
        [complete_dataset] = ds.load_dataset("ade_corpus_v2",
                                             "Ade_corpus_v2_classification",
                                             split=["train"])
        # Rename column.
        complete_dataset = complete_dataset.rename_column("label", "labels")
        complete_dataset = complete_dataset.shuffle(seed=args.seed)
        # Split into train and test sets.
        split_examples = complete_dataset.train_test_split(
            test_size=args.test_frac)
        train_examples = split_examples["train"]
        test_examples = split_examples["test"]

        # Tokenize training set.
        train_dataset = train_examples.map(
            lambda examples: tokenizer(
                examples["text"],
                truncation=True,
                max_length=max_seq_len,
                padding="max_length",
            ),
            batched=True,
            remove_columns=["text"],
        )
        train_dataset.set_format(type="torch")

        # Tokenize test set.
        test_dataset = test_examples.map(
            lambda examples: tokenizer(
                examples["text"],
                truncation=True,
                max_length=max_seq_len,
                padding="max_length",
            ),
            batched=True,
            remove_columns=["text"],
        )
        test_dataset.set_format(type="torch")

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Ade_corpus
            user_groups = ade_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Ade_corpus
            if args.unequal:
                # Chose unequal splits for every user
                raise NotImplementedError()
            else:
                # Chose equal splits for every user
                user_groups = ade_noniid(train_dataset, args.num_users)

    elif args.task == 'cv':
        if args.dataset == 'cifar':
            data_dir = './data/cifar/'
            apply_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

            train_dataset = datasets.CIFAR10(data_dir,
                                             train=True,
                                             download=True,
                                             transform=apply_transform)

            test_dataset = datasets.CIFAR10(data_dir,
                                            train=False,
                                            download=True,
                                            transform=apply_transform)

            # sample training data amongst users
            if custom_sampling is not None:
                user_groups = custom_sampling(dataset=train_dataset,
                                              num_users=args.num_users)
                assert len(
                    user_groups
                ) == args.num_users, "Incorrect number of users generated."
                check_client_sampled_data = []
                for client_idx, client_samples in user_groups.items():
                    assert len(client_samples) == len(
                        train_dataset
                    ) / args.num_users, "Incorrectly sampled client shard."
                    for record in client_samples:
                        check_client_sampled_data.append(record)
                assert len(set(check_client_sampled_data)) == len(
                    train_dataset), "Client shards are not i.i.d"
                print("Congratulations! You've got it :)")
            else:
                # sample training data amongst users
                if args.iid:
                    # Sample IID user data from Mnist
                    user_groups = cifar_iid(train_dataset, args.num_users)
                else:
                    # Sample Non-IID user data from Mnist
                    if args.unequal:
                        # Chose uneuqal splits for every user
                        raise NotImplementedError()
                    else:
                        # Chose euqal splits for every user
                        user_groups = cifar_noniid(train_dataset,
                                                   args.num_users)

        elif args.dataset == 'mnist' or 'fmnist':
            if args.dataset == 'mnist':
                data_dir = './data/mnist/'
            else:
                data_dir = './data/fmnist/'

            apply_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])

            train_dataset = datasets.MNIST(data_dir,
                                           train=True,
                                           download=True,
                                           transform=apply_transform)

            test_dataset = datasets.MNIST(data_dir,
                                          train=False,
                                          download=True,
                                          transform=apply_transform)

            if args.iid:
                # Sample IID user data from Mnist
                user_groups = mnist_iid(train_dataset, args.num_users)
            else:
                # Sample Non-IID user data from Mnist
                if args.unequal:
                    # Chose unequal splits for every user
                    user_groups = mnist_noniid_unequal(train_dataset,
                                                       args.num_users)
                else:
                    # Chose equal splits for every user
                    user_groups = mnist_noniid(train_dataset, args.num_users)
        else:
            raise NotImplementedError(f"""Unrecognized dataset {args.dataset}.
                Options are: `cifar`, `mnist`, `fmnist`.
                """)
    else:
        raise NotImplementedError(f"""Unrecognised task {args.task}.
            Options are: `nlp` and `cv`.
            """)

    return train_dataset, test_dataset, user_groups
コード例 #6
0
def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if args.dataset == 'cifar':
        data_dir = '../data/cifar/'

        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)

        transforms_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        transforms_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        train_dataset = datasets.CIFAR10(data_dir,
                                         train=True,
                                         download=True,
                                         transform=transforms_train)
        test_dataset = datasets.CIFAR10(data_dir,
                                        train=False,
                                        download=True,
                                        transform=transforms_test)
        """
        apply_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True,
                                       transform=apply_transform)

        test_dataset = datasets.CIFAR10(data_dir, train=False, download=True,
                                      transform=apply_transform)
        """
        # sample training data amongst users
        if args.iid == 1:
            # Sample IID user data from Mnist
            # user_groups = cifar_iid(train_dataset, args.num_users)
            user_groups = cifar10_iid(train_dataset, args.num_users, args=args)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                if args.iid == 2:
                    #                    user_groups = partition_data(train_dataset, 'noniid-#label2', num_uers=args.num_users, alpha=1, args=args)
                    user_groups = cifar_noniid(train_dataset,
                                               num_users=args.num_users,
                                               args=args)
                else:
                    user_groups = partition_data(train_dataset,
                                                 'dirichlet',
                                                 num_uers=args.num_users,
                                                 alpha=1,
                                                 args=args)
        # 분류된 index와 train dataset로 client train dataloder 생성
        client_loader_dict = client_loader(train_dataset, user_groups, args)

    elif args.dataset == 'cifar100':
        data_dir = '../data/fed_cifar100'
        train_dataset, test_dataset, client_loader_dict = load_partition_data_federated_cifar100(
            data_dir=data_dir, batch_size=args.local_bs)

    elif args.dataset == 'mnist' or 'fmnist':
        if args.dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        train_dataset = datasets.MNIST(data_dir,
                                       train=True,
                                       download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir,
                                      train=False,
                                      download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(train_dataset,
                                                   args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups = mnist_noniid(train_dataset, args.num_users)

        # 분류된 index와 train dataset로 client train data loader 생성
        client_loader_dict = client_loader(train_dataset, user_groups, args)

    return train_dataset, test_dataset, client_loader_dict
コード例 #7
0
ファイル: utils.py プロジェクト: yuetan031/fedproto
def get_dataset(args, n_list, k_list):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """
    data_dir = args.data_dir + args.dataset
    if args.dataset == 'mnist':
        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        train_dataset = datasets.MNIST(data_dir,
                                       train=True,
                                       download=True,
                                       transform=apply_transform)

        test_dataset = datasets.MNIST(data_dir,
                                      train=False,
                                      download=True,
                                      transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = mnist_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                user_groups = mnist_noniid_unequal(args, train_dataset,
                                                   args.num_users)
            else:
                # Chose euqal splits for every user
                user_groups, classes_list = mnist_noniid(
                    args, train_dataset, args.num_users, n_list, k_list)
                user_groups_lt = mnist_noniid_lt(args, test_dataset,
                                                 args.num_users, n_list,
                                                 k_list, classes_list)
                classes_list_gt = classes_list

    elif args.dataset == 'femnist':
        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        train_dataset = femnist.FEMNIST(args,
                                        data_dir,
                                        train=True,
                                        download=True,
                                        transform=apply_transform)
        test_dataset = femnist.FEMNIST(args,
                                       data_dir,
                                       train=False,
                                       download=True,
                                       transform=apply_transform)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = femnist_iid(train_dataset, args.num_users)
            # print("TBD")
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                # user_groups = mnist_noniid_unequal(train_dataset, args.num_users)
                user_groups = femnist_noniid_unequal(args, train_dataset,
                                                     args.num_users)
                # print("TBD")
            else:
                # Chose euqal splits for every user
                user_groups, classes_list, classes_list_gt = femnist_noniid(
                    args, args.num_users, n_list, k_list)
                user_groups_lt = femnist_noniid_lt(args, args.num_users,
                                                   classes_list)

    elif args.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10(data_dir,
                                         train=True,
                                         download=True,
                                         transform=trans_cifar10_train)
        test_dataset = datasets.CIFAR10(data_dir,
                                        train=False,
                                        download=True,
                                        transform=trans_cifar10_val)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups, classes_list, classes_list_gt = cifar10_noniid(
                    args, train_dataset, args.num_users, n_list, k_list)
                user_groups_lt = cifar10_noniid_lt(args, test_dataset,
                                                   args.num_users, n_list,
                                                   k_list, classes_list)

    elif args.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100(data_dir,
                                          train=True,
                                          download=True,
                                          transform=trans_cifar100_train)
        test_dataset = datasets.CIFAR100(data_dir,
                                         train=False,
                                         download=True,
                                         transform=trans_cifar100_val)

        # sample training data amongst users
        if args.iid:
            # Sample IID user data from Mnist
            user_groups = cifar_iid(train_dataset, args.num_users)
        else:
            # Sample Non-IID user data from Mnist
            if args.unequal:
                # Chose uneuqal splits for every user
                raise NotImplementedError()
            else:
                # Chose euqal splits for every user
                user_groups, classes_list = cifar100_noniid(
                    args, train_dataset, args.num_users, n_list, k_list)
                user_groups_lt = cifar100_noniid_lt(test_dataset,
                                                    args.num_users,
                                                    classes_list)

    return train_dataset, test_dataset, user_groups, user_groups_lt, classes_list, classes_list_gt
コード例 #8
0
    # load dataset and split users
    if args.dataset == 'mnist':
        dataset_train = datasets.MNIST('../data/mnist/',
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        elif args.unequal:
            dict_users = mnist_noniid_unequal(dataset_train, args.num_users)
        else:
            dict_users = mnist_noniid(dataset_train, args.num_users)

    elif args.dataset == 'cifar':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset_train = datasets.CIFAR10('../data/cifar',
                                         train=True,
                                         transform=transform,
                                         target_transform=None,
                                         download=True)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)