def process(self, batch_x, quality=None, return_entropy=False):
        """ Compress a batch of images (NHW3:rgb) with a given quality factor:

        - if quality is a number - use this quality level
        - if quality is an iterable with 2 numbers - use a random integer from that range
        - if quality is an iterable with >2 numbers - use a random value from that set
        """

        quality = self.quality if quality is None else quality
        quality = int(quality)

        if not is_valid_quality(quality):
            raise ValueError('Invalid or unspecified JPEG quality!')

#         if hasattr(quality, '__getitem__') and len(quality) > 2:
#             quality = int(np.random.choice(quality))
        
#         elif hasattr(quality, '__getitem__') and len(quality) == 2:
#             quality = np.random.randint(quality[0], quality[1])
        
#         elif is_number(quality) and quality >= 1 and quality <= 100:
#             quality = int(quality)
        
        #else:
        #    raise ValueError('Invalid quality! {}'.format(quality))

        if self._model is None:
            if not isinstance(batch_x, np.ndarray):
                batch_x = batch_x.numpy()
            if return_entropy:
                return jpeg_helpers.compress_batch(batch_x, quality)[0], np.nan
            else:
                return jpeg_helpers.compress_batch(batch_x, quality)[0]
        else:
            if quality != self.quality:
                old_q_luma, old_q_chroma = self._model._q_mtx_luma, self._model._q_mtx_chroma
                self._model._q_mtx_luma = jpeg_qtable(quality, 0)
                self._model._q_mtx_chroma = jpeg_qtable(quality, 1)
            
            y, X = self._model(batch_x)

            if quality != self.quality:
                self._model._q_mtx_luma, self._model._q_mtx_chroma = old_q_luma, old_q_chroma

            if return_entropy:
                # TODO This currently takes too much memory
                entropy = tf_helpers.entropy(X, self._model.quantization.codebook, v=5, gamma=5)[0]
                return y, X, entropy

            return y
def get_jpeg_df(directory,
                write_files=False,
                effective_bytes=True,
                force_calc=False):
    """
    Compute and return (as Pandas DF) the rate distortion curve for JPEG. The result is saved
    as a CSV file in the source directory. If the file exists, the DF is loaded and returned.

    Files are saved as JPEG using imageio.
    """

    files, _ = loading.discover_images(directory, n_images=-1, v_images=0)
    batch_x = loading.load_images(files, directory, load='y')
    batch_x = batch_x['y'].astype(np.float32) / (2**8 - 1)

    # Get trade-off for JPEG
    quality_levels = np.arange(95, 5, -5)
    df_jpeg_path = os.path.join(directory, 'jpeg.csv')

    if os.path.isfile(df_jpeg_path) and not force_calc:
        logger.info('Restoring JPEG stats from {}'.format(df_jpeg_path))
        df = pd.read_csv(df_jpeg_path, index_col=False)
    else:
        df = pd.DataFrame(columns=[
            'image_id', 'filename', 'codec', 'quality', 'ssim', 'psnr',
            'msssim', 'msssim_db', 'bytes', 'bpp'
        ])

        with tqdm.tqdm(total=len(files) * len(quality_levels),
                       ncols=120,
                       desc='JPEG') as pbar:

            for image_id, filename in enumerate(files):

                # Read the original image
                image = batch_x[image_id]

                for qi, q in enumerate(quality_levels):

                    # Compress images and get effective bytes (only image data - no headers)
                    image_compressed, image_bytes = jpeg_helpers.compress_batch(
                        image, q, effective=effective_bytes)

                    if write_files:
                        image_dir = os.path.join(directory,
                                                 os.path.splitext(filename)[0])
                        if not os.path.isdir(image_dir):
                            os.makedirs(image_dir)

                        image_path = os.path.join(image_dir,
                                                  'jpeg_q{:03d}.png'.format(q))
                        imageio.imwrite(image_path,
                                        (255 * image_compressed).astype(
                                            np.uint8))

                    msssim_value = msssim(image, image_compressed, MAX=1).real

                    df = df.append(
                        {
                            'image_id':
                            image_id,
                            'filename':
                            filename,
                            'codec':
                            'jpeg',
                            'quality':
                            q,
                            'ssim':
                            compare_ssim(image,
                                         image_compressed,
                                         multichannel=True,
                                         data_range=1),
                            'psnr':
                            compare_psnr(image, image_compressed,
                                         data_range=1),
                            'msssim':
                            msssim_value,
                            'msssim_db':
                            -10 * np.log10(1 - msssim_value),
                            'bytes':
                            image_bytes,
                            'bpp':
                            8 * image_bytes / image.shape[0] / image.shape[1]
                        },
                        ignore_index=True)

                    pbar.set_postfix(image_id=image_id, quality=q)
                    pbar.update(1)

        df.to_csv(os.path.join(directory, 'jpeg.csv'), index=False)

    return df
Exemple #3
0
def match_jpeg(model, batch_x, axes=None, match='ssim'):

    # Compress using DCN and get number of bytes
    batch_y, bytes_dcn = codec.simulate_compression(batch_x, model)

    ssim_dcn = compare_ssim(batch_x.squeeze(),
                            batch_y.squeeze(),
                            multichannel=True,
                            data_range=1)
    bpp_dcn = 8 * bytes_dcn / np.prod(batch_x.shape[1:-1])
    target = ssim_dcn if match == 'ssim' else bpp_dcn

    try:
        jpeg_quality = jpeg_helpers.match_quality(batch_x.squeeze(),
                                                  target,
                                                  match=match)
    except:
        if match == 'ssim':
            jpeg_quality = 95 if ssim_dcn > 0.8 else 10
        else:
            jpeg_quality = 95 if bpp_dcn > 3 else 10
        print(
            'WARNING Could not find a matching JPEG quality factor - guessing {}'
            .format(jpeg_quality))

    # Compress using JPEG
    batch_j, bytes_jpeg = jpeg_helpers.compress_batch(batch_x[0],
                                                      jpeg_quality,
                                                      effective=True)
    ssim_jpeg = compare_ssim(batch_x.squeeze(),
                             batch_j.squeeze(),
                             multichannel=True,
                             data_range=1)
    bpp_jpg = 8 * bytes_jpeg / np.prod(batch_x.shape[1:-1])

    # Get stats
    code_book = model.get_codebook()
    batch_z = model.compress(batch_x)
    counts = utils.qhist(batch_z, code_book)
    counts = counts.clip(min=1)
    probs = counts / counts.sum()
    entropy = -np.sum(probs * np.log2(probs))

    # Print report
    print('DCN             : {}'.format(model.model_code))
    print('Pixels          : {}x{} = {:,} px'.format(
        batch_x.shape[1], batch_x.shape[2], np.prod(batch_x.shape[1:-1])))
    print('Bitmap          : {:,} bytes'.format(np.prod(batch_x.shape)))
    print('Code-book size  : {} elements from {} to {}'.format(
        len(code_book), min(code_book), max(code_book)))
    print('Entropy         : {:.2f} bits per symbol'.format(entropy))
    print('Latent size     : {:,}'.format(np.prod(batch_z.shape)))
    print('PPF Naive       : {:,.0f} --> {:,.0f} bytes [{} bits per element]'.
          format(
              np.prod(batch_z.shape) * np.log2(len(code_book)) / 8,
              np.prod(batch_z.shape) * np.ceil(np.log2(len(code_book))) / 8,
              np.ceil(np.log2(len(code_book)))))
    print('PPF Theoretical : {:,.0f} bytes ({:.2f} bpp)'.format(
        np.prod(batch_z.shape) * entropy / 8,
        np.prod(batch_z.shape) * entropy / np.prod(batch_x.shape[1:-1])))
    print('FSE Coded       : {:,} bytes ({:.2f} bpp) --> ssim: {:.3f}'.format(
        bytes_dcn, bpp_dcn, ssim_dcn))
    print(
        'JPEG (Q={:2d})     : {:,} bytes ({:0.2f} bpp) --> ssim: {:.3f} // effective size disregarding JPEG headers'
        .format(jpeg_quality, bytes_jpeg, bpp_jpg, ssim_jpeg))

    # Plot results
    if axes is None:
        fig, axes = plotting.sub(6, ncols=3)
        fig.set_size_inches(12, 10)
        fig.tight_layout()
    else:
        fig = axes[0].figure

    # Plot full-resolution
    plotting.quickshow(batch_x,
                       'Original ({0}x{0})'.format(batch_x.shape[1]),
                       axes=axes[0])
    plotting.quickshow(batch_y,
                       'DCN ssim:{:.2f} bpp:{:.2f}'.format(ssim_dcn, bpp_dcn),
                       axes=axes[1])
    plotting.quickshow(batch_j,
                       'JPEG {} ssim:{:.2f} bpp:{:.2f}'.format(
                           jpeg_quality, ssim_jpeg, bpp_jpg),
                       axes=axes[2])

    # Plot zoom
    crop_size = max([64, batch_x.shape[1] // 4])
    plotting.quickshow(utils.crop_middle(batch_x, crop_size),
                       'Original crop ({0}x{0})'.format(crop_size),
                       axes=axes[3])
    plotting.quickshow(utils.crop_middle(batch_y, crop_size),
                       'DCN crop ({0}x{0})'.format(crop_size),
                       axes=axes[4])
    plotting.quickshow(utils.crop_middle(batch_j, crop_size),
                       'JPEG crop ({0}x{0})'.format(crop_size),
                       axes=axes[5])

    return fig