コード例 #1
0
 def get_dataset(self, name, train_ds=True):
     valid_datasets = {
         'cifar10':
         lambda: (datasets.CIFAR10(self.root_folder,
                                   train=True,
                                   transform=transforms.ToTensor(),
                                   download=True),
                  datasets.CIFAR10(self.root_folder,
                                   train=False,
                                   transform=transforms.ToTensor(),
                                   download=True)),
         'stl10':
         lambda: (datasets.STL10(self.root_folder,
                                 split='train',
                                 transform=transforms.ToTensor(),
                                 download=True),
                  datasets.STL10(self.root_folder,
                                 split='test',
                                 transform=transforms.ToTensor(),
                                 download=True))
     }
     try:
         dataset_fn = valid_datasets[name]
     except KeyError:
         raise InvalidDatasetSelection()
     else:
         return dataset_fn()
コード例 #2
0
    def get_test_dataset(self, name, n_views):
        valid_datasets = {
            'cifar10':
            lambda: datasets.CIFAR10(self.root_folder,
                                     train=False,
                                     transform=transforms.ToTensor(),
                                     download=True),
            'stl10':
            lambda: datasets.STL10(self.root_folder,
                                   split='unlabeled',
                                   transform=ContrastiveLearningViewGenerator(
                                       self.get_simclr_pipeline_transform(96),
                                       n_views),
                                   download=True),
            'mnist':
            lambda: datasets.MNIST(self.root_folder,
                                   transform=ContrastiveLearningViewGenerator(
                                       self.get_simclr_pipeline_transform(32),
                                       n_views),
                                   download=True)
        }

        try:
            dataset_fn = valid_datasets[name]
        except KeyError:
            raise InvalidDatasetSelection()
        else:
            return dataset_fn()
コード例 #3
0
    def get_dataset(self, name, transform='none', train=True):
        if name == 'cifar10':
            if transform == 'none':
                transform = transforms.ToTensor()
            elif transform == 'cifar10':
                transform = get_cifar10_transform(32, train)
            elif transform == 'simclr':
                transform = get_simclr_transform(32, train=train)

            return datasets.CIFAR10(self.root_folder,
                                    train=train,
                                    transform=transform,
                                    download=True)

        elif name == 'stl10':
            if transform == 'none':
                transform = transforms.ToTensor()
            elif transform == 'cifar10':
                transform = get_cifar10_transform(96, train)
            elif transform == 'simclr':
                transform = get_simclr_transform(96, train=train)

            split = 'train' if train else 'test'
            return datasets.STL10(self.root_folder,
                                  split=split,
                                  transform=transform,
                                  download=True)

        else:
            raise InvalidDatasetSelection()
コード例 #4
0
 def get_dataset(
     self, name, n_views
 ):  # the use of lambda is pretty nice over here. without lambda, each dataset object is created when "valid_datasets" is called. with lambda, just the lambda functions are called. the actual object is created only when dict is called via the key. saves created ALL dataset objects. nice
     valid_datasets = {
         'cifar10':
         lambda: datasets.CIFAR10(
             self.root_folder,
             train=True,
             transform=ContrastiveLearningViewGenerator(
                 self.get_simclr_pipeline_transform(32), n_views),
             download=True),
         'stl10':
         lambda: datasets.STL10(self.root_folder,
                                split='unlabeled',
                                transform=ContrastiveLearningViewGenerator(
                                    self.get_simclr_pipeline_transform(96),
                                    n_views),
                                download=True)
     }
     # some new dataset will be added here. need to pass in the dataset image size to this function. later : https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
     try:
         dataset_fn = valid_datasets[name]  #lambda fn
     except KeyError:
         raise InvalidDatasetSelection()
     else:
         return dataset_fn()  #return lambda_fn returns the datasets object
コード例 #5
0
 def get_dataset(self, name, n_views):
     valid_datasets = {
         'cifar10':
         lambda: datasets.CIFAR10(
             self.root_folder,
             train=True,
             transform=ContrastiveLearningViewGenerator(
                 self.get_simclr_pipeline_transform(32), n_views),
             download=True),
         'stl10':
         lambda: datasets.STL10(self.root_folder,
                                split='unlabeled',
                                transform=ContrastiveLearningViewGenerator(
                                    self.get_simclr_pipeline_transform(96),
                                    n_views),
                                download=True),
         'dldataset':
         lambda: CustomDataSet(self.root_folder + '/unlabeled',
                               transform=ContrastiveLearningViewGenerator(
                                   self.get_simclr_pipeline_transform(96),
                                   n_views)),
         'dldataset_aug7':
         lambda: CustomDataSet(self.root_folder + '/unlabeled',
                               transform=ContrastiveLearningViewGenerator(
                                   self.get_simclr_pipeline_transform(
                                       96, num_aug=7), n_views))
     }
     try:
         dataset_fn = valid_datasets[name]
     except KeyError:
         raise InvalidDatasetSelection()
     else:
         return dataset_fn()
コード例 #6
0
ファイル: head_dataset.py プロジェクト: fjbriones/SimCLR
    def get_num_classes(self, name):
        valid_datasets = {'cifar10': 10, 'stl10': 10, 'mnist': 10}

        try:
            dataset_classes = valid_datasets[name]
        except KeyError:
            raise InvalidDatasetSelection()
        else:
            return dataset_classes
コード例 #7
0
    def get_dataset(self, name):
        valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True,
                                                              transform=self.get_simclr_pipeline_transform(32),
                                                              download=True),

                          'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled',
                                                          transform=self.get_simclr_pipeline_transform(96),
                                                          download=True)}

        try:
            dataset_fn = valid_datasets[name]
        except KeyError:
            raise InvalidDatasetSelection()
        else:
            return IndexDataset(dataset_fn())
コード例 #8
0
    def get_dataset(self, name, n_views, train=True):
        if name == 'cifar10':
            return datasets.CIFAR10(self.root_folder,
                                    train=train,
                                    transform=ContrastiveLearningViewGenerator(
                                        get_simclr_transform(32, train=True),
                                        n_views),
                                    download=True)

        elif name == 'stl10':
            split = 'train' if train else 'test'
            return datasets.STL10(self.root_folder,
                                  split=split,
                                  transform=ContrastiveLearningViewGenerator(
                                      get_simclr_transform(96, train=True),
                                      n_views),
                                  download=True)

        else:
            raise InvalidDatasetSelection()
コード例 #9
0
    def get_dataset(self, name, n_views):
        valid_datasets = {
            'cifar10':
            lambda: datasets.CIFAR10(
                self.root_folder,
                train=True,
                transform=ContrastiveLearningViewGenerator(
                    self.get_simclr_pipeline_transform(32), n_views),
                download=True),
            'stl10':
            lambda: datasets.STL10(self.root_folder,
                                   split='unlabeled',
                                   transform=ContrastiveLearningViewGenerator(
                                       self.get_simclr_pipeline_transform(96),
                                       n_views),
                                   download=True),
            'mscoco':
            lambda: CocoDetection(
                os.path.join(self.root_folder, 'mscoco', 'train2017'
                             ),
                annFile=os.path.join(self.root_folder, 'mscoco', 'annotations',
                                     'captions_train2017.json'),
                tokenizer=SentenceTransformer('bert-base-nli-mean-tokens'),
                transform=ContrastiveLearningViewGenerator(
                    self.get_simclr_pipeline_transform(96), n_views),
                encodingPath=os.path.join(self.root_folder, 'mscoco',
                                          'encoded_captions_train2017.json'),
                generateEncodings=False),
            'mscocovalid':
            lambda:
            CocoDetection(os.path.join(self.root_folder, 'mscoco', 'val2017'),
                          annFile=os.path.
                          join(self.root_folder, 'mscoco', 'annotations',
                               'captions_val2017.json'),
                          tokenizer=SentenceTransformer(
                              'bert-base-nli-mean-tokens'),
                          transform=ContrastiveLearningViewGenerator(
                              self.get_simclr_pipeline_transform(96), n_views),
                          encodingPath=os.path.join(
                              self.root_folder, 'mscoco',
                              'encoded_captions_val2017.json'),
                          generateEncodings=False),
            'mscocobaseline':
            lambda: CocoDetectionBaseline(
                os.path.join(self.root_folder, 'mscoco', 'train2017'),
                annFile=os.path.join(self.root_folder, 'mscoco', 'annotations',
                                     'captions_train2017.json'),
                transform=ContrastiveLearningViewGenerator(
                    self.get_simclr_pipeline_transform(96), n_views)),
            'mscocobaselinevalid':
            lambda: CocoDetectionBaseline(
                os.path.join(self.root_folder, 'mscoco', 'val2017'),
                annFile=os.path.join(self.root_folder, 'mscoco', 'annotations',
                                     'captions_val2017.json'),
                transform=ContrastiveLearningViewGenerator(
                    self.get_simclr_pipeline_transform(96), n_views))
        }

        try:
            dataset_fn = valid_datasets[name]
        except KeyError:
            raise InvalidDatasetSelection()
        else:
            return dataset_fn()