예제 #1
0
def make_bin_pkl(img_folder, num_bins, overwrite=False):
    if isinstance(img_folder, list):
        img_folder = img_folder[0]
    pkl_out = get_default_pkl_p(img_folder, num_bins)
    if os.path.isfile(pkl_out):
        return pkl_out

    assert 'train_oi' in img_folder, img_folder  # currently not supported, bc min_size and discard_shitty flags
    ps = cached_listdir_imgs.cached_listdir_imgs(img_folder,
                                                 min_size=None,
                                                 discard_shitty=True).ps

    ps_bpsps = sorted(
        ((p, Compressor.bpp_from_compressed_file(p) / 3) for p in ps),
        key=lambda xy: xy[1])  # sort by bpsp
    bpsps = [bpsp for _, bpsp in ps_bpsps]

    # border     b0    b1     ...   bk         k+1 borders
    # bin_idx  0    1      2  ... k    k+1  => k+2 bins
    #
    # for N bins, we need N-1 borders
    # first border is after 1/NB-th of the data

    # NB + 1 so that we get NB-1 evenly spaced bins *within* the data
    bin_borders_x = np.linspace(0, len(bpsps) - 1, num_bins + 1, dtype=np.int)
    # throw away the boundaries
    bin_borders_x = bin_borders_x[1:-1]
    bin_borders_b = [bpsps[x] for x in bin_borders_x]

    with open(pkl_out, 'wb') as f:
        print('Saving', bin_borders_b, '\n->', pkl_out)
        pickle.dump(bin_borders_b, f)

    return pkl_out
def main():
    p = argparse.ArgumentParser()
    p.add_argument('raw')
    p.add_argument('compressed', nargs='*')
    flags = p.parse_args()
    if ';' in flags.raw:
        flags.raw, flags.compressed = flags.raw.split(';')
    else:
        flags.compressed = flags.compressed[0]

    raw = cached_listdir_imgs(flags.raw, discard_shitty=False)
    compressed = cached_listdir_imgs(flags.compressed, discard_shitty=False)
    print(raw, compressed)

    print('Average bpp', np.mean([Compressor.bpp_from_compressed_file(p) for p in compressed.ps]))

    r = ResidualDataset(compressed, raw, 'diff', 256, True, output_device=pe.DEVICE)

    # for i in range(len(r)):
    #     print(r[i]['residual'].unique())
    # exit(1)
    d = DataLoader(r, batch_size=10, shuffle=False, num_workers=2)
    mi, ma = None, None
    for b in d:
        res = b['raw']
        # print(res.unique())
        if mi is None:
            mi = res.min()
            ma = res.max()
        else:
            mi = min(res.min(), mi)
            ma = max(res.max(), ma)
        print(mi, ma)
    def _get_pils(self, idx):
        if self._cache and idx in self._cache:
            yield self._cache[idx]
            return

        compressed_p = self.compressed_images.ps[idx]
        bpp = Compressor.bpp_from_compressed_file(compressed_p)

        raw_p = self.compressed_to_raw[os.path.basename(compressed_p)]

        compressed, f1 = self._read_img(compressed_p)
        raw, f2 = self._read_img(raw_p)
        if not NO_ERRORS:
            assert compressed.size == raw.size, f'Error for {compressed_p}, {raw_p}; {compressed.size, raw.size}'

        if self._cache is not None and idx not in self._cache:
            with self._cache_lock:
                if idx not in self._cache:
                    print('Caching', idx)
                    compressed, raw = map(ResidualDataset._convert, (compressed, raw))
                    self._cache[idx] = compressed, raw, bpp

        yield compressed, raw, bpp

        f1.close()
        f2.close()
def get_residual_dataset(imgs_dir, random_transforms: bool,
                         random_scale, crop_size: int, mode: str, max_imgs=None,
                         discard_shitty=True, filter_min_size=None, top_only=None,
                         is_training=False, sort=False):
    if top_only:
        assert top_only < 1
    multiple_ds = False
    if isinstance(imgs_dir, dict):
        assert 'raw' in imgs_dir and 'compressed' in imgs_dir, imgs_dir.keys()
        raw_p, compressed_p = imgs_dir['raw'], imgs_dir['compressed']
        multiple_ds = isinstance(imgs_dir['raw'], list)
        if multiple_ds:
            assert len(raw_p) == len(compressed_p)
    elif ';' in imgs_dir:
        raw_p, compressed_p = imgs_dir.split(';')
    else:
        raise ValueError('Invalid imgs_dir, should be dict or string with ;, got', imgs_dir)

    # works fine if p_ is a list
    get_imgs = lambda p_: cached_listdir_imgs(
            p_, min_size=filter_min_size or crop_size, discard_shitty=discard_shitty)

    if compressed_p == 'JPG':
        print('*** Using JPG...')
        imgs = get_imgs(raw_p)
        return JPGDataset(imgs,
                          random_crops=crop_size if random_transforms else None,
                          random_flips=random_transforms,
                          random_scale=random_scale,
                          center_crops=crop_size if not random_transforms else None,
                          max_imgs=max_imgs)

    if is_training and global_config.get('filter_imgs', False):
        assert not multiple_ds
        print('*** filtering', imgs_dir)
        filter_imgs = global_config['filter_imgs']
        if not isinstance(filter_imgs, int):
            filter_imgs = 680
        print(filter_imgs)
        get_imgs = lambda p_: cached_listdir_imgs_max(p_, max_size=filter_imgs, discard_shitty=True)

    raw, compressed = map(get_imgs, (raw_p, compressed_p))

    if top_only:
        sorted_imgs = sorted((Compressor.bpp_from_compressed_file(p), p) for p in compressed.ps)
        top_only_imgs = sorted_imgs[-int(top_only * len(sorted_imgs)):]
        top_only_ps = [p for _, p in top_only_imgs]
        compressed = Images(top_only_ps, compressed.id + f'_top{top_only:.2f}')
        print(f'*** Using {len(top_only_ps)} of {len(sorted_imgs)} images only')

    if sort:
        print('Sorting...')
        raw = raw.sort()
        compressed = compressed.sort()

    return ResidualDataset(compressed, raw,
                           mode=mode,
                           random_crops=crop_size if random_transforms else None,
                           random_flips=random_transforms,
                           random_scale=random_scale,
                           center_crops=crop_size if not random_transforms else None,
                           max_imgs=max_imgs)
 def get_bpg_bpsps(self, filenames_Qs):
     # Dict Q -> (raw_fn -> compressed_fn)
     q_to_d = {Q: self.datasets[Q].get_raw_to_compressed_dict()
               for Q in self.datasets.keys()}
     return [Compressor.bpp_from_compressed_file(q_to_d[Q][fn])/3 for fn, Q in filenames_Qs]