Example #1
0
def compare_batches(batch_a, batch_b, labels=None, fig=None, figwidth=4, nrows=3, transpose=False):

    n_images = min(len(batch_a), len(batch_b))

    labels = labels or ['', '']

    fig, axes = plots.sub(n_images * nrows, figwidth, ncols=n_images, fig=fig, transpose=transpose)

    for i, (img_a, img_b) in enumerate(zip(batch_a, batch_b)):

        label_a = f'(A) {labels[0]}'
        plots.image(img_a, label_a, axes=axes[i + n_images * 0])

        psnr = metrics.psnr(img_a, img_b)
        ssim = metrics.ssim(img_a, img_b)
        label_b = f'(B) {labels[1]}: {psnr:.1f} dB / {ssim:.3f}'
        plots.image(img_b, label_b, axes=axes[i + n_images * 1])

        diff_ab = img_b - img_a
        diff_abs = np.abs(img_b - img_a)
        diff_mean = diff_ab.mean()

        if nrows > 2:
            plots.image(image.normalize(diff_ab), f'A - B: mean abs {diff_mean:.3f}', axes=axes[i + n_images*2])
        
        if nrows > 3:
            plots.image(image.normalize(diff_abs, 0.1), f'|A - B|: mean abs {diff_mean:.3f}', axes=axes[i + n_images*3])

    return fig
Example #2
0
def validate_dcn(dcn, data, save_dir=False, epoch=0, show_ref=False):
    """
    Computes validation metrics for a compression model (DCN). (If not a DCN, the function returns immediately).
    If requested, plot compressed images to a JPEG image.

    :param dcn: the DCN model
    :param data: the dataset (instance of Dataset)
    :param data: the dataset (instance of Dataset)
    :param save_dir: path to the directory where figures should be generated
    :param epoch: epoch counter to be appended to the output filename
    :param show_ref: whether to show only the compressed image or also the input image as reference
    :return: tuple of lists with per-image measurements of (ssims, psnrs, losses, entropies)
    """

    if not isinstance(dcn, DCN):
        return

    # Compute latent representations and compressed output
    batch_x = data.next_validation_batch(0, data.count_validation)
    if isinstance(batch_x, tuple):
        batch_x = batch_x[-1]
    batch_y, entropy = dcn.process(batch_x, return_entropy=True)
    entropy = float(entropy.numpy())

    ssim = metrics.ssim(batch_x, batch_y.numpy()).tolist()
    psnr = metrics.psnr(batch_x, batch_y.numpy()).tolist()

    loss = float(dcn.loss(batch_x, batch_y, entropy).numpy())

    # If requested, plot a figure with input/output pairs
    if save_dir is not None:
        images_x = np.minimum(data.count_validation, 10 if not show_ref else 5)
        images_y = np.ceil(data.count_validation / images_x)
        fig = Figure(figsize=(20, 20 / images_x * images_y *
                              (1 if not show_ref else 0.5)))

        for b in range(data.count_validation):
            ax = fig.add_subplot(images_y, images_x, b + 1)
            plots.image(np.concatenate(
                (batch_x[b], batch_y[b]), axis=1) if show_ref else batch_y[b],
                        '{:.1f} / {:.2f}'.format(psnr[b], ssim[b]),
                        axes=ax)

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        fig.savefig('{}/dcn_validation_{:05d}.jpg'.format(save_dir, epoch),
                    bbox_inches='tight',
                    dpi=100,
                    quality=90)
        del fig

    return {
        'ssim': float(np.mean(ssim)),
        'psnr': float(np.mean(psnr)),
        'loss': loss,
        'entropy': entropy
    }
Example #3
0
def validate(model,
             data,
             out_directory,
             savefig=False,
             epoch=0,
             show_ref=False,
             loss_metric='L2'):

    ssims, psnrs, losss = [], [], []

    if loss_metric not in ['L2', 'L1', 'SSIM', 'MS-SSIM']:
        raise ValueError('Unsupported loss ({})!'.format(loss_metric))

    if savefig:
        images_x = np.minimum(data.count_validation,
                              20 if not show_ref else 10)
        images_y = np.ceil(data.count_validation / images_x)
        fig = Figure(figsize=(40, 1.1 * 40 / images_x * images_y *
                              (1 if not show_ref else 0.5)))

    developed_out = np.zeros_like(data['validation']['y'], dtype=np.float32)

    for b in range(data.count_validation):

        # Fetch the next example and develop the RGB image
        example_x, example_y = data.next_validation_batch(b, 1)
        developed = model.process(example_x).numpy().clip(0, 1)
        developed_out[b, :, :, :] = developed
        developed = developed.squeeze()
        reference = example_y.squeeze()

        # Compute loss & quality metrics
        ssim = float(metrics.ssim(reference, developed))
        psnr = float(metrics.psnr(reference, developed))

        if loss_metric == 'L2':
            loss = metrics.mse(255 * reference, 255 * developed)
        elif loss_metric == 'L1':
            loss = metrics.mae(255 * reference, 255 * developed)
        elif loss_metric == 'SSIM':
            loss = 255 * (1 - metrics.ssim(reference, developed))
        else:
            raise ValueError('Unsupported loss ({})!'.format(loss_metric))

        ssims.append(ssim)
        psnrs.append(psnr)
        losss.append(loss)

        if savefig:
            ax = fig.add_subplot(images_y, images_x, b + 1)
            if show_ref:
                ax.imshow(np.concatenate((reference, developed), axis=1))
            else:
                ax.imshow(developed)
            ax.set_xticks([])
            ax.set_yticks([])
            label_index = int(
                b // (data.count_validation / len(data.files['validation'])))
            ax.set_title('{} : {:.1f} dB / {:.2f}'.format(
                data.files['validation'][label_index], psnr, ssim),
                         fontsize=6)

    if savefig:
        if not os.path.exists(out_directory):
            os.makedirs(out_directory)
        fig.savefig(os.path.join(out_directory,
                                 'validation_{:05d}.jpg'.format(epoch)),
                    bbox_inches='tight',
                    dpi=150)
        del fig

    return ssims, psnrs, losss, developed_out
Example #4
0
def match_jpeg(model, batch_x, axes=None, match='ssim'):

    # Compress using DCN and get number of bytes
    batch_y, bytes_dcn = codec.simulate_compression(batch_x, model)

    ssim_dcn = metrics.ssim(batch_x.squeeze(), batch_y.squeeze()).mean()
    bpp_dcn = 8 * bytes_dcn / np.prod(batch_x.shape[1:-1])
    target = ssim_dcn if match == 'ssim' else bpp_dcn

    try:
        jpeg_quality = jpeg_helpers.match_quality(batch_x.squeeze(),
                                                  target,
                                                  match=match)
    except:
        if match == 'ssim':
            jpeg_quality = 95 if ssim_dcn > 0.8 else 10
        else:
            jpeg_quality = 95 if bpp_dcn > 3 else 10
        print(
            'WARNING Could not find a matching JPEG quality factor - guessing {}'
            .format(jpeg_quality))

    # Compress using JPEG
    batch_j, bytes_jpeg = jpeg_helpers.compress_batch(batch_x[0],
                                                      jpeg_quality,
                                                      effective=True)
    ssim_jpeg = metrics.ssim(batch_x.squeeze(), batch_j.squeeze()).mean()
    bpp_jpg = 8 * bytes_jpeg / np.prod(batch_x.shape[1:-1])

    # Get stats
    code_book = model.get_codebook()
    batch_z = model.compress(batch_x).numpy()
    counts = helpers.stats.hist(batch_z, code_book)
    counts = counts.clip(min=1)
    probs = counts / counts.sum()
    entropy = -np.sum(probs * np.log2(probs))

    # Print report
    print('DCN             : {}'.format(model.model_code))
    print('Pixels          : {}x{} = {:,} px'.format(
        batch_x.shape[1], batch_x.shape[2], np.prod(batch_x.shape[1:-1])))
    print('Bitmap          : {:,} bytes'.format(np.prod(batch_x.shape)))
    print('Code-book size  : {} elements from {} to {}'.format(
        len(code_book), min(code_book), max(code_book)))
    print('Entropy         : {:.2f} bits per symbol'.format(entropy))
    print('Latent size     : {:,}'.format(np.prod(batch_z.shape)))
    print('PPF Naive       : {:,.0f} --> {:,.0f} bytes [{} bits per element]'.
          format(
              np.prod(batch_z.shape) * np.log2(len(code_book)) / 8,
              np.prod(batch_z.shape) * np.ceil(np.log2(len(code_book))) / 8,
              np.ceil(np.log2(len(code_book)))))
    print('PPF Theoretical : {:,.0f} bytes ({:.2f} bpp)'.format(
        np.prod(batch_z.shape) * entropy / 8,
        np.prod(batch_z.shape) * entropy / np.prod(batch_x.shape[1:-1])))
    print('FSE Coded       : {:,} bytes ({:.2f} bpp) --> ssim: {:.3f}'.format(
        bytes_dcn, bpp_dcn, ssim_dcn))
    print(
        'JPEG (Q={:2d})     : {:,} bytes ({:0.2f} bpp) --> ssim: {:.3f} // effective size disregarding JPEG headers'
        .format(jpeg_quality, bytes_jpeg, bpp_jpg, ssim_jpeg))

    # Plot results
    if axes is None:
        fig, axes = plots.sub(6, ncols=3)
        fig.set_size_inches(12, 10)
        fig.tight_layout()
    else:
        fig = axes[0].figure

    # Plot full-resolution
    plots.image(batch_x,
                'Original ({0}x{0})'.format(batch_x.shape[1]),
                axes=axes[0])
    plots.image(batch_y,
                'DCN ssim:{:.2f} bpp:{:.2f}'.format(ssim_dcn, bpp_dcn),
                axes=axes[1])
    plots.image(batch_j,
                'JPEG {} ssim:{:.2f} bpp:{:.2f}'.format(
                    jpeg_quality, ssim_jpeg, bpp_jpg),
                axes=axes[2])

    # Plot zoom
    crop_size = max([64, batch_x.shape[1] // 4])
    plots.image(helpers.image.crop_middle(batch_x, crop_size),
                'Original crop ({0}x{0})'.format(crop_size),
                axes=axes[3])
    plots.image(helpers.image.crop_middle(batch_y, crop_size),
                'DCN crop ({0}x{0})'.format(crop_size),
                axes=axes[4])
    plots.image(helpers.image.crop_middle(batch_j, crop_size),
                'JPEG crop ({0}x{0})'.format(crop_size),
                axes=axes[5])

    return fig
Example #5
0
def show_example(model, batch_x):

    # Compress and decompress model
    batch_z = model.compress(batch_x).numpy()
    batch_y = model.decompress(batch_z).numpy()

    # Get empirical histogram of the latent representation
    codebook = model.get_codebook()

    qmin = np.floor(codebook[0])
    qmax = np.ceil(codebook[-1])

    bin_centers = np.arange(qmin - 1, qmax + 1, 0.1)
    bin_boundaries = np.convolve(bin_centers, [0.5, 0.5], mode='valid')
    bin_centers = bin_centers[1:-1]

    hist_emp = np.histogram(batch_z.reshape((-1, )),
                            bins=bin_boundaries,
                            density=True)[0]
    hist_emp = np.maximum(hist_emp, 1e-9)
    hist_emp = hist_emp / hist_emp.sum()

    # Get TF histogram estimate based on soft quantization
    hist = helpers.stats.hist(batch_z, codebook)
    hist = hist / hist.sum()

    # Entropy
    entropy = -np.sum(hist * np.log2(hist))
    entropy_emp = -np.sum(hist_emp * np.log2(hist_emp))

    fig, axes = plots.sub(2, ncols=1)
    fig.set_size_inches(12, 10)

    axes[0].plot(bin_centers, hist_emp / hist_emp.max(), 'r-')
    axes[0].plot(codebook, hist / hist.max(), '-bo')

    axes[0].legend([
        'Empirical H={:.2f}'.format(entropy_emp),
        'TF estimate (soft) H={:.2f}'.format(entropy)
    ])
    axes[0].set_ylabel('normalized frequency')
    axes[0].set_xlabel('latent values')

    # Thumbnails
    indices = np.argsort(np.var(batch_x, axis=(1, 2, 3)))[::-1]
    thumbs_pairs_few = np.concatenate((batch_x[indices], batch_y[indices]),
                                      axis=0)
    thumbs_few = (
        255 * plots.thumbnails(thumbs_pairs_few, ncols=len(batch_x))).astype(
            np.uint8)

    ssim_values = [
        metrics.ssim(batch_x[i], batch_y[i]).mean()
        for i in range(len(batch_x))
    ]

    plots.image(thumbs_few,
                'Sample reconstructions, ssim={:.3f}'.format(
                    np.mean(ssim_values)),
                axes=axes[1])

    fig.tight_layout()
    return fig
Example #6
0
def compare_images_ab_ref(img_ref, img_a, img_b, labels=None, extras=False, fig=None):

    labels = labels or ['target', '', '']

    img_a = img_a.squeeze()
    img_b = img_b.squeeze()
    img_ref = img_ref.squeeze()

    fig, axes = plots.sub(9 if extras else 3, ncols=3, fig=fig)
    # Index of the last axes 
    j = 3 if extras else 2

    plots.image(img_ref, '(T) {}'.format(labels[0]), axes=axes[0])

    label_a = '(A) {}: {:.1f} dB / {:.3f}'.format(labels[1], metrics.psnr(img_ref, img_a), metrics.ssim(img_ref, img_a))
    plots.image(img_a, label_a, axes=axes[1])

    label_b = '(B) {}: {:.1f} dB / {:.3f}'.format(labels[2], metrics.psnr(img_ref, img_b), metrics.ssim(img_ref, img_b))
    plots.image(img_b, label_b, axes=axes[j])

    # A hack to allow image axes to zoom together
    axes[1].get_shared_x_axes().join(axes[0], axes[1])
    axes[j].get_shared_x_axes().join(axes[0], axes[j])
    axes[1].get_shared_y_axes().join(axes[0], axes[1])
    axes[j].get_shared_y_axes().join(axes[0], axes[j])

    if not extras:
        return fig

    # Compute and plot difference images
    diff_a = np.abs(img_a - img_ref)
    diff_a_mean = diff_a.mean()
    diff_a = image.normalize(diff_a, 0.1)

    diff_b = np.abs(img_b - img_ref)
    diff_b_mean = diff_b.mean()
    diff_b = image.normalize(diff_b, 0.1)

    diff_ab = np.abs(img_b - img_a)
    diff_ab_mean = diff_ab.mean()
    diff_ab = image.normalize(diff_ab, 0.1)

    plots.image(diff_a, 'T - A: mean abs {:.3f}'.format(diff_a_mean), axes=axes[2])
    plots.image(diff_b, 'T - B: mean abs {:.3f}'.format(diff_b_mean), axes=axes[6])
    plots.image(diff_ab, 'A - B: mean abs {:.3f}'.format(diff_ab_mean), axes=axes[4])

    # Compute and plot spectra
    fft_a = fft_log_norm(diff_a)
    fft_b = fft_log_norm(diff_b)

    # fft_ab = utils.normalize(np.abs(fft_a - fft_b))
    fft_ab = image.normalize(np.abs(fft_log_norm(img_b) - fft_log_norm(img_a)), 0.01)
    plots.image(fft_a, 'FFT(T - A)', axes=axes[5])
    plots.image(fft_b, 'FFT(T - B)', axes=axes[7])
    plots.image(fft_ab, 'FFT(A) - FFT(B)', axes=axes[8])

    return fig
Example #7
0
def develop_image(pipeline,
                  camera=None,
                  batch=None,
                  image=None,
                  patch_size=0,
                  patches=2,
                  root_dir='./data',
                  pipeline_args=None):
    """
    Display a patch developed by a neural imaging pipeline.
    """

    if camera is not None:
        supported_cameras = fsutil.listdir(
            os.path.join(root_dir, 'models', 'nip'), '.*')
        if camera not in supported_cameras:
            raise ValueError(
                'Camera data not found ({})! Available cameras: {}'.format(
                    camera, ', '.join(supported_cameras)))
        root_dirname = os.path.join(root_dir, 'models', 'nip', camera)
        data_dirname = os.path.join(root_dir, 'raw', 'training_data', camera)

    if patch_size != 0 and (patch_size < 4 or patch_size > 2048):
        raise ValueError('Patch size seems to be invalid!')

    # Lazy imports to minimize delay for invalid command line parameters
    import numpy as np
    import imageio as io
    import matplotlib.pyplot as plt
    import tensorflow as tf
    from models import pipelines

    # Construct the NIP model ---------------------------------------------------------------------

    if os.path.isdir(pipeline):
        # Restore a NIP model from a training log
        model = tfmodel.restore(pipeline, pipelines)
    else:
        # Construct the NIP model from class name (and optional arguments)
        if pipeline_args is None:
            model = getattr(pipelines, pipeline)()
        else:
            model = getattr(pipelines, pipeline)(**pipeline_args)

        loaded_model = False
        candidate_dirs = [
            os.path.join(root_dirname, model.model_code),
            os.path.join(root_dirname)
        ]
        for candidate in candidate_dirs:
            if os.path.isdir(candidate):
                model.load_model(candidate)
                loaded_model = True
                break

        if not loaded_model:
            raise FileNotFoundError(
                f'Could not find the corresponding model: {candidate_dirs}')

    # Load image(s) -------------------------------------------------------------------------------

    if image is None and batch is not None:
        print('Loading a batch of {} images'.format(batch))
        data = dataset.Dataset(data_dirname,
                               n_images=0,
                               v_images=batch,
                               val_rgb_patch_size=patch_size or 256,
                               val_n_patches=patches)
        sample_x, sample_y = data.next_validation_batch(
            0, data.count_validation)

        with open('config/cameras.json') as f:
            cameras = json.load(f)
            cfa, srgb = cameras[camera]['cfa'], np.array(
                cameras[camera]['srgb'])

    elif image is not None:
        print('Loading a RAW image {}'.format(image))
        sample_x, cfa, srgb, _ = raw.unpack(image, expand=True)
        sample_y = raw.process(image, brightness=None, expand=True)

    if isinstance(model, pipelines.ClassicISP):
        print('Configuring ISP to CFA: {} & sRGB {}'.format(
            cfa,
            srgb.round(2).tolist()))
        model.set_cfa_pattern(cfa)
        model.set_srgb_conversion(srgb)

    sample_Y = model.process(sample_x).numpy()

    if patch_size > 0:
        xx = (sample_y.shape[2] - patch_size) // 2
        yy = (sample_y.shape[1] - patch_size) // 2
        sample_y = sample_y[:, yy:yy + patch_size, xx:xx + patch_size, :]
        sample_Y = sample_Y[:, yy:yy + patch_size, xx:xx + patch_size, :]

    psnrs = metrics.psnr(sample_y, sample_Y)
    ssims = metrics.ssim(sample_y, sample_Y)

    print('sample x: {}'.format(sample_x.shape))
    print('sample y: {}'.format(sample_y.shape))
    print('sample Y: {}'.format(sample_Y.shape))

    # Plot images ---------------------------------------------------------------------------------
    if len(sample_y) > 1:
        sample_y = plots.thumbnails(sample_y, batch, True)
        sample_Y = plots.thumbnails(sample_Y, batch, True)
    else:
        sample_y = sample_y.squeeze()
        sample_Y = sample_Y.squeeze()

    print('thumbnails: {}'.format(sample_y.shape))

    ncols = 1 if sample_y.shape[1] > sample_y.shape[0] else 2
    nrows = 2 if ncols == 1 else 1
    fig, axes = plt.subplots(nrows, ncols)

    plots.image(sample_Y,
                '{}, PSNR={:.1f} dB, SSIM={:.2f} : {{}}'.format(
                    model.model_code, float(psnrs.mean()),
                    float(ssims.mean())),
                axes=axes[0])
    plots.image(sample_y, 'Target RGB images () : {}', axes=axes[1])

    plt.show()
    plt.close()
Example #8
0
 def get_ssim(q):
     image_j = compress_batch(image, q, subsampling=subsampling)[0].squeeze()
     c_ssim = metrics.ssim(image, image_j)
     return c_ssim - target
Example #9
0
def validate_nip(model,
                 data,
                 save_dir=False,
                 epoch=0,
                 show_ref=False,
                 loss_type='L2'):
    """
    Develops image patches using the given NIP and returns standard image quality measures.
    If requested, resulting patches are visualized as thumbnails and saved to a directory.

    :param model: the NIP model
    :param data: the dataset (instance of Dataset)
    :param data: the dataset (instance of Dataset)
    :param save_dir: path to the directory where figures should be generated
    :param epoch: epoch counter to be appended to the output filename
    :param show_ref: whether to show only the developed image or also the GT target
    :param loss_type: L1 or L2
    :return: tuple of lists with per-image measurements of (ssims, psnrs, losss)
    """

    ssims = []
    psnrs = []
    losss = []

    # If requested, plot a figure with output/target pairs
    if save_dir is not None:
        images_x = np.minimum(data.count_validation, 10 if not show_ref else 5)
        images_y = np.ceil(data.count_validation / images_x)
        fig = Figure(figsize=(20, 20 / images_x * images_y *
                              (1 if not show_ref else 0.5)))

    developed_out = np.zeros_like(data['validation']['y'], dtype=np.float32)

    for b in range(data.count_validation):
        example_x, example_y = data.next_validation_batch(b, 1)
        developed = model.process(example_x).numpy().clip(0, 1)
        developed_out[b, :, :, :] = developed
        developed = developed[:, :, :, :].squeeze()
        reference = example_y.squeeze()

        # Compute stats
        ssim = metrics.ssim(reference, developed).mean()
        psnr = metrics.psnr(reference, developed).mean()

        if loss_type == 'L2':
            loss = np.mean(np.power(reference - developed, 2.0))
        elif loss_type == 'L1':
            loss = np.mean(np.abs(reference - developed))
        else:
            raise ValueError('Invalid loss! Use either L1 or L2.')

        ssims.append(ssim)
        psnrs.append(psnr)
        losss.append(loss)

        # Add images to the plot
        if save_dir is not None:
            ax = fig.add_subplot(images_y, images_x, b + 1)
            plots.image(np.concatenate(
                (reference, developed), axis=1) if show_ref else developed,
                        '{:.1f} dB / {:.2f}'.format(psnr, ssim),
                        axes=ax)

    if save_dir is not None:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        fig.savefig('{}/nip_validation_{:05d}.jpg'.format(save_dir, epoch),
                    bbox_inches='tight',
                    dpi=100,
                    quality=90)
        del fig

    return ssims, psnrs, losss