Exemplo n.º 1
0
Arquivo: mnist.py Projeto: bacox/fltk
 def init_test_dataset(self):
     dist_loader_text = "distributed" if self.args.get_distributed() else ""
     self.logger.debug(f"Loading '{dist_loader_text}' MNIST test data")
     self.test_dataset = datasets.MNIST(
         root=self.get_args().get_data_path(),
         train=False,
         download=True,
         transform=transforms.Compose([transforms.ToTensor()]))
     self.test_sampler = get_sampler(self.test_dataset, self.args)
     self.test_loader = DataLoader(self.test_dataset,
                                   batch_size=self.args.test_batch_size,
                                   sampler=self.test_sampler)
Exemplo n.º 2
0
 def init_test_dataset(self):
     dist_loader_text = "distributed" if self.args.get_distributed() else ""
     self.logger.debug(f"Loading '{dist_loader_text}' Shakespeare test data")
     dataset = 'shakespeare'
     data_root: str = None
     pickle_root: str = None
     pdataset = PickleDataset(dataset_name=dataset, data_root=data_root, pickle_root=pickle_root)
     # client_id = None
     client_id = self.args.rank
     self.test_dataset = pdataset.get_dataset_pickle(dataset_type="test", client_id=client_id)
     self.test_sampler = get_sampler(self.test_dataset, self.args)
     self.test_loader = DataLoader(self.test_dataset, batch_size=self.args.test_batch_size, sampler=self.test_sampler)
Exemplo n.º 3
0
    def init_test_dataset(self):
        dist_loader_text = "distributed" if self.args.get_distributed() else ""
        self.logger.debug(f"Loading '{dist_loader_text}' CIFAR100 test data")

        normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
        transform = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])
        self.test_dataset = datasets.CIFAR100(root=self.get_args().get_data_path(), train=False, download=True,
                                        transform=transform)
        self.test_sampler = get_sampler(self.test_dataset, self.args)
        self.test_loader = DataLoader(self.test_dataset, batch_size=self.args.test_batch_size, sampler=self.test_sampler)
Exemplo n.º 4
0
    def load_test_dataset(self):
        self.logger.debug("Loading CIFAR100 test data")

        normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
        transform = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])
        test_dataset = datasets.CIFAR100(root=self.get_args().get_data_path(), train=False, download=True,
                                        transform=transform)
        sampler = get_sampler(self.test_dataset, self.args)
        test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), sampler=sampler)
        self.args.set_sampler(sampler)

        test_data = self.get_tuple_from_data_loader(test_loader)

        self.logger.debug("Finished loading CIFAR10 test data")

        return test_data
Exemplo n.º 5
0
    def load_train_dataset(self):
        dist_loader_text = "distributed" if self.args.get_distributed() else ""
        self.logger.debug(f"Loading '{dist_loader_text}' CIFAR100 train data")

        normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize
        ])

        train_dataset = datasets.CIFAR100(root=self.get_args().get_data_path(), train=True, download=True,
                                         transform=transform)
        sampler = get_sampler(self.test_dataset, self.args)

        train_loader = DataLoader(train_dataset, batch_size=len(train_dataset), sampler=sampler)
        self.args.set_sampler(sampler)

        train_data = self.get_tuple_from_data_loader(train_loader)
        dist_loader_text = "distributed" if self.args.get_distributed() else ""
        self.logger.debug(f"Finished loading '{dist_loader_text}' CIFAR100 train data")

        return train_data