Beispiel #1
0
def get_psnr(x: SymbolTensor, y: SymbolTensor):
    """
    Notes: Make sure that the input tensors are floats, as otherwise we get over-/underflows when we calculate MSE!
    Tested to be same as tf.image.psnr
    """
    assert x.L == y.L
    max_val = x.L - 1
    # NOTE: thats how tf.image.psnr does the mean, too: MSE over spatial, PSNR over batch
    mse = (x.get() - y.get()).pow(2).float().mean((1, 2, 3))
    assert len(mse.shape) == 1, mse.shape
    return 10. * torch.log10((max_val**2) / mse).mean()
Beispiel #2
0
def new_bottleneck_summary(s: SymbolTensor):
    """
    Grayscale bottleneck representation: Expects the actual bottleneck symbols.
    :param s: NCHW
    :return: [0, 1] image
    """
    s_raw, L = s.get(), s.L
    assert s_raw.dim() == 4, s_raw.shape
    s_raw = s_raw.detach().float().div(L)
    grid = vis.grid.prep_for_grid(s_raw, channelwise=True)
    assert len(grid) == s_raw.shape[1], (len(grid), s_raw.shape)
    assert [g.max() <= 1 for g in grid], [g.max() for g in grid]
    assert grid[0].dtype == torch.float32, grid.dtype
    return torchvision.utils.make_grid(grid, nrow=5)
    def _encode(self, pin, pout) -> EncodeOut:
        """
        :param pin:
        :param pout:
        :return:  tuple (img, actual_bpsp), where img is int64 1CHW
        """
        assert not os.path.isfile(pout)
        img = self.pil_to_1CHW_long(
            Image.open(pin))  # int64 1CHW pe.DEVICE tensor
        assert len(
            img.shape
        ) == 4 and img.shape[0] == 1 and img.shape[1] == 3, img.shape

        # gt
        x_r = SymbolTensor(img, L=256).to_norm()

        if self.blueprint.clf is not None:
            with self.times.run('Q-Classifier'):
                q = self.blueprint.clf.get_q(x_r.get())
        else:
            q = 12  # TODO

        with self.times.run(f'BPG'):
            # img = img.float()
            # Encode BPG
            pout_bpg = self._path_for_bpg(pout)
            bpp_bpg = self._encode_bpg(pin, pout_bpg, q)
            # 1. sym -> norm (for l)
            x_l: NormalizedTensor = self._decode_bpg(pout_bpg)

        with self.times.run('[-] encode forwardpass'):
            # 1. sym -> norm (for r)
            network_out: prob_clf.NetworkOutput = self.blueprint.forward_lossy(
                x_l, torch.tensor([bpp_bpg], device=pe.DEVICE))
            # in here:
            # 2. norm -> sym (for l and r)
            out = EnhancementOut(network_out, x_r, x_l)

        if self.compare_with_theory:
            with self.times.run('[-] get loss'):
                num_subpixels_before_pad = np.prod(img.shape)
                loss_out = self.blueprint.losses(
                    out,
                    num_subpixels_before_pad=num_subpixels_before_pad,
                    base_bpp=bpp_bpg)

        entropy_coding_bytes = []  # bytes used by different scales

        dmll = self.blueprint.losses.loss_dmol_rgb

        with open(pout, 'wb') as fout:
            with self.times.prefix_scope(f'RGB'):
                entropy_coding_bytes.append(self.encode_rgb(dmll, out, fout))
                fout.write(_MAGIC_VALUE_SEP)

        num_subpixels = np.prod(img.shape)
        actual_num_bytes = os.path.getsize(pout) + os.path.getsize(pout_bpg)
        actual_bpsp = actual_num_bytes * 8 / num_subpixels

        if self.compare_with_theory:
            # TODO
            raise NotImplementedError
            # assumed_bpsps = [b * 8 / num_subpixels for b in entropy_coding_bytes]
            # tostr = lambda l: ' | '.join(map('{:.3f}'.format, l)) + f' => {sum(l):.3f}'
            # overhead = (sum(assumed_bpsps) / sum(loss_out.nonrecursive_bpsps) - 1) * 100
            # return f'Bitrates:\n' \
            #     f'theory:  {tostr(loss_out.nonrecursive_bpsps)}\n' \
            #     f'assumed: {tostr(list(reversed(assumed_bpsps)))} [{overhead:.2f}%]\n' \
            #     f'actual:                                => {actual_bpsp:.3f} [{actual_num_bytes} bytes]'
        else:
            return EncodeOut(img, actual_bpsp, None)