Ejemplo n.º 1
0
 def _generate_artificial_anomalies_train_set(self, supervise_mode: str, noise_mode: str, oe_limit: int,
                                              train_set: Dataset, nom_class: int):
     """
     This method generates offline artificial anomalies,
     i.e. it generates them once at the start of the training and adds them to the training set.
     It creates a balanced dataset, thus sampling as many anomalies as there are nominal samples.
     This is way faster than online generation, but lacks diversity (hence usually weaker performance).
     :param supervise_mode: the type of generated artificial anomalies.
         unsupervised: no anomalies, returns a subset of the original dataset containing only nominal samples.
         other: other classes, i.e. all the true anomalies!
         noise: pure noise images (can also be outlier exposure based).
         malformed_normal: add noise to nominal samples to create malformed nominal anomalies.
         malformed_normal_gt: like malformed_normal, but also creates artificial ground-truth maps
             that mark pixels anomalous where the difference between the original nominal sample
             and the malformed one is greater than a low threshold.
     :param noise_mode: the type of noise used, see :mod:`fcdd.datasets.noise_mode`.
     :param oe_limit: the number of different outlier exposure samples used in case of outlier exposure based noise.
     :param train_set: the training set that is to be extended with artificial anomalies.
     :param nom_class: the class considered nominal
     :return:
     """
     if isinstance(train_set.targets, torch.Tensor):
         dataset_targets = train_set.targets.clone().data.cpu().numpy()
     else:  # e.g. imagenet
         dataset_targets = np.asarray(train_set.targets)
     train_idx_normal = get_target_label_idx(dataset_targets, self.normal_classes)
     generated_noise = norm = None
     if supervise_mode not in ['unsupervised', 'other']:
         self.logprint('Generating artificial anomalies...')
         generated_noise = self._generate_noise(
             noise_mode, train_set.data[train_idx_normal].shape, oe_limit,
             self.root
         )
         norm = train_set.data[train_idx_normal]
     if supervise_mode in ['other']:
         self._train_set = train_set
     elif supervise_mode in ['unsupervised']:
         if isinstance(train_set, GTMapADDataset):
             self._train_set = GTSubset(train_set, train_idx_normal)
         else:
             self._train_set = Subset(train_set, train_idx_normal)
     elif supervise_mode in ['noise']:
         self._train_set = apply_noise(self.outlier_classes, generated_noise, norm, nom_class, train_set)
     elif supervise_mode in ['malformed_normal']:
         self._train_set = apply_malformed_normal(self.outlier_classes, generated_noise, norm, nom_class, train_set)
     elif supervise_mode in ['malformed_normal_gt']:
         train_set, gtmaps = apply_malformed_normal(
             self.outlier_classes, generated_noise, norm, nom_class, train_set, gt=True
         )
         self._train_set = GTMapADDatasetExtension(train_set, gtmaps)
     else:
         raise NotImplementedError('Supervise mode {} unknown.'.format(supervise_mode))
     if supervise_mode not in ['unsupervised', 'other']:
         self.logprint('Artificial anomalies generated.')
Ejemplo n.º 2
0
    def __init__(self,
                 root: str,
                 normal_class: int,
                 preproc: str,
                 nominal_label: int,
                 supervise_mode: str,
                 noise_mode: str,
                 oe_limit: int,
                 online_supervision: bool,
                 logger: Logger = None):
        """
        AD dataset for ImageNet. Following Hendrycks et al. (https://arxiv.org/abs/1812.04606) this AD dataset
        is limited to 30 of the 1000 classes of Imagenet (see :attr:`ADImageNet.ad_classes`).
        :param root: root directory where data is found or is to be downloaded to
        :param normal_class: the class considered nominal
        :param preproc: the kind of preprocessing pipeline
        :param nominal_label: the label that marks nominal samples in training. The scores in the heatmaps always
            rate label 1, thus usually the nominal label is 0, s.t. the scores are anomaly scores.
        :param supervise_mode: the type of generated artificial anomalies.
            See :meth:`fcdd.datasets.bases.TorchvisionDataset._generate_artificial_anomalies_train_set`.
        :param noise_mode: the type of noise used, see :mod:`fcdd.datasets.noise_mode`.
        :param oe_limit: limits the number of different anomalies in case of Outlier Exposure (defined in noise_mode)
        :param online_supervision: whether to sample anomalies online in each epoch,
            or offline before training (same for all epochs in this case)
        :param logger: logger
        """
        assert online_supervision, 'ImageNet artificial anomaly generation needs to be online'
        assert supervise_mode in ['unsupervised', 'other', 'noise'], \
            'Noise mode "malformed_normal" is not supported for ImageNet because nominal images are loaded ' \
            'only if not replaced by some artificial anomaly (to speedup data preprocessing).'

        root = pt.join(root, self.base_folder)
        super().__init__(root, logger=logger)

        self.n_classes = 2  # 0: normal, 1: outlier
        self.shape = (3, 224, 224)
        self.raw_shape = (3, 256, 256)
        self.normal_classes = tuple([normal_class])
        self.outlier_classes = list(range(0, 30))
        self.outlier_classes.remove(normal_class)
        assert nominal_label in [0, 1]
        self.nominal_label = nominal_label
        self.anomalous_label = 1 if self.nominal_label == 0 else 0

        if self.nominal_label != 0:
            self.logprint(
                'Swapping labels, i.e. anomalies are 0 and nominals are 1.')

        # mean and std of original pictures per class
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)

        # different types of preprocessing pipelines, here just choose whether to use augmentations
        if preproc in ['', None, 'default', 'none']:
            test_transform = transform = transforms.Compose([
                transforms.Resize((self.shape[-2], self.shape[-1])),
                transforms.ToTensor(),
                transforms.Normalize(mean[normal_class], std[normal_class])
            ])
        elif preproc in ['aug1']:
            test_transform = transforms.Compose([
                transforms.Resize(self.raw_shape[-1]),
                transforms.CenterCrop(self.shape[-1]),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
            transform = transforms.Compose([
                transforms.Resize(self.raw_shape[-1]),
                transforms.ColorJitter(brightness=0.01,
                                       contrast=0.01,
                                       saturation=0.01,
                                       hue=0.01),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(self.shape[-1]),
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x + 0.001 * torch.randn_like(x)),
                transforms.Normalize(mean, std)
            ])
        else:
            raise ValueError(
                'Preprocessing pipeline {} is not known.'.format(preproc))

        target_transform = transforms.Lambda(
            lambda x: self.anomalous_label
            if x in self.outlier_classes else self.nominal_label)
        if supervise_mode not in ['unsupervised', 'other']:
            all_transform = OnlineSupervisor(self, supervise_mode, noise_mode,
                                             oe_limit)
        else:
            all_transform = None

        train_set = MyImageNet(root=self.root,
                               split='train',
                               normal_classes=self.normal_classes,
                               transform=transform,
                               target_transform=target_transform,
                               all_transform=all_transform,
                               logger=logger)
        self.train_ad_classes_idx = train_set.get_class_idx(self.ad_classes)
        train_set.targets = [  # t = nan if not in ad_classes else give id in order of ad_classes
            self.train_ad_classes_idx.index(t)
            if t in self.train_ad_classes_idx else np.nan
            for t in train_set.targets
        ]
        self._generate_artificial_anomalies_train_set(
            'unsupervised',
            noise_mode,
            oe_limit,
            train_set,
            normal_class,  # gets rid of true anomalous samples
        )
        self._test_set = MyImageNet(root=self.root,
                                    split='val',
                                    normal_classes=self.normal_classes,
                                    transform=test_transform,
                                    target_transform=target_transform,
                                    logger=logger)
        self.test_ad_classes_idx = self._test_set.get_class_idx(
            self.ad_classes)
        self._test_set.targets = [  # t = nan if not in ad_classes else give id in order of ad_classes
            self.test_ad_classes_idx.index(t)
            if t in self.test_ad_classes_idx else np.nan
            for t in self._test_set.targets
        ]
        self._test_set = Subset(
            self._test_set,
            get_target_label_idx(np.asarray(self._test_set.targets),
                                 list(range(len(self.ad_classes)))))
        self._test_set.fixed_random_order = MyImageNet.fixed_random_order
Ejemplo n.º 3
0
    def __init__(self,
                 root: str,
                 normal_class: int,
                 preproc: str,
                 nominal_label: int,
                 supervise_mode: str,
                 noise_mode: str,
                 oe_limit: int,
                 online_supervision: bool,
                 logger: Logger = None):
        """
        This is a general-purpose implementation for custom datasets.
        It expects the data being contained in class folders and distinguishes between
        (1) the one-vs-rest (ovr) approach where one class is considered normal
        and is tested against all other classes being anomalous
        (2) the general approach where each class folder contains a normal data folder and an anomalous data folder.
        The :attr:`ovr` determines this.

        For (1) the data folders have to follow this structure:
        root/custom/train/dog/xxx.png
        root/custom/train/dog/xxy.png
        root/custom/train/dog/xxz.png

        root/custom/train/cat/123.png
        root/custom/train/cat/nsdf3.png
        root/custom/train/cat/asd932_.png

        For (2):
        root/custom/train/hazelnut/normal/xxx.png
        root/custom/train/hazelnut/normal/xxy.png
        root/custom/train/hazelnut/normal/xxz.png
        root/custom/train/hazelnut/anomalous/xxa.png    -- may be used during training for a semi-supervised setting

        root/custom/train/screw/normal/123.png
        root/custom/train/screw/normal/nsdf3.png
        root/custom/train/screw/anomalous/asd932_.png   -- may be used during training for a semi-supervised setting

        The same holds for the test set, where "train" has to be replaced by "test".

        :param root: root directory where data is found.
        :param normal_class: the class considered nominal.
        :param preproc: the kind of preprocessing pipeline.
        :param nominal_label: the label that marks nominal samples in training. The scores in the heatmaps always
            rate label 1, thus usually the nominal label is 0, s.t. the scores are anomaly scores.
        :param supervise_mode: the type of generated artificial anomalies.
            See :meth:`fcdd.datasets.bases.TorchvisionDataset._generate_artificial_anomalies_train_set`.
        :param noise_mode: the type of noise used, see :mod:`fcdd.datasets.noise_mode`.
        :param oe_limit: limits the number of different anomalies in case of Outlier Exposure (defined in noise_mode).
        :param online_supervision: whether to sample anomalies online in each epoch,
            or offline before training (same for all epochs in this case).
        :param logger: logger.
        """
        assert online_supervision, 'Artificial anomaly generation for custom datasets needs to be online'
        self.trainpath = pt.join(root, self.base_folder, 'train')
        self.testpath = pt.join(root, self.base_folder, 'test')
        super().__init__(root, logger=logger)

        self.n_classes = 2  # 0: normal, 1: outlier
        self.raw_shape = (3, 248, 248)
        self.shape = (
            3, 224, 224
        )  # shape of your data samples in channels x height x width after image preprocessing
        self.normal_classes = tuple([normal_class])
        self.outlier_classes = list(range(0,
                                          len(extract_custom_classes(root))))
        self.outlier_classes.remove(normal_class)
        assert nominal_label in [0, 1]
        self.nominal_label = nominal_label
        self.anomalous_label = 1 if self.nominal_label == 0 else 0

        # precomputed mean and std of your training data
        self.mean, self.std = self.extract_mean_std(self.trainpath,
                                                    normal_class)

        if preproc in ['', None, 'default', 'none']:
            test_transform = transform = transforms.Compose([
                transforms.Resize((self.shape[-2], self.shape[-1])),
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std)
            ])
        elif preproc in ['aug1']:
            test_transform = transforms.Compose([
                transforms.Resize((self.raw_shape[-1])),
                transforms.CenterCrop(self.shape[-1]),
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std)
            ])
            transform = transforms.Compose([
                transforms.Resize(self.raw_shape[-1]),
                transforms.ColorJitter(brightness=0.01,
                                       contrast=0.01,
                                       saturation=0.01,
                                       hue=0.01),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(self.shape[-1]),
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x + 0.001 * torch.randn_like(x)),
                transforms.Normalize(self.mean, self.std)
            ])
        #  here you could define other pipelines with augmentations
        else:
            raise ValueError(
                'Preprocessing pipeline {} is not known.'.format(preproc))

        self.target_transform = transforms.Lambda(
            lambda x: self.anomalous_label
            if x in self.outlier_classes else self.nominal_label)
        if supervise_mode not in ['unsupervised', 'other']:
            self.all_transform = OnlineSupervisor(self, supervise_mode,
                                                  noise_mode, oe_limit)
        else:
            self.all_transform = None

        self._train_set = ImageFolderDataset(
            self.trainpath,
            supervise_mode,
            self.raw_shape,
            self.ovr,
            self.nominal_label,
            self.anomalous_label,
            normal_classes=self.normal_classes,
            transform=transform,
            target_transform=self.target_transform,
            all_transform=self.all_transform,
        )
        if supervise_mode == 'other':  # (semi)-supervised setting
            self.balance_dataset()
        else:
            self._train_set = Subset(
                self._train_set,
                np.argwhere((np.asarray(self._train_set.anomaly_labels)
                             == self.nominal_label) *
                            np.isin(self._train_set.targets,
                                    self.normal_classes)).flatten().tolist())

        self._test_set = ImageFolderDataset(
            self.testpath,
            supervise_mode,
            self.raw_shape,
            self.ovr,
            self.nominal_label,
            self.anomalous_label,
            normal_classes=self.normal_classes,
            transform=test_transform,
            target_transform=self.target_transform,
        )
        if not self.ovr:
            self._test_set = Subset(
                self._test_set,
                get_target_label_idx(self._test_set.targets,
                                     np.asarray(self.normal_classes)))
Ejemplo n.º 4
0
    def __init__(self, root: str, normal_class: int, preproc: str, nominal_label: int,
                 supervise_mode: str, noise_mode: str, oe_limit: int, online_supervision: bool,
                 logger: Logger = None, raw_shape: int = 240):
        """
        AD dataset for MVTec-AD. If no MVTec data is found in the root directory,
        the data is downloaded and processed to be stored in torch tensors with appropriate size (defined in raw_shape).
        This speeds up data loading at the start of training.
        :param root: root directory where data is found or is to be downloaded to
        :param normal_class: the class considered nominal
        :param preproc: the kind of preprocessing pipeline
        :param nominal_label: the label that marks nominal samples in training. The scores in the heatmaps always
            rate label 1, thus usually the nominal label is 0, s.t. the scores are anomaly scores.
        :param supervise_mode: the type of generated artificial anomalies.
            See :meth:`fcdd.datasets.bases.TorchvisionDataset._generate_artificial_anomalies_train_set`.
        :param noise_mode: the type of noise used, see :mod:`fcdd.datasets.noise_mode`.
        :param oe_limit: limits the number of different anomalies in case of Outlier Exposure (defined in noise_mode)
        :param online_supervision: whether to sample anomalies online in each epoch,
            or offline before training (same for all epochs in this case).
        :param logger: logger
        :param raw_shape: the height and width of the raw MVTec images before passed through the preprocessing pipeline.
        """
        super().__init__(root, logger=logger)

        self.n_classes = 2  # 0: normal, 1: outlier
        self.shape = (3, 224, 224)
        self.raw_shape = (3,) + (raw_shape, ) * 2
        self.normal_classes = tuple([normal_class])
        self.outlier_classes = list(range(0, 15))
        self.outlier_classes.remove(normal_class)
        assert nominal_label in [0, 1], 'GT maps are required to be binary!'
        self.nominal_label = nominal_label
        self.anomalous_label = 1 if self.nominal_label == 0 else 0

        # min max after gcn l1 norm has been applied
        min_max_l1 = [
            [(-1.3336724042892456, -1.3107913732528687, -1.2445921897888184),
             (1.3779616355895996, 1.3779616355895996, 1.3779616355895996)],
            [(-2.2404820919036865, -2.3387579917907715, -2.2896201610565186),
             (4.573435306549072, 4.573435306549072, 4.573435306549072)],
            [(-3.184587001800537, -3.164201259613037, -3.1392977237701416),
             (1.6995097398757935, 1.6011602878570557, 1.5209171772003174)],
            [(-3.0334954261779785, -2.958242416381836, -2.7701096534729004),
             (6.503103256225586, 5.875098705291748, 5.814228057861328)],
            [(-3.100773334503174, -3.100773334503174, -3.100773334503174),
             (4.27892541885376, 4.27892541885376, 4.27892541885376)],
            [(-3.6565306186676025, -3.507692813873291, -2.7635035514831543),
             (18.966819763183594, 21.64590072631836, 26.408710479736328)],
            [(-1.5192601680755615, -2.2068002223968506, -2.3948357105255127),
             (11.564697265625, 10.976534843444824, 10.378695487976074)],
            [(-1.3207964897155762, -1.2889339923858643, -1.148416519165039),
             (6.854909896850586, 6.854909896850586, 6.854909896850586)],
            [(-0.9883341193199158, -0.9822461605072021, -0.9288841485977173),
             (2.290637969970703, 2.4007883071899414, 2.3044068813323975)],
            [(-7.236185073852539, -7.236185073852539, -7.236185073852539),
             (3.3777384757995605, 3.3777384757995605, 3.3777384757995605)],
            [(-3.2036616802215576, -3.221003532409668, -3.305514335632324),
             (7.022546768188477, 6.115569114685059, 6.310940742492676)],
            [(-0.8915618658065796, -0.8669204115867615, -0.8002046346664429),
             (4.4255571365356445, 4.642300128936768, 4.305730819702148)],
            [(-1.9086798429489136, -2.0004451274871826, -1.929288387298584),
             (5.463134765625, 5.463134765625, 5.463134765625)],
            [(-2.9547364711761475, -3.17536997795105, -3.143850803375244),
             (5.305514812469482, 4.535006523132324, 3.3618252277374268)],
            [(-1.2906527519226074, -1.2906527519226074, -1.2906527519226074),
             (2.515115737915039, 2.515115737915039, 2.515115737915039)]
        ]

        # mean and std of original images per class
        mean = [
            (0.53453129529953, 0.5307118892669678, 0.5491130352020264),
            (0.326835036277771, 0.41494372487068176, 0.46718254685401917),
            (0.6953922510147095, 0.6663950085639954, 0.6533040404319763),
            (0.36377236247062683, 0.35087138414382935, 0.35671544075012207),
            (0.4484519958496094, 0.4484519958496094, 0.4484519958496094),
            (0.2390524297952652, 0.17620408535003662, 0.17206747829914093),
            (0.3919542133808136, 0.2631213963031769, 0.22006843984127045),
            (0.21368788182735443, 0.23478130996227264, 0.24079132080078125),
            (0.30240726470947266, 0.3029524087905884, 0.32861486077308655),
            (0.7099748849868774, 0.7099748849868774, 0.7099748849868774),
            (0.4567880630493164, 0.4711957275867462, 0.4482630491256714),
            (0.19987481832504272, 0.18578395247459412, 0.19361256062984467),
            (0.38699793815612793, 0.276934415102005, 0.24219433963298798),
            (0.6718143820762634, 0.47696375846862793, 0.35050269961357117),
            (0.4014520049095154, 0.4014520049095154, 0.4014520049095154)
        ]
        std = [
            (0.3667600452899933, 0.3666728734970093, 0.34991779923439026),
            (0.15321789681911469, 0.21510766446590424, 0.23905669152736664),
            (0.23858436942100525, 0.2591284513473511, 0.2601949870586395),
            (0.14506031572818756, 0.13994529843330383, 0.1276693195104599),
            (0.1636597216129303, 0.1636597216129303, 0.1636597216129303),
            (0.1688646823167801, 0.07597383111715317, 0.04383210837841034),
            (0.06069392338395119, 0.04061736911535263, 0.0303945429623127),
            (0.1602524220943451, 0.18222476541996002, 0.15336430072784424),
            (0.30409011244773865, 0.30411985516548157, 0.28656429052352905),
            (0.1337062269449234, 0.1337062269449234, 0.1337062269449234),
            (0.12076705694198608, 0.13341768085956573, 0.12879984080791473),
            (0.22920562326908112, 0.21501320600509644, 0.19536510109901428),
            (0.20621345937252045, 0.14321941137313843, 0.11695228517055511),
            (0.08259467780590057, 0.06751163303852081, 0.04756828024983406),
            (0.32304847240448, 0.32304847240448, 0.32304847240448)
        ]

        # different types of preprocessing pipelines, 'lcn' is for using LCN, 'aug{X}' for augmentations
        img_gt_transform, img_gt_test_transform = None, None
        all_transform = []
        if preproc == 'lcn':
            assert self.raw_shape == self.shape, 'in case of no augmentation, raw shape needs to fit net input shape'
            img_gt_transform = img_gt_test_transform = MultiCompose([
                transforms.ToTensor(),
            ])
            test_transform = transform = transforms.Compose([
                transforms.Lambda(lambda x: local_contrast_normalization(x, scale='l1')),
                transforms.Normalize(
                    min_max_l1[normal_class][0],
                    [ma - mi for ma, mi in zip(min_max_l1[normal_class][1], min_max_l1[normal_class][0])]
                )
            ])
        elif preproc in ['', None, 'default', 'none']:
            assert self.raw_shape == self.shape, 'in case of no augmentation, raw shape needs to fit net input shape'
            img_gt_transform = img_gt_test_transform = MultiCompose([
                transforms.ToTensor(),
            ])
            test_transform = transform = transforms.Compose([
                transforms.Normalize(mean[normal_class], std[normal_class])
            ])
        elif preproc in ['aug1']:
            img_gt_transform = MultiCompose([
                transforms.RandomChoice(
                    [transforms.RandomCrop(self.shape[-1], padding=0), transforms.Resize(self.shape[-1], Image.NEAREST)]
                ),
                transforms.ToTensor(),
            ])
            img_gt_test_transform = MultiCompose(
                [transforms.Resize(self.shape[-1], Image.NEAREST), transforms.ToTensor()]
            )
            test_transform = transforms.Compose([
                transforms.Normalize(mean[normal_class], std[normal_class])
            ])
            transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomChoice([
                    transforms.ColorJitter(0.04, 0.04, 0.04, 0.04),
                    transforms.ColorJitter(0.005, 0.0005, 0.0005, 0.0005),
                ]),
                transforms.ToTensor(),
                transforms.Lambda(
                    lambda x: (x + torch.randn_like(x).mul(np.random.randint(0, 2)).mul(x.std()).mul(0.1)).clamp(0, 1)
                ),
                transforms.Normalize(mean[normal_class], std[normal_class])
            ])
        elif preproc in ['lcnaug1']:
            img_gt_transform = MultiCompose([
                transforms.RandomChoice(
                    [transforms.RandomCrop(self.shape[-1], padding=0), transforms.Resize(self.shape[-1], Image.NEAREST)]
                ),
                transforms.ToTensor(),
            ])
            img_gt_test_transform = MultiCompose(
                [transforms.Resize(self.shape[-1], Image.NEAREST), transforms.ToTensor()]
            )
            test_transform = transforms.Compose([
                transforms.Lambda(lambda x: local_contrast_normalization(x, scale='l1')),
                transforms.Normalize(
                    min_max_l1[normal_class][0],
                    [ma - mi for ma, mi in zip(min_max_l1[normal_class][1], min_max_l1[normal_class][0])]
                )
            ])
            transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomChoice([
                        transforms.ColorJitter(0.04, 0.04, 0.04, 0.04),
                        transforms.ColorJitter(0.005, 0.0005, 0.0005, 0.0005),
                ]),
                transforms.ToTensor(),
                transforms.Lambda(
                    lambda x: (x + torch.randn_like(x).mul(np.random.randint(0, 2)).mul(x.std()).mul(0.1)).clamp(0, 1)
                ),
                transforms.Lambda(lambda x: local_contrast_normalization(x, scale='l1')),
                transforms.Normalize(
                    min_max_l1[normal_class][0],
                    [ma - mi for ma, mi in zip(min_max_l1[normal_class][1], min_max_l1[normal_class][0])]
                )
            ])
        else:
            raise ValueError('Preprocessing pipeline {} is not known.'.format(preproc))

        target_transform = transforms.Lambda(
            lambda x: self.anomalous_label if x in self.outlier_classes else self.nominal_label
        )

        if online_supervision:
            # order: target_transform -> all_transform -> img_gt transform -> transform
            assert supervise_mode not in ['supervised'], 'supervised mode works only offline'
            all_transform = MultiCompose([
                *all_transform,
                OnlineSupervisor(self, supervise_mode, noise_mode, oe_limit),
            ])
            train_set = MvTec(
                root=self.root, split='train', download=True,
                target_transform=target_transform,
                img_gt_transform=img_gt_transform, transform=transform, all_transform=all_transform,
                shape=self.raw_shape, normal_classes=self.normal_classes,
                nominal_label=self.nominal_label, anomalous_label=self.anomalous_label,
                enlarge=ADMvTec.enlarge
            )
            self._train_set = GTSubset(
                train_set, get_target_label_idx(train_set.targets.clone().data.cpu().numpy(), self.normal_classes)
            )
            test_set = MvTec(
                root=self.root, split='test_anomaly_label_target', download=True,
                target_transform=transforms.Lambda(
                    lambda x: self.anomalous_label if x != MvTec.normal_anomaly_label_idx else self.nominal_label
                ),
                img_gt_transform=img_gt_test_transform, transform=test_transform, shape=self.raw_shape,
                normal_classes=self.normal_classes,
                nominal_label=self.nominal_label, anomalous_label=self.anomalous_label,
                enlarge=False
            )
            test_idx_normal = get_target_label_idx(test_set.targets.clone().data.cpu().numpy(), self.normal_classes)
            self._test_set = GTSubset(test_set, test_idx_normal)
        else:
            all_transform = MultiCompose([
                *all_transform,
            ]) if len(all_transform) > 0 else None
            train_set = MvTec(
                root=self.root, split='train', download=True,
                target_transform=target_transform, all_transform=all_transform,
                img_gt_transform=img_gt_transform, transform=transform, shape=self.raw_shape,
                normal_classes=self.normal_classes,
                nominal_label=self.nominal_label, anomalous_label=self.anomalous_label,
                enlarge=ADMvTec.enlarge
            )
            test_set = MvTec(
                root=self.root, split='test_anomaly_label_target', download=True,
                target_transform=transforms.Lambda(
                    lambda x: self.anomalous_label if x != MvTec.normal_anomaly_label_idx else self.nominal_label
                ),
                img_gt_transform=img_gt_test_transform, transform=test_transform, shape=self.raw_shape,
                normal_classes=self.normal_classes,
                nominal_label=self.nominal_label, anomalous_label=self.anomalous_label, enlarge=False
            )
            test_idx_normal = get_target_label_idx(test_set.targets.clone().data.cpu().numpy(), self.normal_classes)
            self._test_set = GTSubset(test_set, test_idx_normal)
            self._generate_artificial_anomalies_train_set(supervise_mode, noise_mode, oe_limit, train_set, normal_class)
Ejemplo n.º 5
0
    def __init__(self,
                 size: torch.Size,
                 clsses: List[int],
                 root: str = None,
                 limit_var: int = np.infty,
                 limit_per_anomaly=True,
                 download=True,
                 logger: Logger = None,
                 gt=False,
                 remove_nominal=True):
        """
        Outlier Exposure dataset for MVTec-AD. Considers only a part of the classes.
        :param size: size of the samples in n x c x h x w, samples will be resized to h x w. If n is larger than the
            number of samples available in MVTec-AD, dataset will be enlarged by repetitions to fit n.
            This is important as exactly n images are extracted per iteration of the data_loader.
            For online supervision n should be set to 1 because only one sample is extracted at a time.
        :param clsses: the classes that are to be considered, i.e. all other classes are dismissed.
        :param root: root directory where data is found or is to be downloaded to.
        :param limit_var: limits the number of different samples, i.e. randomly chooses limit_var many samples
            from all available ones to be the training data.
        :param limit_per_anomaly: whether limit_var limits the number of different samples per type
            of defection or overall.
        :param download: whether to download the data if it is not found in root.
        :param logger: logger.
        :param gt: whether ground-truth maps are to be included in the data.
        :param remove_nominal: whether nominal samples are to be excluded from the data.
        """
        assert len(size) == 4 and size[2] == size[3]
        assert size[1] in [1, 3]
        self.root = root
        self.logger = logger
        self.size = size
        self.use_gt = gt
        self.clsses = clsses
        super().__init__(root,
                         'test',
                         download=download,
                         shape=size[1:],
                         logger=logger)

        self.img_gt_transform = MultiCompose(
            [transforms.Resize((size[2], size[2])),
             transforms.ToTensor()])
        self.picks = get_target_label_idx(self.targets, self.clsses)
        if remove_nominal:
            self.picks = sorted(
                list(
                    set.intersection(
                        set(self.picks),
                        set((self.anomaly_labels !=
                             self.normal_anomaly_label_idx
                             ).nonzero().squeeze().tolist()))))
        if limit_per_anomaly and limit_var is not None:
            new_picks = []
            for l in set(self.anomaly_labels.tolist()):
                linclsses = list(
                    set.intersection(
                        set(self.picks),
                        set((self.anomaly_labels == l
                             ).nonzero().squeeze().tolist())))
                if len(linclsses) == 0:
                    continue
                if limit_var < len(linclsses):
                    new_picks.extend(
                        np.random.choice(linclsses,
                                         size=limit_var,
                                         replace=False))
                else:
                    self.logprint(
                        'OEMvTec shall be limited to {} samples per anomaly label, '
                        'but MvTec anomaly label {} contains only {} samples, thus using all.'
                        .format(limit_var, self.anomaly_label_strings[l],
                                len(linclsses)),
                        fps=False)
                    new_picks.extend(linclsses)
            self.picks = sorted(new_picks)
        else:
            if limit_var is not None and limit_var < len(self):
                self.picks = np.random.choice(self.picks,
                                              size=limit_var,
                                              replace=False)
            if limit_var is not None and limit_var > len(self):
                self.logprint(
                    'OEMvTec shall be limited to {} samples, but MvTec contains only {} samples, thus using all.'
                    .format(limit_var, len(self)))
        if len(self) < size[0]:
            raise NotImplementedError()