def _generate_artificial_anomalies_train_set(self, supervise_mode: str, noise_mode: str, oe_limit: int, train_set: Dataset, nom_class: int): """ This method generates offline artificial anomalies, i.e. it generates them once at the start of the training and adds them to the training set. It creates a balanced dataset, thus sampling as many anomalies as there are nominal samples. This is way faster than online generation, but lacks diversity (hence usually weaker performance). :param supervise_mode: the type of generated artificial anomalies. unsupervised: no anomalies, returns a subset of the original dataset containing only nominal samples. other: other classes, i.e. all the true anomalies! noise: pure noise images (can also be outlier exposure based). malformed_normal: add noise to nominal samples to create malformed nominal anomalies. malformed_normal_gt: like malformed_normal, but also creates artificial ground-truth maps that mark pixels anomalous where the difference between the original nominal sample and the malformed one is greater than a low threshold. :param noise_mode: the type of noise used, see :mod:`fcdd.datasets.noise_mode`. :param oe_limit: the number of different outlier exposure samples used in case of outlier exposure based noise. :param train_set: the training set that is to be extended with artificial anomalies. :param nom_class: the class considered nominal :return: """ if isinstance(train_set.targets, torch.Tensor): dataset_targets = train_set.targets.clone().data.cpu().numpy() else: # e.g. imagenet dataset_targets = np.asarray(train_set.targets) train_idx_normal = get_target_label_idx(dataset_targets, self.normal_classes) generated_noise = norm = None if supervise_mode not in ['unsupervised', 'other']: self.logprint('Generating artificial anomalies...') generated_noise = self._generate_noise( noise_mode, train_set.data[train_idx_normal].shape, oe_limit, self.root ) norm = train_set.data[train_idx_normal] if supervise_mode in ['other']: self._train_set = train_set elif supervise_mode in ['unsupervised']: if isinstance(train_set, GTMapADDataset): self._train_set = GTSubset(train_set, train_idx_normal) else: self._train_set = Subset(train_set, train_idx_normal) elif supervise_mode in ['noise']: self._train_set = apply_noise(self.outlier_classes, generated_noise, norm, nom_class, train_set) elif supervise_mode in ['malformed_normal']: self._train_set = apply_malformed_normal(self.outlier_classes, generated_noise, norm, nom_class, train_set) elif supervise_mode in ['malformed_normal_gt']: train_set, gtmaps = apply_malformed_normal( self.outlier_classes, generated_noise, norm, nom_class, train_set, gt=True ) self._train_set = GTMapADDatasetExtension(train_set, gtmaps) else: raise NotImplementedError('Supervise mode {} unknown.'.format(supervise_mode)) if supervise_mode not in ['unsupervised', 'other']: self.logprint('Artificial anomalies generated.')
def __init__(self, root: str, normal_class: int, preproc: str, nominal_label: int, supervise_mode: str, noise_mode: str, oe_limit: int, online_supervision: bool, logger: Logger = None): """ AD dataset for ImageNet. Following Hendrycks et al. (https://arxiv.org/abs/1812.04606) this AD dataset is limited to 30 of the 1000 classes of Imagenet (see :attr:`ADImageNet.ad_classes`). :param root: root directory where data is found or is to be downloaded to :param normal_class: the class considered nominal :param preproc: the kind of preprocessing pipeline :param nominal_label: the label that marks nominal samples in training. The scores in the heatmaps always rate label 1, thus usually the nominal label is 0, s.t. the scores are anomaly scores. :param supervise_mode: the type of generated artificial anomalies. See :meth:`fcdd.datasets.bases.TorchvisionDataset._generate_artificial_anomalies_train_set`. :param noise_mode: the type of noise used, see :mod:`fcdd.datasets.noise_mode`. :param oe_limit: limits the number of different anomalies in case of Outlier Exposure (defined in noise_mode) :param online_supervision: whether to sample anomalies online in each epoch, or offline before training (same for all epochs in this case) :param logger: logger """ assert online_supervision, 'ImageNet artificial anomaly generation needs to be online' assert supervise_mode in ['unsupervised', 'other', 'noise'], \ 'Noise mode "malformed_normal" is not supported for ImageNet because nominal images are loaded ' \ 'only if not replaced by some artificial anomaly (to speedup data preprocessing).' root = pt.join(root, self.base_folder) super().__init__(root, logger=logger) self.n_classes = 2 # 0: normal, 1: outlier self.shape = (3, 224, 224) self.raw_shape = (3, 256, 256) self.normal_classes = tuple([normal_class]) self.outlier_classes = list(range(0, 30)) self.outlier_classes.remove(normal_class) assert nominal_label in [0, 1] self.nominal_label = nominal_label self.anomalous_label = 1 if self.nominal_label == 0 else 0 if self.nominal_label != 0: self.logprint( 'Swapping labels, i.e. anomalies are 0 and nominals are 1.') # mean and std of original pictures per class mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) # different types of preprocessing pipelines, here just choose whether to use augmentations if preproc in ['', None, 'default', 'none']: test_transform = transform = transforms.Compose([ transforms.Resize((self.shape[-2], self.shape[-1])), transforms.ToTensor(), transforms.Normalize(mean[normal_class], std[normal_class]) ]) elif preproc in ['aug1']: test_transform = transforms.Compose([ transforms.Resize(self.raw_shape[-1]), transforms.CenterCrop(self.shape[-1]), transforms.ToTensor(), transforms.Normalize(mean, std) ]) transform = transforms.Compose([ transforms.Resize(self.raw_shape[-1]), transforms.ColorJitter(brightness=0.01, contrast=0.01, saturation=0.01, hue=0.01), transforms.RandomHorizontalFlip(), transforms.RandomCrop(self.shape[-1]), transforms.ToTensor(), transforms.Lambda(lambda x: x + 0.001 * torch.randn_like(x)), transforms.Normalize(mean, std) ]) else: raise ValueError( 'Preprocessing pipeline {} is not known.'.format(preproc)) target_transform = transforms.Lambda( lambda x: self.anomalous_label if x in self.outlier_classes else self.nominal_label) if supervise_mode not in ['unsupervised', 'other']: all_transform = OnlineSupervisor(self, supervise_mode, noise_mode, oe_limit) else: all_transform = None train_set = MyImageNet(root=self.root, split='train', normal_classes=self.normal_classes, transform=transform, target_transform=target_transform, all_transform=all_transform, logger=logger) self.train_ad_classes_idx = train_set.get_class_idx(self.ad_classes) train_set.targets = [ # t = nan if not in ad_classes else give id in order of ad_classes self.train_ad_classes_idx.index(t) if t in self.train_ad_classes_idx else np.nan for t in train_set.targets ] self._generate_artificial_anomalies_train_set( 'unsupervised', noise_mode, oe_limit, train_set, normal_class, # gets rid of true anomalous samples ) self._test_set = MyImageNet(root=self.root, split='val', normal_classes=self.normal_classes, transform=test_transform, target_transform=target_transform, logger=logger) self.test_ad_classes_idx = self._test_set.get_class_idx( self.ad_classes) self._test_set.targets = [ # t = nan if not in ad_classes else give id in order of ad_classes self.test_ad_classes_idx.index(t) if t in self.test_ad_classes_idx else np.nan for t in self._test_set.targets ] self._test_set = Subset( self._test_set, get_target_label_idx(np.asarray(self._test_set.targets), list(range(len(self.ad_classes))))) self._test_set.fixed_random_order = MyImageNet.fixed_random_order
def __init__(self, root: str, normal_class: int, preproc: str, nominal_label: int, supervise_mode: str, noise_mode: str, oe_limit: int, online_supervision: bool, logger: Logger = None): """ This is a general-purpose implementation for custom datasets. It expects the data being contained in class folders and distinguishes between (1) the one-vs-rest (ovr) approach where one class is considered normal and is tested against all other classes being anomalous (2) the general approach where each class folder contains a normal data folder and an anomalous data folder. The :attr:`ovr` determines this. For (1) the data folders have to follow this structure: root/custom/train/dog/xxx.png root/custom/train/dog/xxy.png root/custom/train/dog/xxz.png root/custom/train/cat/123.png root/custom/train/cat/nsdf3.png root/custom/train/cat/asd932_.png For (2): root/custom/train/hazelnut/normal/xxx.png root/custom/train/hazelnut/normal/xxy.png root/custom/train/hazelnut/normal/xxz.png root/custom/train/hazelnut/anomalous/xxa.png -- may be used during training for a semi-supervised setting root/custom/train/screw/normal/123.png root/custom/train/screw/normal/nsdf3.png root/custom/train/screw/anomalous/asd932_.png -- may be used during training for a semi-supervised setting The same holds for the test set, where "train" has to be replaced by "test". :param root: root directory where data is found. :param normal_class: the class considered nominal. :param preproc: the kind of preprocessing pipeline. :param nominal_label: the label that marks nominal samples in training. The scores in the heatmaps always rate label 1, thus usually the nominal label is 0, s.t. the scores are anomaly scores. :param supervise_mode: the type of generated artificial anomalies. See :meth:`fcdd.datasets.bases.TorchvisionDataset._generate_artificial_anomalies_train_set`. :param noise_mode: the type of noise used, see :mod:`fcdd.datasets.noise_mode`. :param oe_limit: limits the number of different anomalies in case of Outlier Exposure (defined in noise_mode). :param online_supervision: whether to sample anomalies online in each epoch, or offline before training (same for all epochs in this case). :param logger: logger. """ assert online_supervision, 'Artificial anomaly generation for custom datasets needs to be online' self.trainpath = pt.join(root, self.base_folder, 'train') self.testpath = pt.join(root, self.base_folder, 'test') super().__init__(root, logger=logger) self.n_classes = 2 # 0: normal, 1: outlier self.raw_shape = (3, 248, 248) self.shape = ( 3, 224, 224 ) # shape of your data samples in channels x height x width after image preprocessing self.normal_classes = tuple([normal_class]) self.outlier_classes = list(range(0, len(extract_custom_classes(root)))) self.outlier_classes.remove(normal_class) assert nominal_label in [0, 1] self.nominal_label = nominal_label self.anomalous_label = 1 if self.nominal_label == 0 else 0 # precomputed mean and std of your training data self.mean, self.std = self.extract_mean_std(self.trainpath, normal_class) if preproc in ['', None, 'default', 'none']: test_transform = transform = transforms.Compose([ transforms.Resize((self.shape[-2], self.shape[-1])), transforms.ToTensor(), transforms.Normalize(self.mean, self.std) ]) elif preproc in ['aug1']: test_transform = transforms.Compose([ transforms.Resize((self.raw_shape[-1])), transforms.CenterCrop(self.shape[-1]), transforms.ToTensor(), transforms.Normalize(self.mean, self.std) ]) transform = transforms.Compose([ transforms.Resize(self.raw_shape[-1]), transforms.ColorJitter(brightness=0.01, contrast=0.01, saturation=0.01, hue=0.01), transforms.RandomHorizontalFlip(), transforms.RandomCrop(self.shape[-1]), transforms.ToTensor(), transforms.Lambda(lambda x: x + 0.001 * torch.randn_like(x)), transforms.Normalize(self.mean, self.std) ]) # here you could define other pipelines with augmentations else: raise ValueError( 'Preprocessing pipeline {} is not known.'.format(preproc)) self.target_transform = transforms.Lambda( lambda x: self.anomalous_label if x in self.outlier_classes else self.nominal_label) if supervise_mode not in ['unsupervised', 'other']: self.all_transform = OnlineSupervisor(self, supervise_mode, noise_mode, oe_limit) else: self.all_transform = None self._train_set = ImageFolderDataset( self.trainpath, supervise_mode, self.raw_shape, self.ovr, self.nominal_label, self.anomalous_label, normal_classes=self.normal_classes, transform=transform, target_transform=self.target_transform, all_transform=self.all_transform, ) if supervise_mode == 'other': # (semi)-supervised setting self.balance_dataset() else: self._train_set = Subset( self._train_set, np.argwhere((np.asarray(self._train_set.anomaly_labels) == self.nominal_label) * np.isin(self._train_set.targets, self.normal_classes)).flatten().tolist()) self._test_set = ImageFolderDataset( self.testpath, supervise_mode, self.raw_shape, self.ovr, self.nominal_label, self.anomalous_label, normal_classes=self.normal_classes, transform=test_transform, target_transform=self.target_transform, ) if not self.ovr: self._test_set = Subset( self._test_set, get_target_label_idx(self._test_set.targets, np.asarray(self.normal_classes)))
def __init__(self, root: str, normal_class: int, preproc: str, nominal_label: int, supervise_mode: str, noise_mode: str, oe_limit: int, online_supervision: bool, logger: Logger = None, raw_shape: int = 240): """ AD dataset for MVTec-AD. If no MVTec data is found in the root directory, the data is downloaded and processed to be stored in torch tensors with appropriate size (defined in raw_shape). This speeds up data loading at the start of training. :param root: root directory where data is found or is to be downloaded to :param normal_class: the class considered nominal :param preproc: the kind of preprocessing pipeline :param nominal_label: the label that marks nominal samples in training. The scores in the heatmaps always rate label 1, thus usually the nominal label is 0, s.t. the scores are anomaly scores. :param supervise_mode: the type of generated artificial anomalies. See :meth:`fcdd.datasets.bases.TorchvisionDataset._generate_artificial_anomalies_train_set`. :param noise_mode: the type of noise used, see :mod:`fcdd.datasets.noise_mode`. :param oe_limit: limits the number of different anomalies in case of Outlier Exposure (defined in noise_mode) :param online_supervision: whether to sample anomalies online in each epoch, or offline before training (same for all epochs in this case). :param logger: logger :param raw_shape: the height and width of the raw MVTec images before passed through the preprocessing pipeline. """ super().__init__(root, logger=logger) self.n_classes = 2 # 0: normal, 1: outlier self.shape = (3, 224, 224) self.raw_shape = (3,) + (raw_shape, ) * 2 self.normal_classes = tuple([normal_class]) self.outlier_classes = list(range(0, 15)) self.outlier_classes.remove(normal_class) assert nominal_label in [0, 1], 'GT maps are required to be binary!' self.nominal_label = nominal_label self.anomalous_label = 1 if self.nominal_label == 0 else 0 # min max after gcn l1 norm has been applied min_max_l1 = [ [(-1.3336724042892456, -1.3107913732528687, -1.2445921897888184), (1.3779616355895996, 1.3779616355895996, 1.3779616355895996)], [(-2.2404820919036865, -2.3387579917907715, -2.2896201610565186), (4.573435306549072, 4.573435306549072, 4.573435306549072)], [(-3.184587001800537, -3.164201259613037, -3.1392977237701416), (1.6995097398757935, 1.6011602878570557, 1.5209171772003174)], [(-3.0334954261779785, -2.958242416381836, -2.7701096534729004), (6.503103256225586, 5.875098705291748, 5.814228057861328)], [(-3.100773334503174, -3.100773334503174, -3.100773334503174), (4.27892541885376, 4.27892541885376, 4.27892541885376)], [(-3.6565306186676025, -3.507692813873291, -2.7635035514831543), (18.966819763183594, 21.64590072631836, 26.408710479736328)], [(-1.5192601680755615, -2.2068002223968506, -2.3948357105255127), (11.564697265625, 10.976534843444824, 10.378695487976074)], [(-1.3207964897155762, -1.2889339923858643, -1.148416519165039), (6.854909896850586, 6.854909896850586, 6.854909896850586)], [(-0.9883341193199158, -0.9822461605072021, -0.9288841485977173), (2.290637969970703, 2.4007883071899414, 2.3044068813323975)], [(-7.236185073852539, -7.236185073852539, -7.236185073852539), (3.3777384757995605, 3.3777384757995605, 3.3777384757995605)], [(-3.2036616802215576, -3.221003532409668, -3.305514335632324), (7.022546768188477, 6.115569114685059, 6.310940742492676)], [(-0.8915618658065796, -0.8669204115867615, -0.8002046346664429), (4.4255571365356445, 4.642300128936768, 4.305730819702148)], [(-1.9086798429489136, -2.0004451274871826, -1.929288387298584), (5.463134765625, 5.463134765625, 5.463134765625)], [(-2.9547364711761475, -3.17536997795105, -3.143850803375244), (5.305514812469482, 4.535006523132324, 3.3618252277374268)], [(-1.2906527519226074, -1.2906527519226074, -1.2906527519226074), (2.515115737915039, 2.515115737915039, 2.515115737915039)] ] # mean and std of original images per class mean = [ (0.53453129529953, 0.5307118892669678, 0.5491130352020264), (0.326835036277771, 0.41494372487068176, 0.46718254685401917), (0.6953922510147095, 0.6663950085639954, 0.6533040404319763), (0.36377236247062683, 0.35087138414382935, 0.35671544075012207), (0.4484519958496094, 0.4484519958496094, 0.4484519958496094), (0.2390524297952652, 0.17620408535003662, 0.17206747829914093), (0.3919542133808136, 0.2631213963031769, 0.22006843984127045), (0.21368788182735443, 0.23478130996227264, 0.24079132080078125), (0.30240726470947266, 0.3029524087905884, 0.32861486077308655), (0.7099748849868774, 0.7099748849868774, 0.7099748849868774), (0.4567880630493164, 0.4711957275867462, 0.4482630491256714), (0.19987481832504272, 0.18578395247459412, 0.19361256062984467), (0.38699793815612793, 0.276934415102005, 0.24219433963298798), (0.6718143820762634, 0.47696375846862793, 0.35050269961357117), (0.4014520049095154, 0.4014520049095154, 0.4014520049095154) ] std = [ (0.3667600452899933, 0.3666728734970093, 0.34991779923439026), (0.15321789681911469, 0.21510766446590424, 0.23905669152736664), (0.23858436942100525, 0.2591284513473511, 0.2601949870586395), (0.14506031572818756, 0.13994529843330383, 0.1276693195104599), (0.1636597216129303, 0.1636597216129303, 0.1636597216129303), (0.1688646823167801, 0.07597383111715317, 0.04383210837841034), (0.06069392338395119, 0.04061736911535263, 0.0303945429623127), (0.1602524220943451, 0.18222476541996002, 0.15336430072784424), (0.30409011244773865, 0.30411985516548157, 0.28656429052352905), (0.1337062269449234, 0.1337062269449234, 0.1337062269449234), (0.12076705694198608, 0.13341768085956573, 0.12879984080791473), (0.22920562326908112, 0.21501320600509644, 0.19536510109901428), (0.20621345937252045, 0.14321941137313843, 0.11695228517055511), (0.08259467780590057, 0.06751163303852081, 0.04756828024983406), (0.32304847240448, 0.32304847240448, 0.32304847240448) ] # different types of preprocessing pipelines, 'lcn' is for using LCN, 'aug{X}' for augmentations img_gt_transform, img_gt_test_transform = None, None all_transform = [] if preproc == 'lcn': assert self.raw_shape == self.shape, 'in case of no augmentation, raw shape needs to fit net input shape' img_gt_transform = img_gt_test_transform = MultiCompose([ transforms.ToTensor(), ]) test_transform = transform = transforms.Compose([ transforms.Lambda(lambda x: local_contrast_normalization(x, scale='l1')), transforms.Normalize( min_max_l1[normal_class][0], [ma - mi for ma, mi in zip(min_max_l1[normal_class][1], min_max_l1[normal_class][0])] ) ]) elif preproc in ['', None, 'default', 'none']: assert self.raw_shape == self.shape, 'in case of no augmentation, raw shape needs to fit net input shape' img_gt_transform = img_gt_test_transform = MultiCompose([ transforms.ToTensor(), ]) test_transform = transform = transforms.Compose([ transforms.Normalize(mean[normal_class], std[normal_class]) ]) elif preproc in ['aug1']: img_gt_transform = MultiCompose([ transforms.RandomChoice( [transforms.RandomCrop(self.shape[-1], padding=0), transforms.Resize(self.shape[-1], Image.NEAREST)] ), transforms.ToTensor(), ]) img_gt_test_transform = MultiCompose( [transforms.Resize(self.shape[-1], Image.NEAREST), transforms.ToTensor()] ) test_transform = transforms.Compose([ transforms.Normalize(mean[normal_class], std[normal_class]) ]) transform = transforms.Compose([ transforms.ToPILImage(), transforms.RandomChoice([ transforms.ColorJitter(0.04, 0.04, 0.04, 0.04), transforms.ColorJitter(0.005, 0.0005, 0.0005, 0.0005), ]), transforms.ToTensor(), transforms.Lambda( lambda x: (x + torch.randn_like(x).mul(np.random.randint(0, 2)).mul(x.std()).mul(0.1)).clamp(0, 1) ), transforms.Normalize(mean[normal_class], std[normal_class]) ]) elif preproc in ['lcnaug1']: img_gt_transform = MultiCompose([ transforms.RandomChoice( [transforms.RandomCrop(self.shape[-1], padding=0), transforms.Resize(self.shape[-1], Image.NEAREST)] ), transforms.ToTensor(), ]) img_gt_test_transform = MultiCompose( [transforms.Resize(self.shape[-1], Image.NEAREST), transforms.ToTensor()] ) test_transform = transforms.Compose([ transforms.Lambda(lambda x: local_contrast_normalization(x, scale='l1')), transforms.Normalize( min_max_l1[normal_class][0], [ma - mi for ma, mi in zip(min_max_l1[normal_class][1], min_max_l1[normal_class][0])] ) ]) transform = transforms.Compose([ transforms.ToPILImage(), transforms.RandomChoice([ transforms.ColorJitter(0.04, 0.04, 0.04, 0.04), transforms.ColorJitter(0.005, 0.0005, 0.0005, 0.0005), ]), transforms.ToTensor(), transforms.Lambda( lambda x: (x + torch.randn_like(x).mul(np.random.randint(0, 2)).mul(x.std()).mul(0.1)).clamp(0, 1) ), transforms.Lambda(lambda x: local_contrast_normalization(x, scale='l1')), transforms.Normalize( min_max_l1[normal_class][0], [ma - mi for ma, mi in zip(min_max_l1[normal_class][1], min_max_l1[normal_class][0])] ) ]) else: raise ValueError('Preprocessing pipeline {} is not known.'.format(preproc)) target_transform = transforms.Lambda( lambda x: self.anomalous_label if x in self.outlier_classes else self.nominal_label ) if online_supervision: # order: target_transform -> all_transform -> img_gt transform -> transform assert supervise_mode not in ['supervised'], 'supervised mode works only offline' all_transform = MultiCompose([ *all_transform, OnlineSupervisor(self, supervise_mode, noise_mode, oe_limit), ]) train_set = MvTec( root=self.root, split='train', download=True, target_transform=target_transform, img_gt_transform=img_gt_transform, transform=transform, all_transform=all_transform, shape=self.raw_shape, normal_classes=self.normal_classes, nominal_label=self.nominal_label, anomalous_label=self.anomalous_label, enlarge=ADMvTec.enlarge ) self._train_set = GTSubset( train_set, get_target_label_idx(train_set.targets.clone().data.cpu().numpy(), self.normal_classes) ) test_set = MvTec( root=self.root, split='test_anomaly_label_target', download=True, target_transform=transforms.Lambda( lambda x: self.anomalous_label if x != MvTec.normal_anomaly_label_idx else self.nominal_label ), img_gt_transform=img_gt_test_transform, transform=test_transform, shape=self.raw_shape, normal_classes=self.normal_classes, nominal_label=self.nominal_label, anomalous_label=self.anomalous_label, enlarge=False ) test_idx_normal = get_target_label_idx(test_set.targets.clone().data.cpu().numpy(), self.normal_classes) self._test_set = GTSubset(test_set, test_idx_normal) else: all_transform = MultiCompose([ *all_transform, ]) if len(all_transform) > 0 else None train_set = MvTec( root=self.root, split='train', download=True, target_transform=target_transform, all_transform=all_transform, img_gt_transform=img_gt_transform, transform=transform, shape=self.raw_shape, normal_classes=self.normal_classes, nominal_label=self.nominal_label, anomalous_label=self.anomalous_label, enlarge=ADMvTec.enlarge ) test_set = MvTec( root=self.root, split='test_anomaly_label_target', download=True, target_transform=transforms.Lambda( lambda x: self.anomalous_label if x != MvTec.normal_anomaly_label_idx else self.nominal_label ), img_gt_transform=img_gt_test_transform, transform=test_transform, shape=self.raw_shape, normal_classes=self.normal_classes, nominal_label=self.nominal_label, anomalous_label=self.anomalous_label, enlarge=False ) test_idx_normal = get_target_label_idx(test_set.targets.clone().data.cpu().numpy(), self.normal_classes) self._test_set = GTSubset(test_set, test_idx_normal) self._generate_artificial_anomalies_train_set(supervise_mode, noise_mode, oe_limit, train_set, normal_class)
def __init__(self, size: torch.Size, clsses: List[int], root: str = None, limit_var: int = np.infty, limit_per_anomaly=True, download=True, logger: Logger = None, gt=False, remove_nominal=True): """ Outlier Exposure dataset for MVTec-AD. Considers only a part of the classes. :param size: size of the samples in n x c x h x w, samples will be resized to h x w. If n is larger than the number of samples available in MVTec-AD, dataset will be enlarged by repetitions to fit n. This is important as exactly n images are extracted per iteration of the data_loader. For online supervision n should be set to 1 because only one sample is extracted at a time. :param clsses: the classes that are to be considered, i.e. all other classes are dismissed. :param root: root directory where data is found or is to be downloaded to. :param limit_var: limits the number of different samples, i.e. randomly chooses limit_var many samples from all available ones to be the training data. :param limit_per_anomaly: whether limit_var limits the number of different samples per type of defection or overall. :param download: whether to download the data if it is not found in root. :param logger: logger. :param gt: whether ground-truth maps are to be included in the data. :param remove_nominal: whether nominal samples are to be excluded from the data. """ assert len(size) == 4 and size[2] == size[3] assert size[1] in [1, 3] self.root = root self.logger = logger self.size = size self.use_gt = gt self.clsses = clsses super().__init__(root, 'test', download=download, shape=size[1:], logger=logger) self.img_gt_transform = MultiCompose( [transforms.Resize((size[2], size[2])), transforms.ToTensor()]) self.picks = get_target_label_idx(self.targets, self.clsses) if remove_nominal: self.picks = sorted( list( set.intersection( set(self.picks), set((self.anomaly_labels != self.normal_anomaly_label_idx ).nonzero().squeeze().tolist())))) if limit_per_anomaly and limit_var is not None: new_picks = [] for l in set(self.anomaly_labels.tolist()): linclsses = list( set.intersection( set(self.picks), set((self.anomaly_labels == l ).nonzero().squeeze().tolist()))) if len(linclsses) == 0: continue if limit_var < len(linclsses): new_picks.extend( np.random.choice(linclsses, size=limit_var, replace=False)) else: self.logprint( 'OEMvTec shall be limited to {} samples per anomaly label, ' 'but MvTec anomaly label {} contains only {} samples, thus using all.' .format(limit_var, self.anomaly_label_strings[l], len(linclsses)), fps=False) new_picks.extend(linclsses) self.picks = sorted(new_picks) else: if limit_var is not None and limit_var < len(self): self.picks = np.random.choice(self.picks, size=limit_var, replace=False) if limit_var is not None and limit_var > len(self): self.logprint( 'OEMvTec shall be limited to {} samples, but MvTec contains only {} samples, thus using all.' .format(limit_var, len(self))) if len(self) < size[0]: raise NotImplementedError()