Beispiel #1
0
    def download_url(url, root, filename=None):
        """Download a file from a url and place it in root.
        Args:
            url (str): URL to download file from
            root (str): Directory to place downloaded file in
            filename (str, optional): Name to save the file under. If None, use the basename of the URL
        """
        from six.moves import urllib

        root = os.path.expanduser(root)
        if not filename:
            filename = os.path.basename(url)
        fpath = os.path.join(root, filename)

        os.makedirs(root, exist_ok=True)

        def gen_bar_updater():
            pbar = tqdm(total=None)

            def bar_update(count, block_size, total_size):
                if pbar.total is None and total_size:
                    pbar.total = total_size
                progress_bytes = count * block_size
                pbar.update(progress_bytes - pbar.n)

            return bar_update

        # check if file is already present locally
        if not check_integrity(fpath, None):
            try:
                print('Downloading ' + url + ' to ' + fpath)
                urllib.request.urlretrieve(url,
                                           fpath,
                                           reporthook=gen_bar_updater())
            except (urllib.error.URLError, IOError) as e:
                if url[:5] == 'https':
                    url = url.replace('https:', 'http:')
                    print('Failed download. Trying https -> http instead.'
                          ' Downloading ' + url + ' to ' + fpath)
                    urllib.request.urlretrieve(url,
                                               fpath,
                                               reporthook=gen_bar_updater())
                else:
                    raise e
            # check integrity of downloaded file
            if not check_integrity(fpath, None):
                raise RuntimeError("File not found or corrupted.")
Beispiel #2
0
    def parse_archives(self):
        if not check_integrity(os.path.join(self.root, META_FILE)):
            parse_devkit_archive(self.root)

        if not os.path.isdir(self.split_folder):
            if self.split == 'train':
                parse_train_archive(self.root)
            elif self.split == 'val':
                parse_val_archive(self.root)
Beispiel #3
0
    def download_and_extract_archive(url, download_root, extract_root=None, filename=None,
                                     md5=None, remove_finished=False):
        download_root = os.path.expanduser(download_root)
        if extract_root is None:
            extract_root = download_root
        if not filename:
            filename = os.path.basename(url)
        if not os.path.exists(download_root):
            os.makedirs(download_root)
        if not check_integrity(os.path.join(download_root, filename)):
            download_url(url, download_root, filename, md5)

        archive = os.path.join(download_root, filename)
        print("Extracting {} to {}".format(archive, extract_root))
        MvTec.extract_archive(archive, extract_root, remove_finished)
Beispiel #4
0
    def download(self, verbose=True, shape=None, cls=None):
        assert shape is not None or cls is not None, 'original shape requires a class'
        if not check_integrity(self.data_file if shape is not None else self.orig_data_file(cls)):
            tmp_dir = tempfile.mkdtemp()
            self.download_and_extract_archive(
                self.url, os.path.join(self.root, self.base_folder), extract_root=tmp_dir,
            )
            train_data, train_labels = [], []
            test_data, test_labels, test_maps, test_anomaly_labels = [], [], [], []
            anomaly_labels, albl_idmap = [], {self.normal_anomaly_label: self.normal_anomaly_label_idx}

            for lbl_idx, lbl in enumerate(self.labels if cls is None else [self.labels[cls]]):
                if verbose:
                    print('Processing data for label {}...'.format(lbl))
                for anomaly_label in sorted(os.listdir(os.path.join(tmp_dir, lbl, 'test'))):
                    for img_name in sorted(os.listdir(os.path.join(tmp_dir, lbl, 'test', anomaly_label))):
                        with open(os.path.join(tmp_dir, lbl, 'test', anomaly_label, img_name), 'rb') as f:
                            sample = Image.open(f)
                            sample = self.img_to_torch(sample, shape)
                        if anomaly_label != self.normal_anomaly_label:
                            mask_name = self.convert_img_name_to_mask_name(img_name)
                            with open(os.path.join(tmp_dir, lbl, 'ground_truth', anomaly_label, mask_name), 'rb') as f:
                                mask = Image.open(f)
                                mask = self.img_to_torch(mask, shape)
                        else:
                            mask = torch.zeros_like(sample)
                        test_data.append(sample)
                        test_labels.append(cls if cls is not None else lbl_idx)
                        test_maps.append(mask)
                        if anomaly_label not in albl_idmap:
                            albl_idmap[anomaly_label] = len(albl_idmap)
                        test_anomaly_labels.append(albl_idmap[anomaly_label])

                for anomaly_label in sorted(os.listdir(os.path.join(tmp_dir, lbl, 'train'))):
                    for img_name in sorted(os.listdir(os.path.join(tmp_dir, lbl, 'train', anomaly_label))):
                        with open(os.path.join(tmp_dir, lbl, 'train', anomaly_label, img_name), 'rb') as f:
                            sample = Image.open(f)
                            sample = self.img_to_torch(sample, shape)
                        train_data.append(sample)
                        train_labels.append(lbl_idx)

            anomaly_labels = list(zip(*sorted(albl_idmap.items(), key=lambda kv: kv[1])))[0]
            train_data = torch.stack(train_data)
            train_labels = torch.IntTensor(train_labels)
            test_data = torch.stack(test_data)
            test_labels = torch.IntTensor(test_labels)
            test_maps = torch.stack(test_maps)[:, 0, :, :]  # r=g=b -> grayscale
            test_anomaly_labels = torch.IntTensor(test_anomaly_labels)
            torch.save(
                {'train_data': train_data, 'train_labels': train_labels,
                 'test_data': test_data, 'test_labels': test_labels,
                 'test_maps': test_maps, 'test_anomaly_labels': test_anomaly_labels,
                 'anomaly_label_strings': anomaly_labels},
                self.data_file if shape is not None else self.orig_data_file(cls)
            )

            # cleanup temp directory
            for dirpath, dirnames, filenames in os.walk(tmp_dir):
                os.chmod(dirpath, 0o755)
                for filename in filenames:
                    os.chmod(os.path.join(dirpath, filename), 0o755)
            shutil.rmtree(tmp_dir)
        else:
            print('Files already downloaded.')
            return