コード例 #1
0
    def reg_loader_subset(self, indices):
        """Returns the torch dataloader over the regularization set (unsupervised examples only). """
        # transform the unsupervised set the same way as the training set:

        mini_batch_size = self.mini_batch_size()
        return torch.utils.data.DataLoader(self.unsupset, batch_size=mini_batch_size, shuffle=False,
                                           sampler=ProtectedSubsetRandomSampler(indices),
                                           num_workers=2)
コード例 #2
0
    def loader_for_dataset(self, dataset):
        mini_batch_size = self.mini_batch_size()

        return torch.utils.data.DataLoader(
            dataset,
            batch_size=mini_batch_size,
            shuffle=False,
            sampler=ProtectedSubsetRandomSampler(range(0, len(dataset))))
コード例 #3
0
    def train_loader_subset(self, indices):
        """Returns the torch dataloader over the training set, shuffled,
        but limited to the example range start-end."""
        mini_batch_size = self.mini_batch_size()

        trainloader = torch.utils.data.DataLoader(self.trainset, batch_size=mini_batch_size, shuffle=False,
                                                  sampler=ProtectedSubsetRandomSampler(indices),
                                                  num_workers=2)
        return trainloader
コード例 #4
0
 def test_loader_subset(self, indices):
     """Returns the torch dataloader over the test set. """
     mini_batch_size = self.mini_batch_size()
     return torch.utils.data.DataLoader(
         self.testset,
         sampler=ProtectedSubsetRandomSampler(indices),
         batch_size=mini_batch_size,
         shuffle=False,
         num_workers=2)
コード例 #5
0
ファイル: STL10_NT64Problem.py プロジェクト: fac2003/ureg
    def loader_for_dataset(self, dataset):
        mini_batch_size = self.mini_batch_size()

        return torch.utils.data.DataLoader(
            dataset,
            batch_size=mini_batch_size,
            shuffle=False,
            sampler=ProtectedSubsetRandomSampler(range(0, len(dataset))),
            collate_fn=stl10_collate,
            num_workers=self.num_workers)
コード例 #6
0
ファイル: STL10_NT64Problem.py プロジェクト: fac2003/ureg
    def test_loader_subset(self, indices):
        """Returns the torch dataloader over the test set, limiting to the examples
        identified by the indices. """

        mini_batch_size = self.mini_batch_size()
        return torch.utils.data.DataLoader(
            self._testset,
            collate_fn=stl10_collate,
            sampler=ProtectedSubsetRandomSampler(indices),
            batch_size=mini_batch_size,
            shuffle=False,
            num_workers=self.num_workers)