コード例 #1
0
    def setup_clients(self):
        users = []
        trainloaders, testloaders = [], []
        if self.dataset_name == 'cifar10':
            # data augmentation
            # train_transform = transforms.Compose([
            #     # transforms.RandomCrop(size=24, padding=8, fill=0, padding_mode='constant'),
            #     transforms.RandomHorizontalFlip(p=0.5),
            #     transforms.RandomApply([
            #         transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2)], p=0.8),
            #
            #     transforms.ToTensor(),
            #     transforms.Normalize(mean=[0.485, 0.456, 0.406],
            #                          std=[0.229, 0.224, 0.225])
            # ])
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

            test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])
            # TODO(specify the num of all clients: default 100 for cifar10 dataset)
            users, trainloaders, testloaders = get_cifar10_dataloaders(
                batch_size=self.batch_size,
                train_transform=train_transform,
                test_transform=test_transform)

        elif self.dataset_name == 'mnist':
            train_transform = None
            test_transform = None
            users, trainloaders, testloaders = get_mnist_dataloaders(
                batch_size=self.batch_size,
                train_transform=train_transform,
                test_transform=test_transform)
        elif self.dataset_name == 'femnist':
            users, trainloaders, testloaders = get_femnist_dataloaders(
                batch_size=self.batch_size)

        clients = [
            Client(user_id=user_id,
                   seed=self.seed,
                   trainloader=trainloaders[user_id],
                   testloader=testloaders[user_id],
                   model_name=self.model_name,
                   lr=self.lr,
                   epoch=self.epoch,
                   lr_decay=self.lr_decay,
                   decay_step=self.decay_step) for user_id in users
        ]
        return clients
コード例 #2
0
    def setup_datasets(self):
        users = []
        trainloaders, testloaders = [], []
        if self.dataset_name == 'cifar10':
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

            test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])
            users, trainloaders, testloaders = get_cifar10_dataloaders(
                batch_size=self.batch_size,
                train_transform=train_transform,
                test_transform=test_transform)
        elif self.dataset_name == 'cifar100':
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

            test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])
            users, trainloaders, testloaders = get_cifar100_dataloaders(
                batch_size=self.batch_size,
                train_transform=train_transform,
                test_transform=test_transform)
        elif self.dataset_name == 'mnist':
            train_transform = None
            test_transform = None
            users, trainloaders, testloaders = get_mnist_dataloaders(
                batch_size=self.batch_size,
                train_transform=train_transform,
                test_transform=test_transform)
        elif self.dataset_name == 'femnist':
            users, trainloaders, testloaders = get_femnist_dataloaders(
                batch_size=self.batch_size)
        elif self.dataset_name == 'flickr':
            users, trainloaders, testloaders = get_flickr_dataloaders(
                split_ratio=0.9, batch_size=10)
        elif self.dataset_name == 'celeba':
            users, trainloaders, testloaders = get_celeba_dataloaders(
                batch_size=self.batch_size)
        return users, trainloaders, testloaders
コード例 #3
0
ファイル: SERVER_BASE.py プロジェクト: tdye24/Fed101
    def setup_datasets(self):
        users = []
        trainloaders, testloaders = [], []
        if self.dataset_name == 'cifar10':
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])

            test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
            users, trainloaders, testloaders = get_cifar10_dataloaders(
                batch_size=self.batch_size,
                train_transform=train_transform,
                test_transform=test_transform)
        elif self.dataset_name == 'cifar10_dirichlet':
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])

            test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
            users, trainloaders, testloaders = get_cifar10_dirichlet_dataloaders(
                users_num=30,
                alpha=self.alpha,
                batch_size=self.batch_size,
                train_transform=train_transform,
                test_transform=test_transform)
        elif self.dataset_name == 'cifar10_ld':
            users, trainloaders, testloaders = get_cifar10_ld_dataloaders()
        elif self.dataset_name == 'cifar100':
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])

            test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
            users, trainloaders, testloaders = get_cifar100_dataloaders(
                batch_size=self.batch_size,
                train_transform=train_transform,
                test_transform=test_transform)
        elif self.dataset_name == 'cifar100_superclass':
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])

            test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
            users, trainloaders, testloaders = get_cifar100_superclass_dataloaders(
                batch_size=self.batch_size,
                train_transform=train_transform,
                test_transform=test_transform)

        elif self.dataset_name == 'mnist_malposition':
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])
            test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])
            users, trainloaders, testloaders = get_mnist_malposition_dataloaders(
                batch_size=self.batch_size,
                train_transform=train_transform,
                test_transform=test_transform)
        elif self.dataset_name == 'mnist_wo_malposition':
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])
            test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])
            users, trainloaders, testloaders = get_mnist_wo_malposition_dataloaders(
                batch_size=self.batch_size,
                train_transform=train_transform,
                test_transform=test_transform)

        elif self.dataset_name == 'mnist_ld':
            users, trainloaders, testloaders = get_mnist_ld_dataloaders(
                batch_size=self.batch_size)
        elif self.dataset_name == 'mnist_fedml':
            users, trainloaders, testloaders = get_mnist_fedml_dataloaders(
                batch_size=self.batch_size)
        elif self.dataset_name == 'femnist':
            users, trainloaders, testloaders = get_femnist_dataloaders(
                batch_size=self.batch_size)
        elif self.dataset_name == 'flickr':
            users, trainloaders, testloaders = get_flickr_dataloaders(
                split_ratio=0.9, batch_size=10)
        elif self.dataset_name == 'celeba':
            users, trainloaders, testloaders = get_celeba_dataloaders(
                batch_size=self.batch_size)
        elif self.dataset_name == 'har':
            users, trainloaders, testloaders = get_har_dataloaders(
                batch_size=self.batch_size)
        return users, trainloaders, testloaders