def main():
    p = argparse.ArgumentParser()
    p.add_argument('raw')
    p.add_argument('compressed', nargs='*')
    flags = p.parse_args()
    if ';' in flags.raw:
        flags.raw, flags.compressed = flags.raw.split(';')
    else:
        flags.compressed = flags.compressed[0]

    raw = cached_listdir_imgs(flags.raw, discard_shitty=False)
    compressed = cached_listdir_imgs(flags.compressed, discard_shitty=False)
    print(raw, compressed)

    print('Average bpp', np.mean([Compressor.bpp_from_compressed_file(p) for p in compressed.ps]))

    r = ResidualDataset(compressed, raw, 'diff', 256, True, output_device=pe.DEVICE)

    # for i in range(len(r)):
    #     print(r[i]['residual'].unique())
    # exit(1)
    d = DataLoader(r, batch_size=10, shuffle=False, num_workers=2)
    mi, ma = None, None
    for b in d:
        res = b['raw']
        # print(res.unique())
        if mi is None:
            mi = res.min()
            ma = res.max()
        else:
            mi = min(res.min(), mi)
            ma = max(res.max(), ma)
        print(mi, ma)
def make_clf_training_set(training_set_dir):
    V = cached_listdir_imgs(training_set_dir,
                            min_size=512,
                            discard_shitty=False)
    clf_training_set_filenames = get_clf_training_set_filenames()

    # Make sure we have them all
    ps = set(map(os.path.basename, V.ps))
    missing = set(clf_training_set_filenames) - ps
    if missing:
        print(f'ERROR: Not all files found, missing {missing}!')
        sys.exit(1)

    # Create the subset folder
    out_dir = training_set_dir.rstrip(os.path.sep) + '_subset_clf'
    print(f'Creating {out_dir}...')
    os.makedirs(out_dir, exist_ok=True)

    print_every = max(len(clf_training_set_filenames) // 20,
                      1)  # Update every 5%
    for i, filename in enumerate(clf_training_set_filenames):
        if i > 0 and i % print_every == 0:
            percent = i / len(clf_training_set_filenames) * 100
            print(f'Update: {percent:.1f}% copied')
        in_p = os.path.join(training_set_dir, filename)
        out_p = os.path.join(out_dir, filename)
        if not os.path.isfile(out_p):
            shutil.copy(in_p, out_p)

    print('Caching files...')
    make_cache_fast(out_dir)

    print(f'\nSubfolder created at {out_dir}. Now run:\n'
          f'bash prep_bpg_ds.sh A11_17 {out_dir}')
Exemple #3
0
def make_bin_pkl(img_folder, num_bins, overwrite=False):
    if isinstance(img_folder, list):
        img_folder = img_folder[0]
    pkl_out = get_default_pkl_p(img_folder, num_bins)
    if os.path.isfile(pkl_out):
        return pkl_out

    assert 'train_oi' in img_folder, img_folder  # currently not supported, bc min_size and discard_shitty flags
    ps = cached_listdir_imgs.cached_listdir_imgs(img_folder,
                                                 min_size=None,
                                                 discard_shitty=True).ps

    ps_bpsps = sorted(
        ((p, Compressor.bpp_from_compressed_file(p) / 3) for p in ps),
        key=lambda xy: xy[1])  # sort by bpsp
    bpsps = [bpsp for _, bpsp in ps_bpsps]

    # border     b0    b1     ...   bk         k+1 borders
    # bin_idx  0    1      2  ... k    k+1  => k+2 bins
    #
    # for N bins, we need N-1 borders
    # first border is after 1/NB-th of the data

    # NB + 1 so that we get NB-1 evenly spaced bins *within* the data
    bin_borders_x = np.linspace(0, len(bpsps) - 1, num_bins + 1, dtype=np.int)
    # throw away the boundaries
    bin_borders_x = bin_borders_x[1:-1]
    bin_borders_b = [bpsps[x] for x in bin_borders_x]

    with open(pkl_out, 'wb') as f:
        print('Saving', bin_borders_b, '\n->', pkl_out)
        pickle.dump(bin_borders_b, f)

    return pkl_out
Exemple #4
0
def _get(img_dir, clf_ckpt_p):
    out_file = os.path.join(img_dir, OPTIMAL_QS_TXT)

    clf = load_classifier(clf_ckpt_p)

    t = timer.TimeAccumulator()
    opt = {}
    for i, p in enumerate(
            cached_listdir_imgs(img_dir, discard_shitty=False).ps):
        with t.execute():
            img = torch.from_numpy(np.array(Image.open(p))).to(
                pe.DEVICE, non_blocking=True).permute(2, 0, 1)
            assert img.shape[0] == 3
            img = img.unsqueeze(0)
            img = SymbolTensor(img.long(), L=256).to_norm().get()
            q = clf.get_q(img)
            opt[os.path.splitext(os.path.basename(p))[0]] = q
        if i > 0 and i % 10 == 0:
            print(i, t.mean_time_spent())
    with open(out_file, 'w', newline='') as csvfile:
        w = csv.writer(csvfile, quoting=csv.QUOTE_MINIMAL)
        w.writerow(['fn', 'q'])
        for filename, q in sorted(opt.items()):
            w.writerow([filename, q])
    print('Created', out_file)
    def _get_ds_val(self, imgs_dir_val, crop=False):
        # ---
        # ds_syn = global_config.get('ds_syn', None)
        # if ds_syn:
        #     ds_val = self._get_syn(ds_syn, truncate)
        #     return ds_val
        # ---
        if self.config_dl.is_residual_dataset:
            return get_residual_dataset(
                self.config_dl.imgs_dir_val,
                random_transforms=False,
                random_scale=False,
                crop_size=crop,
                mode='both' if self.style == 'enhancement' else 'diff',
                discard_shitty=self.config_dl.discard_shitty_val)
        else:
            assert self.style != 'enhancement', 'style == enhancement -> expected residual dataset'

        img_to_tensor_t = shared.get_test_dataset_transform(crop)

        ds = images_loader.IndexImagesDataset(
            images=cached_listdir_imgs(imgs_dir_val,
                                       min_size=self.config_dl.crop_size,
                                       discard_shitty=True),
            to_tensor_transform=img_to_tensor_t)

        return ds
Exemple #6
0
def get_residual_dataset_jpg(raw_images_dir, random_transforms: bool,
                             random_scale, crop_size: int, max_imgs=None,
                             discard_shitty=True,
                             #filter_min_size=None, top_only=None,
                             #is_training=False, sort=False
                             ):
    imgs = cached_listdir_imgs(raw_images_dir, min_size=512, discard_shitty=discard_shitty).sort()
    return JPGDataset(imgs,
                      random_crops=crop_size if random_transforms else None,
                      random_flips=random_transforms,
                      random_scale=random_scale,
                      center_crops=crop_size if not random_transforms else None,
                      max_imgs=max_imgs)
    def process_all_in(self, input_dir, filter_imgs_dir):
        images_dl = iter_images(input_dir)  # generator of paths

        # files this job should comperss
        files_of_job = [p for _, p in job_enumerate(images_dl)]
        # files that were compressed already by somebody (i.e. this job earlier)
        processed_already = self.images_cleaned | self.images_discarded
        # resulting files to be compressed
        files_of_job = [
            p for p in files_of_job if get_fn(p) not in processed_already
        ]

        if filter_imgs_dir:
            ps_orig = cached_listdir_imgs(filter_imgs_dir,
                                          discard_shitty=True).ps
            fns_to_use = set(map(get_fn, ps_orig))
            print('Filtering with', len(fns_to_use), 'filenames. Before:',
                  len(files_of_job))
            files_of_job = [p for p in files_of_job if get_fn(p) in fns_to_use]
            print('Filtered, now', len(files_of_job))

        N = len(files_of_job)
        if N == 0:
            print('Everything processed / nothing to process.')
            return

        num_process = 2 if NUM_TASKS > 1 else int(
            os.environ.get('MAX_PROCESS', 16))
        print(
            f'Processing {N} images using {num_process} processes in {NUM_TASKS} tasks...'
        )

        start = time.time()
        predicted_time = None
        with multiprocessing.Pool(processes=num_process) as pool:
            for i, clean in enumerate(
                    pool.imap_unordered(self.process, files_of_job)):
                if i > 0 and i % 100 == 0:
                    time_per_img = (time.time() - start) / (i + 1)
                    time_remaining = time_per_img * (N - i)
                    if not predicted_time:
                        predicted_time = time_remaining
                    print(
                        f'\r{time_per_img:.2e} s/img | {i / N * 100:.1f}% | {time_remaining / 60:.1f} min remaining',
                        end='',
                        flush=True)
        if predicted_time:
            print(
                f'Actual time: {(time.time() - start) / 60:.1f} // predicted {predicted_time / 60:.1f}'
            )
def create(indir, num_tars):
    ds = cached_listdir_imgs.cached_listdir_imgs(indir,
                                                 min_size=None,
                                                 discard_shitty=False)
    paths = ds.ps
    paths_per_tar = iterable_ext.chunks(paths, num_chunks=num_tars)

    print(f'Packing {len(paths)} images into {num_tars} tar files...')
    with multiprocessing.Pool(processes=num_tars) as pool:
        completed = 0
        for out_p in pool.imap_unordered(_pack_as_tar,
                                         enumerate(paths_per_tar)):
            completed += 1
            print(f'Done: {out_p} ({completed}/{num_tars} done)')
 def __init__(self, images_dir, optimal_qs_csv, to_tensor_transform):
     images = cached_listdir_imgs(images_dir, discard_shitty=False)
     self.optimal_qs = QHistory.read_q_history(optimal_qs_csv)
     assert self.optimal_qs, optimal_qs_csv
     print('Read optimal Q for', len(self.optimal_qs), 'images.')
     missing = {
         p
         for p in images.ps
         if os.path.splitext(os.path.basename(p))[0] not in self.optimal_qs
     }
     if missing:
         print(f'Missing files in {optimal_qs_csv}: {len(missing)}')
     image_ps = list(set(images.ps) - missing)
     assert len(image_ps) > 0, (images_dir, optimal_qs_csv)
     print(f'--> Using {len(image_ps)} images!')
     self.images_ds = IndexImagesDataset(Images(image_ps, images.id),
                                         to_tensor_transform)
     self.id = f'{images.id}_Q{os.path.basename(optimal_qs_csv)}'
    def get_ds_train(self):
        """
        Dataset must return dicts with {'idx', 'raw'}, where 'raw' is 3HW uint8
        """
        if self.config_dl.is_residual_dataset:
            return get_residual_dataset(
                self.config_dl.imgs_dir_train,
                random_transforms=True,
                random_scale=self.config_dl.random_scale,
                crop_size=self.config_dl.crop_size,
                mode='both' if self.style == 'enhancement' else 'diff',
                discard_shitty=self.config_dl.discard_shitty_train,
                filter_min_size=self.config_dl.train_filter_min_size,
                top_only=global_config.get('top_only', None),
                is_training=True)
        else:
            assert self.style != 'enhancement', 'style == enhancement -> expected residual dataset'

        to_tensor_transform = transforms.Compose([
            transforms.RandomCrop(self.config_dl.crop_size),
            transforms.RandomHorizontalFlip(),
            images_loader.IndexImagesDataset.to_tensor_uint8_transform()
        ])

        if global_config.get('ycbcr', False):
            print('Adding ->YCbCr')
            t = transforms.Lambda(lambda pil_img: pil_img.convert('YCbCr'))
            to_tensor_transform.transforms.insert(-2, t)

        ds_syn = global_config.get('ds_syn', None)
        if ds_syn:
            ds_train = self._get_syn(ds_syn, 30 * 10000)
        else:
            ds_train = images_loader.IndexImagesDataset(
                images=cached_listdir_imgs(
                    self.config_dl.imgs_dir_train,
                    min_size=self.config_dl.crop_size,
                    discard_shitty=self.config_dl.discard_shitty_train),
                to_tensor_transform=to_tensor_transform)
        return ds_train
Exemple #11
0
def compress(compressor: Compressor, indir, discard_shitty):
    ds = cached_listdir_imgs.cached_listdir_imgs(indir,
                                                 min_size=None,
                                                 discard_shitty=discard_shitty)
    imgs = ds.ps
    compressor.optimal_qs = get_optimal_qs.read(indir)
    if compressor.optimal_qs:
        print('Optimal Qs:', len(compressor.optimal_qs))
    assert len(imgs) > 0, f'no matches for {indir}'
    num_imgs_to_process = len(imgs) // task_array.NUM_TASKS

    images_of_job = [p for _, p in task_array.job_enumerate(imgs)]
    N_orig = len(images_of_job)
    images_of_job = [
        p for p in images_of_job if os.path.splitext(os.path.basename(p))[0]
        not in compressor.files_that_exist
    ]
    N = len(images_of_job)

    start = time.time()
    num_process = 2 if task_array.NUM_TASKS > 1 else int(
        os.environ.get('MAX_PROCESS', 16))
    print(
        f'{task_array.JOB_ID}:',
        f'Compressing {N}/{N_orig} images ({ds.id}) using {num_process} processes',
        f'in {task_array.NUM_TASKS} tasks...')

    with multiprocessing.Pool(processes=num_process) as pool:
        for i, _ in enumerate(
                pool.imap_unordered(compressor.compress, images_of_job)):
            if i > 0 and i % 5 == 0:
                time_per_img = (time.time() - start) / (i + 1)
                time_remaining = time_per_img * (N - i)
                print(
                    f'\r{time_per_img*num_process:.2e} s/img | '
                    f'{i / N * 100:.1f}% | {time_remaining / 60:.1f} min remaining',
                    end='',
                    flush=True)
def get_residual_dataset(imgs_dir, random_transforms: bool,
                         random_scale, crop_size: int, mode: str, max_imgs=None,
                         discard_shitty=True, filter_min_size=None, top_only=None,
                         is_training=False, sort=False):
    if top_only:
        assert top_only < 1
    multiple_ds = False
    if isinstance(imgs_dir, dict):
        assert 'raw' in imgs_dir and 'compressed' in imgs_dir, imgs_dir.keys()
        raw_p, compressed_p = imgs_dir['raw'], imgs_dir['compressed']
        multiple_ds = isinstance(imgs_dir['raw'], list)
        if multiple_ds:
            assert len(raw_p) == len(compressed_p)
    elif ';' in imgs_dir:
        raw_p, compressed_p = imgs_dir.split(';')
    else:
        raise ValueError('Invalid imgs_dir, should be dict or string with ;, got', imgs_dir)

    # works fine if p_ is a list
    get_imgs = lambda p_: cached_listdir_imgs(
            p_, min_size=filter_min_size or crop_size, discard_shitty=discard_shitty)

    if compressed_p == 'JPG':
        print('*** Using JPG...')
        imgs = get_imgs(raw_p)
        return JPGDataset(imgs,
                          random_crops=crop_size if random_transforms else None,
                          random_flips=random_transforms,
                          random_scale=random_scale,
                          center_crops=crop_size if not random_transforms else None,
                          max_imgs=max_imgs)

    if is_training and global_config.get('filter_imgs', False):
        assert not multiple_ds
        print('*** filtering', imgs_dir)
        filter_imgs = global_config['filter_imgs']
        if not isinstance(filter_imgs, int):
            filter_imgs = 680
        print(filter_imgs)
        get_imgs = lambda p_: cached_listdir_imgs_max(p_, max_size=filter_imgs, discard_shitty=True)

    raw, compressed = map(get_imgs, (raw_p, compressed_p))

    if top_only:
        sorted_imgs = sorted((Compressor.bpp_from_compressed_file(p), p) for p in compressed.ps)
        top_only_imgs = sorted_imgs[-int(top_only * len(sorted_imgs)):]
        top_only_ps = [p for _, p in top_only_imgs]
        compressed = Images(top_only_ps, compressed.id + f'_top{top_only:.2f}')
        print(f'*** Using {len(top_only_ps)} of {len(sorted_imgs)} images only')

    if sort:
        print('Sorting...')
        raw = raw.sort()
        compressed = compressed.sort()

    return ResidualDataset(compressed, raw,
                           mode=mode,
                           random_crops=crop_size if random_transforms else None,
                           random_flips=random_transforms,
                           random_scale=random_scale,
                           center_crops=crop_size if not random_transforms else None,
                           max_imgs=max_imgs)