コード例 #1
0
 def imagenet(dataset_root, nb_classes, split: str = 'train'):
     assert split in ('train', 'val')
     dataset = UnlabeledImagenet(
         dataset_root,
         nb_classes=nb_classes,
         split=split,
         transform=amdim_transforms.AMDIMTrainTransformsImageNet128(),
     )
     return dataset
コード例 #2
0
 def imagenet(dataset_root, nb_classes, patch_size, patch_overlap, split: str = 'train'):
     assert split in ('train', 'val')
     train_transform = amdim_transforms.TransformsImageNet128Patches(
         patch_size=patch_size,
         overlap=patch_overlap
     )
     dataset = UnlabeledImagenet(
         dataset_root,
         nb_classes=nb_classes,
         split=split,
         transform=train_transform,
     )
     return dataset
コード例 #3
0
    def val_dataloader(self, num_images_per_class=50, add_normalize=False):
        transforms = self._default_transforms(
        ) if self.val_transforms is None else self.val_transforms

        dataset = UnlabeledImagenet(
            self.data_dir,
            num_imgs_per_class_val_split=num_images_per_class,
            meta_dir=self.meta_dir,
            split='val',
            transform=transforms)
        loader = DataLoader(dataset,
                            batch_size=self.batch_size,
                            shuffle=False,
                            num_workers=self.num_workers,
                            pin_memory=True)
        return loader
コード例 #4
0
    def train_dataloader(self, num_images_per_class=-1, add_normalize=False):
        transforms = self._default_transforms(
        ) if self.train_transforms is None else self.train_transforms

        dataset = UnlabeledImagenet(self.data_dir,
                                    num_imgs_per_class=num_images_per_class,
                                    meta_dir=self.meta_dir,
                                    split='train',
                                    transform=transforms)
        loader = DataLoader(dataset,
                            batch_size=self.batch_size,
                            shuffle=self.shuffle,
                            num_workers=self.num_workers,
                            drop_last=self.drop_last,
                            pin_memory=self.pin_memory)
        return loader
コード例 #5
0
    def test_dataloader(self):
        """
        Uses the validation split of imagenet2012 for testing
        """
        transforms = self.val_transform(
        ) if self.test_transforms is None else self.test_transforms

        dataset = UnlabeledImagenet(self.data_dir,
                                    num_imgs_per_class=-1,
                                    meta_dir=self.meta_dir,
                                    split='test',
                                    transform=transforms)
        loader = DataLoader(dataset,
                            batch_size=self.batch_size,
                            shuffle=False,
                            num_workers=self.num_workers,
                            drop_last=True,
                            pin_memory=True)
        return loader
コード例 #6
0
    def train_dataloader(self):
        """
        Uses the train split of imagenet2012 and puts away a portion of it for the validation split
        """
        transforms = self.train_transform(
        ) if self.train_transforms is None else self.train_transforms

        dataset = UnlabeledImagenet(self.data_dir,
                                    num_imgs_per_class=-1,
                                    meta_dir=self.meta_dir,
                                    split='train',
                                    transform=transforms)
        loader = DataLoader(dataset,
                            batch_size=self.batch_size,
                            shuffle=True,
                            num_workers=self.num_workers,
                            drop_last=True,
                            pin_memory=True)
        return loader
コード例 #7
0
    def val_dataloader(self):
        """
        Uses the part of the train split of imagenet2012  that was not used for training via `num_imgs_per_val_class`

        Args:
            batch_size: the batch size
            transforms: the transforms
        """
        transforms = self.train_transform(
        ) if self.val_transforms is None else self.val_transforms

        dataset = UnlabeledImagenet(
            self.data_dir,
            num_imgs_per_class_val_split=self.num_imgs_per_val_class,
            meta_dir=self.meta_dir,
            split='val',
            transform=transforms)
        loader = DataLoader(dataset,
                            batch_size=self.batch_size,
                            shuffle=False,
                            num_workers=self.num_workers,
                            pin_memory=True)
        return loader