Esempio n. 1
0
    def __init__(self, root, split='train', **kwargs):
        root = self.root = os.path.expanduser(root)
        self.split = verify_str_arg(split, "split", ("train", "val"))

        self.parse_archives()
        wnid_to_classes = load_meta_file(self.root)[0]

        super().__init__(self.split_folder, **kwargs)
        self.root = root

        self.wnids = self.classes
        self.wnid_to_idx = self.class_to_idx
        self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
        self.class_to_idx = {
            cls: idx
            for idx, clss in enumerate(self.classes) for cls in clss
        }
Esempio n. 2
0
    def __init__(self, root: str, split: str = 'train', target_transform: Callable = None,
                 img_gt_transform: Callable = None, transform: Callable = None, all_transform: Callable = None,
                 download=True, shape=(3, 300, 300), normal_classes=(), nominal_label=0, anomalous_label=1,
                 logger: Logger = None, enlarge: bool = False
                 ):
        """
        Loads all data from the prepared torch tensors. If such torch tensors containg MVTec data are not found
        in the given root directory, instead downloads the raw data and prepares the tensors.
        They contain labels, images, and ground-truth maps for a fixed size, determined by the shape parameter.
        :param root: directory where the data is to be found.
        :param split: whether to use "train", "test", or "test_anomaly_label_target" data.
            In the latter case the get_item method returns labels indexing the anomalous class rather than
            the object class. That is, instead of returning 0 for "bottle", it returns "1" for "large_broken".
        :param target_transform: function that takes label and transforms it somewhat.
            Target transform is the first transform that is applied.
        :param img_gt_transform: function that takes image and ground-truth map and transforms it somewhat.
            Useful to apply the same augmentation to image and ground-truth map (e.g. cropping), s.t.
            the ground-truth map still matches the image.
            ImgGt transform is the third transform that is applied.
        :param transform: function that takes image and transforms it somewhat.
            Transform is the last transform that is applied.
        :param all_transform: function that takes image, label, and ground-truth map and transforms it somewhat.
            All transform is the second transform that is applied.
        :param download: whether to download if data is not found in root.
        :param shape: the shape (c x h x w) the data should be resized to (images and ground-truth maps).
        :param normal_classes: all the classes that are considered nominal (usually just one).
        :param nominal_label: the label that is to be returned to mark nominal samples.
        :param anomalous_label: the label that is to be returned to mark anomalous samples.
        :param logger: logger
        :param enlarge: whether to enlarge the dataset, i.e. repeat all data samples ten times.
            Consequently, one iteration (epoch) of the data loader returns ten times as many samples.
            This speeds up loading because the MVTec-AD dataset has a poor number of samples and
            PyTorch requires additional work in between epochs.
        """
        super(MvTec, self).__init__(root, transform=transform, target_transform=target_transform)
        self.split = verify_str_arg(split, "split", ("train", "test", "test_anomaly_label_target"))
        self.img_gt_transform = img_gt_transform
        self.all_transform = all_transform
        self.shape = shape
        self.orig_gtmaps = None
        self.normal_classes = normal_classes
        self.nominal_label = nominal_label
        self.anom_label = anomalous_label
        self.logger = logger
        self.enlarge = enlarge

        if download:
            self.download(shape=self.shape[1:])

        print('Loading dataset from {}...'.format(self.data_file))
        dataset_dict = torch.load(self.data_file)
        self.anomaly_label_strings = dataset_dict['anomaly_label_strings']
        if self.split == 'train':
            self.data, self.targets = dataset_dict['train_data'], dataset_dict['train_labels']
            self.gt, self.anomaly_labels = None, None
        else:
            self.data, self.targets = dataset_dict['test_data'], dataset_dict['test_labels']
            self.gt, self.anomaly_labels = dataset_dict['test_maps'], dataset_dict['test_anomaly_labels']

        if self.enlarge:
            self.data, self.targets = self.data.repeat(10, 1, 1, 1), self.targets.repeat(10)
            self.gt = self.gt.repeat(10, 1, 1) if self.gt is not None else None
            self.anomaly_labels = self.anomaly_labels.repeat(10) if self.anomaly_labels is not None else None
            self.orig_gtmaps = self.orig_gtmaps.repeat(10, 1, 1) if self.orig_gtmaps is not None else None

        if self.nominal_label != 0:
            print('Swapping labels, i.e. anomalies are 0 and nominals are 1, same for GT maps.')
            assert -3 not in [self.nominal_label, self.anom_label]
        print('Dataset complete.')