def get_dcn_df(directory,
               model_directory,
               write_files=False,
               force_calc=False):
    """
    Compute and return (as Pandas DF) the rate distortion curve for the learned DCN codec.
    The result is saved as a CSV file in the source directory. If the file exists, the DF
    is loaded and returned.
    """

    # Discover test files
    files, _ = loading.discover_images(directory, n_images=-1, v_images=0)
    batch_x = loading.load_images(files, directory, load='y')
    batch_x = batch_x['y'].astype(np.float32) / (2**8 - 1)

    # Create a new table for the DCN
    df = pd.DataFrame(columns=[
        'image_id', 'filename', 'model_dir', 'codec', 'ssim', 'psnr', 'msssim',
        'msssim_db', 'entropy', 'bytes', 'bpp', 'layers', 'quantization',
        'entropy_reg', 'codebook', 'latent', 'latent_shape', 'n_features'
    ])

    # Discover available models
    model_dirs = list(Path(model_directory).glob('**/progress.json'))
    logger.info('Found {} models'.format(len(model_dirs)))

    df_path = os.path.join(
        directory, 'dcn-{}.csv'.format(
            [x for x in fsutil.split(model_directory) if len(x) > 0][-1]))

    if os.path.isfile(df_path) and not force_calc:
        logger.info('Restoring DCN stats from {}'.format(df_path))
        df = pd.read_csv(df_path, index_col=False)
    else:

        for model_dir in model_dirs:
            logger.info('Processing model dir: {}'.format(model_dir))
            dcn = codec.restore(
                os.path.split(str(model_dir))[0], batch_x.shape[1])

            # Dump compressed images
            for image_id, filename in enumerate(files):

                try:
                    batch_y, image_bytes = codec.simulate_compression(
                        batch_x[image_id:image_id + 1], dcn)
                    batch_z = dcn.compress(batch_x[image_id:image_id + 1])
                    entropy = helpers.stats.entropy(batch_z,
                                                    dcn.get_codebook())
                except Exception as e:
                    logger.error(
                        'Error while processing {} with {} : {}'.format(
                            filename, dcn.model_code, e))
                    raise e

                if write_files:
                    image_dir = os.path.join(directory,
                                             os.path.splitext(filename)[0])
                    if not os.path.isdir(image_dir):
                        os.makedirs(image_dir)

                    image_path = os.path.join(
                        image_dir,
                        dcn.model_code.replace('/', '-') + '.png')
                    imageio.imwrite(image_path,
                                    (255 * batch_y[0]).astype(np.uint8))

                msssim_value = msssim(batch_x[image_id], batch_y[0],
                                      MAX=1).real

                df = df.append(
                    {
                        'image_id':
                        image_id,
                        'filename':
                        filename,
                        'model_dir':
                        os.path.relpath(
                            os.path.split(str(model_dir))[0],
                            model_directory).replace(dcn.scoped_name, ''),
                        'codec':
                        dcn.model_code,
                        'ssim':
                        compare_ssim(batch_x[image_id],
                                     batch_y[0],
                                     multichannel=True,
                                     data_range=1),
                        'psnr':
                        compare_psnr(
                            batch_x[image_id], batch_y[0], data_range=1),
                        'msssim':
                        msssim_value,
                        'msssim_db':
                        -10 * np.log10(1 - msssim_value),
                        'entropy':
                        entropy,
                        'bytes':
                        image_bytes,
                        'bpp':
                        8 * image_bytes / batch_x[image_id].shape[0] /
                        batch_x[image_id].shape[1],
                        'layers':
                        dcn.n_layers if 'n_layers' in dcn._h else None,
                        'quantization':
                        '{}-{:.0f}bpf'.format(dcn._h.rounding, dcn.latent_bpf),
                        'entropy_reg':
                        dcn.entropy_weight,
                        'codebook':
                        dcn._h.rounding,
                        'latent':
                        dcn.n_latent,
                        'latent_shape':
                        '{}x{}x{}'.format(*dcn.latent_shape[1:]),
                        'n_features':
                        dcn.latent_shape[-1]
                    },
                    ignore_index=True)

        df.to_csv(df_path, index=False)

    return df
Example #2
0
def match_jpeg(model, batch_x, axes=None, match='ssim'):

    # Compress using DCN and get number of bytes
    batch_y, bytes_dcn = codec.simulate_compression(batch_x, model)

    ssim_dcn = compare_ssim(batch_x.squeeze(),
                            batch_y.squeeze(),
                            multichannel=True,
                            data_range=1)
    bpp_dcn = 8 * bytes_dcn / np.prod(batch_x.shape[1:-1])
    target = ssim_dcn if match == 'ssim' else bpp_dcn

    try:
        jpeg_quality = jpeg_helpers.match_quality(batch_x.squeeze(),
                                                  target,
                                                  match=match)
    except:
        if match == 'ssim':
            jpeg_quality = 95 if ssim_dcn > 0.8 else 10
        else:
            jpeg_quality = 95 if bpp_dcn > 3 else 10
        print(
            'WARNING Could not find a matching JPEG quality factor - guessing {}'
            .format(jpeg_quality))

    # Compress using JPEG
    batch_j, bytes_jpeg = jpeg_helpers.compress_batch(batch_x[0],
                                                      jpeg_quality,
                                                      effective=True)
    ssim_jpeg = compare_ssim(batch_x.squeeze(),
                             batch_j.squeeze(),
                             multichannel=True,
                             data_range=1)
    bpp_jpg = 8 * bytes_jpeg / np.prod(batch_x.shape[1:-1])

    # Get stats
    code_book = model.get_codebook()
    batch_z = model.compress(batch_x)
    counts = utils.qhist(batch_z, code_book)
    counts = counts.clip(min=1)
    probs = counts / counts.sum()
    entropy = -np.sum(probs * np.log2(probs))

    # Print report
    print('DCN             : {}'.format(model.model_code))
    print('Pixels          : {}x{} = {:,} px'.format(
        batch_x.shape[1], batch_x.shape[2], np.prod(batch_x.shape[1:-1])))
    print('Bitmap          : {:,} bytes'.format(np.prod(batch_x.shape)))
    print('Code-book size  : {} elements from {} to {}'.format(
        len(code_book), min(code_book), max(code_book)))
    print('Entropy         : {:.2f} bits per symbol'.format(entropy))
    print('Latent size     : {:,}'.format(np.prod(batch_z.shape)))
    print('PPF Naive       : {:,.0f} --> {:,.0f} bytes [{} bits per element]'.
          format(
              np.prod(batch_z.shape) * np.log2(len(code_book)) / 8,
              np.prod(batch_z.shape) * np.ceil(np.log2(len(code_book))) / 8,
              np.ceil(np.log2(len(code_book)))))
    print('PPF Theoretical : {:,.0f} bytes ({:.2f} bpp)'.format(
        np.prod(batch_z.shape) * entropy / 8,
        np.prod(batch_z.shape) * entropy / np.prod(batch_x.shape[1:-1])))
    print('FSE Coded       : {:,} bytes ({:.2f} bpp) --> ssim: {:.3f}'.format(
        bytes_dcn, bpp_dcn, ssim_dcn))
    print(
        'JPEG (Q={:2d})     : {:,} bytes ({:0.2f} bpp) --> ssim: {:.3f} // effective size disregarding JPEG headers'
        .format(jpeg_quality, bytes_jpeg, bpp_jpg, ssim_jpeg))

    # Plot results
    if axes is None:
        fig, axes = plotting.sub(6, ncols=3)
        fig.set_size_inches(12, 10)
        fig.tight_layout()
    else:
        fig = axes[0].figure

    # Plot full-resolution
    plotting.quickshow(batch_x,
                       'Original ({0}x{0})'.format(batch_x.shape[1]),
                       axes=axes[0])
    plotting.quickshow(batch_y,
                       'DCN ssim:{:.2f} bpp:{:.2f}'.format(ssim_dcn, bpp_dcn),
                       axes=axes[1])
    plotting.quickshow(batch_j,
                       'JPEG {} ssim:{:.2f} bpp:{:.2f}'.format(
                           jpeg_quality, ssim_jpeg, bpp_jpg),
                       axes=axes[2])

    # Plot zoom
    crop_size = max([64, batch_x.shape[1] // 4])
    plotting.quickshow(utils.crop_middle(batch_x, crop_size),
                       'Original crop ({0}x{0})'.format(crop_size),
                       axes=axes[3])
    plotting.quickshow(utils.crop_middle(batch_y, crop_size),
                       'DCN crop ({0}x{0})'.format(crop_size),
                       axes=axes[4])
    plotting.quickshow(utils.crop_middle(batch_j, crop_size),
                       'JPEG crop ({0}x{0})'.format(crop_size),
                       axes=axes[5])

    return fig