def val_dataloader(self) -> DataLoader:
        """Uses the part of the train split of imagenet2012  that was not used for training via
        `num_imgs_per_val_class`

        Args:
            batch_size: the batch size
            transforms: the transforms
        """
        transforms = self.val_transform(
        ) if self.val_transforms is None else self.val_transforms

        dataset = UnlabeledImagenet(
            self.data_dir,
            num_imgs_per_class_val_split=self.num_imgs_per_val_class,
            meta_dir=self.meta_dir,
            split="val",
            transform=transforms,
        )
        loader: DataLoader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=self.drop_last,
            pin_memory=self.pin_memory,
        )
        return loader
예제 #2
0
 def imagenet(dataset_root, nb_classes, split: str = 'train'):
     assert split in ('train', 'val')
     dataset = UnlabeledImagenet(
         dataset_root,
         nb_classes=nb_classes,
         split=split,
         transform=amdim_transforms.AMDIMTrainTransformsImageNet128(),
     )
     return dataset
예제 #3
0
 def imagenet(dataset_root,
              nb_classes,
              patch_size,
              patch_overlap,
              split: str = 'train'):
     assert split in ('train', 'val')
     train_transform = amdim_transforms.TransformsImageNet128Patches(
         patch_size=patch_size, overlap=patch_overlap)
     dataset = UnlabeledImagenet(
         dataset_root,
         nb_classes=nb_classes,
         split=split,
         transform=train_transform,
     )
     return dataset
    def train_dataloader(self,
                         num_images_per_class: int = -1,
                         add_normalize: bool = False) -> DataLoader:
        transforms = self._default_transforms(
        ) if self.train_transforms is None else self.train_transforms

        dataset = UnlabeledImagenet(self.data_dir,
                                    num_imgs_per_class=num_images_per_class,
                                    meta_dir=self.meta_dir,
                                    split='train',
                                    transform=transforms)
        loader: DataLoader = DataLoader(dataset,
                                        batch_size=self.batch_size,
                                        shuffle=self.shuffle,
                                        num_workers=self.num_workers,
                                        drop_last=self.drop_last,
                                        pin_memory=self.pin_memory)
        return loader
예제 #5
0
    def test_dataloader(self) -> DataLoader:
        """
        Uses the validation split of imagenet2012 for testing
        """
        transforms = self.val_transform(
        ) if self.test_transforms is None else self.test_transforms

        dataset = UnlabeledImagenet(self.data_dir,
                                    num_imgs_per_class=-1,
                                    meta_dir=self.meta_dir,
                                    split='test',
                                    transform=transforms)
        loader: DataLoader = DataLoader(dataset,
                                        batch_size=self.batch_size,
                                        shuffle=False,
                                        num_workers=self.num_workers,
                                        drop_last=self.drop_last,
                                        pin_memory=self.pin_memory)
        return loader
예제 #6
0
    def train_dataloader(self) -> DataLoader:
        """
        Uses the train split of imagenet2012 and puts away a portion of it for the validation split
        """
        transforms = self.train_transform(
        ) if self.train_transforms is None else self.train_transforms

        dataset = UnlabeledImagenet(
            self.data_dir,
            num_imgs_per_class=-1,
            num_imgs_per_class_val_split=self.num_imgs_per_val_class,
            meta_dir=self.meta_dir,
            split='train',
            transform=transforms)
        loader: DataLoader = DataLoader(dataset,
                                        batch_size=self.batch_size,
                                        shuffle=self.shuffle,
                                        num_workers=self.num_workers,
                                        drop_last=self.drop_last,
                                        pin_memory=self.pin_memory)
        return loader