Ejemplo n.º 1
0
def validate_dcn(dcn, data, save_dir=False, epoch=0, show_ref=False):
    """
    Computes validation metrics for a compression model (DCN). (If not a DCN, the function returns immediately).
    If requested, plot compressed images to a JPEG image.

    :param dcn: the DCN model
    :param data: the dataset (instance of IPDataset)
    :param save_dir: path to the directory where figures should be generated
    :param epoch: epoch counter to be appended to the output filename
    :param show_ref: whether to show only the compressed image or also the input image as reference
    :return: tuple of lists with per-image measurements of (ssims, psnrs, losses, entropies)
    """
    
    if not isinstance(dcn, DCN):
        return

    # Compute latent representations and compressed output
    batch_x = data.next_validation_batch(0, data.count_validation)
    if isinstance(batch_x, tuple):
        batch_x = batch_x[-1]
    batch_z = dcn.compress(batch_x, direct=True)
    batch_y = dcn.process(batch_x, direct=True)

    # Compute quality measures and entropy statistics
    codebook = dcn.get_codebook()
    
    ssims = [compare_ssim(batch_x[b], batch_y[b], multichannel=True, data_range=1) for b in range(data.count_validation)]
    psnrs = [compare_psnr(batch_x[b], batch_y[b], data_range=1) for b in range(data.count_validation)]
    losses = [dcn.sess.run(dcn.loss, feed_dict={dcn.x: batch_x[b:b + 1]}) for b in range(data.count_validation)]
    entropies = [utils.entropy(batch_z[b], codebook) for b in range(data.count_validation)]

    # If requested, plot a figure with input/output pairs
    if save_dir is not None:
        images_x = np.minimum(data.count_validation, 10 if not show_ref else 5)
        images_y = np.ceil(data.count_validation / images_x)
        fig = plt.figure(figsize=(20, 20 / images_x * images_y * (1 if not show_ref else 0.5)))

        for b in range(data.count_validation):
            ax = fig.add_subplot(images_y, images_x, b + 1)
            plotting.quickshow(
                np.concatenate((batch_x[b], batch_y[b]), axis=1) if show_ref else batch_y[b],
                '{:.1f} / {:.2f}'.format(psnrs[b], ssims[b]),
                axes=ax
            )

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        fig.savefig('{}/dcn_validation_{:05d}.jpg'.format(save_dir, epoch), bbox_inches='tight', dpi=100, quality=90)
        plt.close(fig)
        del fig
    
    return ssims, psnrs, losses, entropies
Ejemplo n.º 2
0
def test_output(image, jpeg_quality=50, rounding_approximation=None):

    jpg = DJPG(rounding_approximation=rounding_approximation, rounding_approximation_steps=5)
    print(jpg)

    batch_x = np.expand_dims(image, 0)
    batch_y = jpg.process(batch_x / 255, jpeg_quality)

    n_images = batch_x.shape[0]

    batch_j = np.zeros_like(batch_x)
    for n in range(n_images):
        io.imwrite('/tmp/patch_{}.jpg'.format(n), (batch_x[n].squeeze()).astype(np.uint8), quality=jpeg_quality)
        batch_j[n] = io.imread('/tmp/patch_{}.jpg'.format(n))

    for n in range(n_images):
        plt.subplot(n_images, 3, 1 + n*3)
        plotting.quickshow(batch_x[n].squeeze() / np.max(np.abs(batch_x)), 'Input')

        plt.subplot(n_images, 3, 2 + n*3)
        plotting.quickshow(batch_y[n].squeeze() / np.max(np.abs(batch_y)), 'dJPEG Model')

        plt.subplot(n_images, 3, 3 + n*3)
        plotting.quickshow(batch_j[n].squeeze() / np.max(np.abs(batch_j)), 'libJPG Codec')

    plt.show()
Ejemplo n.º 3
0
def match_jpeg(model, batch_x, axes=None, match='ssim'):

    # Compress using DCN and get number of bytes
    batch_y, bytes_dcn = codec.simulate_compression(batch_x, model)

    ssim_dcn = compare_ssim(batch_x.squeeze(),
                            batch_y.squeeze(),
                            multichannel=True,
                            data_range=1)
    bpp_dcn = 8 * bytes_dcn / np.prod(batch_x.shape[1:-1])
    target = ssim_dcn if match == 'ssim' else bpp_dcn

    try:
        jpeg_quality = jpeg_helpers.match_quality(batch_x.squeeze(),
                                                  target,
                                                  match=match)
    except:
        if match == 'ssim':
            jpeg_quality = 95 if ssim_dcn > 0.8 else 10
        else:
            jpeg_quality = 95 if bpp_dcn > 3 else 10
        print(
            'WARNING Could not find a matching JPEG quality factor - guessing {}'
            .format(jpeg_quality))

    # Compress using JPEG
    batch_j, bytes_jpeg = jpeg_helpers.compress_batch(batch_x[0],
                                                      jpeg_quality,
                                                      effective=True)
    ssim_jpeg = compare_ssim(batch_x.squeeze(),
                             batch_j.squeeze(),
                             multichannel=True,
                             data_range=1)
    bpp_jpg = 8 * bytes_jpeg / np.prod(batch_x.shape[1:-1])

    # Get stats
    code_book = model.get_codebook()
    batch_z = model.compress(batch_x)
    counts = utils.qhist(batch_z, code_book)
    counts = counts.clip(min=1)
    probs = counts / counts.sum()
    entropy = -np.sum(probs * np.log2(probs))

    # Print report
    print('DCN             : {}'.format(model.model_code))
    print('Pixels          : {}x{} = {:,} px'.format(
        batch_x.shape[1], batch_x.shape[2], np.prod(batch_x.shape[1:-1])))
    print('Bitmap          : {:,} bytes'.format(np.prod(batch_x.shape)))
    print('Code-book size  : {} elements from {} to {}'.format(
        len(code_book), min(code_book), max(code_book)))
    print('Entropy         : {:.2f} bits per symbol'.format(entropy))
    print('Latent size     : {:,}'.format(np.prod(batch_z.shape)))
    print('PPF Naive       : {:,.0f} --> {:,.0f} bytes [{} bits per element]'.
          format(
              np.prod(batch_z.shape) * np.log2(len(code_book)) / 8,
              np.prod(batch_z.shape) * np.ceil(np.log2(len(code_book))) / 8,
              np.ceil(np.log2(len(code_book)))))
    print('PPF Theoretical : {:,.0f} bytes ({:.2f} bpp)'.format(
        np.prod(batch_z.shape) * entropy / 8,
        np.prod(batch_z.shape) * entropy / np.prod(batch_x.shape[1:-1])))
    print('FSE Coded       : {:,} bytes ({:.2f} bpp) --> ssim: {:.3f}'.format(
        bytes_dcn, bpp_dcn, ssim_dcn))
    print(
        'JPEG (Q={:2d})     : {:,} bytes ({:0.2f} bpp) --> ssim: {:.3f} // effective size disregarding JPEG headers'
        .format(jpeg_quality, bytes_jpeg, bpp_jpg, ssim_jpeg))

    # Plot results
    if axes is None:
        fig, axes = plotting.sub(6, ncols=3)
        fig.set_size_inches(12, 10)
        fig.tight_layout()
    else:
        fig = axes[0].figure

    # Plot full-resolution
    plotting.quickshow(batch_x,
                       'Original ({0}x{0})'.format(batch_x.shape[1]),
                       axes=axes[0])
    plotting.quickshow(batch_y,
                       'DCN ssim:{:.2f} bpp:{:.2f}'.format(ssim_dcn, bpp_dcn),
                       axes=axes[1])
    plotting.quickshow(batch_j,
                       'JPEG {} ssim:{:.2f} bpp:{:.2f}'.format(
                           jpeg_quality, ssim_jpeg, bpp_jpg),
                       axes=axes[2])

    # Plot zoom
    crop_size = max([64, batch_x.shape[1] // 4])
    plotting.quickshow(utils.crop_middle(batch_x, crop_size),
                       'Original crop ({0}x{0})'.format(crop_size),
                       axes=axes[3])
    plotting.quickshow(utils.crop_middle(batch_y, crop_size),
                       'DCN crop ({0}x{0})'.format(crop_size),
                       axes=axes[4])
    plotting.quickshow(utils.crop_middle(batch_j, crop_size),
                       'JPEG crop ({0}x{0})'.format(crop_size),
                       axes=axes[5])

    return fig
Ejemplo n.º 4
0
def show_example(model, batch_x):

    # Compress and decompress model
    batch_z = model.compress(batch_x)
    batch_y = model.decompress(batch_z)

    # Get empirical histogram of the latent representation
    codebook = model.get_codebook()

    qmin = np.floor(codebook[0])
    qmax = np.ceil(codebook[-1])

    bin_centers = np.arange(qmin - 1, qmax + 1, 0.1)
    bin_boundaries = np.convolve(bin_centers, [0.5, 0.5], mode='valid')
    bin_centers = bin_centers[1:-1]

    hist_emp = np.histogram(batch_z.reshape((-1, )),
                            bins=bin_boundaries,
                            density=True)[0]
    hist_emp = np.maximum(hist_emp, 1e-9)
    hist_emp = hist_emp / hist_emp.sum()

    # Get TF histogram estimate based on soft quantization
    hist = model.get_tf_histogram(batch_x)

    # Entropy
    entropy = -np.sum(hist * np.log2(hist))
    entropy_emp = -np.sum(hist_emp * np.log2(hist_emp))

    fig, axes = plotting.sub(2, ncols=1)
    fig.set_size_inches(12, 10)

    axes[0].plot(bin_centers, hist_emp / hist_emp.max(), 'r-')
    axes[0].plot(codebook, hist / hist.max(), '-bo')

    axes[0].legend([
        'Empirical H={:.2f}'.format(entropy_emp),
        'TF estimate (soft) H={:.2f}'.format(entropy)
    ])
    axes[0].set_ylabel('normalized frequency')
    axes[0].set_xlabel('latent values')

    # Thumbnails
    indices = np.argsort(np.var(batch_x, axis=(1, 2, 3)))[::-1]
    thumbs_pairs_few = np.concatenate((batch_x[indices], batch_y[indices]),
                                      axis=0)
    thumbs_few = (255 * plotting.thumbnails(
        thumbs_pairs_few, n_cols=len(batch_x))).astype(np.uint8)

    ssim_values = [
        compare_ssim(batch_x[i], batch_y[i], multichannel=True)
        for i in range(len(batch_x))
    ]

    plotting.quickshow(thumbs_few,
                       'Sample reconstructions, ssim={:.3f}'.format(
                           np.mean(ssim_values)),
                       axes=axes[1])

    fig.tight_layout()
    return fig
Ejemplo n.º 5
0
def compare_images_ab_ref(img_ref, img_a, img_b, labels=None):
    from helpers import plotting
    from skimage.measure import compare_psnr, compare_ssim

    labels = labels or ['target', '', '']

    img_a = img_a.squeeze()
    img_b = img_b.squeeze()
    img_ref = img_ref.squeeze()

    fig, axes = plotting.sub(9)

    plotting.quickshow(img_ref, '(T) {}'.format(labels[0]), axes=axes[0])

    label_a = '(A) {}: {:.1f} dB / {:.3f}'.format(
        labels[1], compare_psnr(img_ref, img_a, data_range=1.0),
        compare_ssim(img_ref, img_a, data_range=1.0, multichannel=True))
    plotting.quickshow(img_a, label_a, axes=axes[1])

    label_b = '(B) {}: {:.1f} dB / {:.3f}'.format(
        labels[2], compare_psnr(img_ref, img_b, data_range=1.0),
        compare_ssim(img_ref, img_b, data_range=1.0, multichannel=True))
    plotting.quickshow(img_b, label_b, axes=axes[3])

    # Compute and plot difference images
    diff_a = np.abs(img_a - img_ref)
    diff_a_mean = diff_a.mean()
    diff_a = nm(diff_a, 0.1)

    diff_b = np.abs(img_b - img_ref)
    diff_b_mean = diff_b.mean()
    diff_b = nm(diff_b, 0.1)

    diff_ab = np.abs(img_b - img_a)
    diff_ab_mean = diff_ab.mean()
    diff_ab = nm(diff_ab, 0.1)

    plotting.quickshow(diff_a,
                       'T - A: mean abs {:.3f}'.format(diff_a_mean),
                       axes=axes[2])
    plotting.quickshow(diff_b,
                       'T - B: mean abs {:.3f}'.format(diff_b_mean),
                       axes=axes[6])
    plotting.quickshow(diff_ab,
                       'A - B: mean abs {:.3f}'.format(diff_ab_mean),
                       axes=axes[4])

    # Compute and plot spectra
    fft_a = fft_log_norm(diff_a)
    fft_b = fft_log_norm(diff_b)

    # fft_ab = nm(np.abs(fft_a - fft_b))
    fft_ab = nm(np.abs(fft_log_norm(img_b) - fft_log_norm(img_a)), 0.01)
    plotting.quickshow(fft_a, 'FFT(T - A)', axes=axes[5])
    plotting.quickshow(fft_b, 'FFT(T - B)', axes=axes[7])
    plotting.quickshow(fft_ab, 'FFT(A) - FFT(B)', axes=axes[8])

    return fig
Ejemplo n.º 6
0
def validate_nip(model, data, save_dir=False, epoch=0, show_ref=False, loss_type='L2'):
    """
    Develops image patches using the given NIP and returns standard image quality measures.
    If requested, resulting patches are visualized as thumbnails and saved to a directory.

    :param model: the NIP model
    :param data: the dataset (instance of IPDataset)
    :param save_dir: path to the directory where figures should be generated
    :param epoch: epoch counter to be appended to the output filename
    :param show_ref: whether to show only the developed image or also the GT target
    :param loss_type: L1 or L2
    :return: tuple of lists with per-image measurements of (ssims, psnrs, losss)
    """

    ssims = []
    psnrs = []
    losss = []

    # If requested, plot a figure with output/target pairs
    if save_dir is not None:
        images_x = np.minimum(data.count_validation, 10 if not show_ref else 5)
        images_y = np.ceil(data.count_validation / images_x)
        fig = plt.figure(figsize=(20, 20 / images_x * images_y * (1 if not show_ref else 0.5)))
        
    developed_out = np.zeros_like(data['validation']['y'], dtype=np.float32)

    for b in range(data.count_validation):
        example_x, example_y = data.next_validation_batch(b, 1)
        developed = model.process(example_x).clip(0, 1)
        developed_out[b, :, :, :] = developed
        developed = developed[:, :, :, :].squeeze()
        reference = example_y.squeeze()

        # Compute stats
        ssim = compare_ssim(reference, developed, multichannel=True, data_range=1)
        psnr = compare_psnr(reference, developed, data_range=1)
        
        if loss_type == 'L2':
            loss = np.mean(np.power(reference - developed, 2.0))
        elif loss_type == 'L1':
            loss = np.mean(np.abs(reference - developed))
        else:
            raise ValueError('Invalid loss! Use either L1 or L2.')
            
        ssims.append(ssim)
        psnrs.append(psnr)
        losss.append(loss)
        
        # Add images to the plot
        if save_dir is not None:
            ax = fig.add_subplot(images_y, images_x, b+1)
            plotting.quickshow(
                np.concatenate((reference, developed), axis=1) if show_ref else developed,
                '{:.1f} dB / {:.2f}'.format(psnr, ssim),
                axes=ax
            )

    if save_dir is not None:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        fig.savefig('{}/nip_validation_{:05d}.jpg'.format(save_dir, epoch), bbox_inches='tight', dpi=100, quality=90)
        plt.close(fig)
        del fig
    
    return ssims, psnrs, losss