def test_round_trip():
    x = torch.rand(1, 3, 32, 32)
    rv = ycbcr2rgb(rgb2ycbcr(x))
    assert torch.allclose(x, rv, atol=1e-5)

    rv = rgb2ycbcr(ycbcr2rgb(x))
    assert torch.allclose(x, rv, atol=1e-5)
Exemple #2
0
def compute_metrics_for_frame(
    org_frame: Frame,
    dec_frame: Frame,
    bitdepth: int = 8,
) -> Dict[str, Any]:
    org_frame = tuple(p.unsqueeze(0).unsqueeze(0) for p in org_frame)  # type: ignore
    dec_frame = tuple(p.unsqueeze(0).unsqueeze(0) for p in dec_frame)  # type:ignore
    out: Dict[str, Any] = {}

    max_val = 2**bitdepth - 1

    # YCbCr metrics
    for i, component in enumerate("yuv"):
        out[f"mse-{component}"] = (org_frame[i] - dec_frame[i]).pow(2).mean()

    org_rgb = ycbcr2rgb(yuv_420_to_444(org_frame, mode="bicubic").true_divide(max_val))  # type: ignore
    dec_rgb = ycbcr2rgb(yuv_420_to_444(dec_frame, mode="bicubic").true_divide(max_val))  # type: ignore

    org_rgb = (org_rgb * max_val).clamp(0, max_val).round()
    dec_rgb = (dec_rgb * max_val).clamp(0, max_val).round()
    mse_rgb = (org_rgb - dec_rgb).pow(2).mean()

    ms_ssim_rgb = ms_ssim(org_rgb, dec_rgb, data_range=max_val)
    out.update({"ms-ssim-rgb": ms_ssim_rgb, "mse-rgb": mse_rgb})
    return out
Exemple #3
0
def convert_yuv420_rgb(frame: Tuple[np.ndarray, np.ndarray, np.ndarray],
                       device: torch.device, max_val: int) -> Tensor:
    # yuv420 [0, 2**bitdepth-1] to rgb 444 [0, 1] only for now
    frame = to_tensors(frame, device=str(device), max_value=max_val)
    frame = yuv_420_to_444(
        tuple(c.unsqueeze(0).unsqueeze(0) for c in frame),
        mode="bicubic"  # type: ignore
    )
    return ycbcr2rgb(frame)  # type: ignore
Exemple #4
0
    def _run(self, img, quality, return_rec=False):
        if not 0 <= quality <= 63:
            raise ValueError(f"Invalid quality value: {quality} (0,63)")

        # Convert input image to yuv 444 file
        arr = np.asarray(read_image(img))
        fd, yuv_path = mkstemp(suffix=".yuv")
        out_filepath = os.path.splitext(yuv_path)[0] + ".webm"
        bitdepth = 8

        arr = arr.transpose((2, 0, 1))  # color channel first

        # convert rgb content to YCbCr
        rgb = torch.from_numpy(arr.copy()).float() / (2**bitdepth - 1)
        arr = np.clip(rgb2ycbcr(rgb).numpy(), 0, 1)
        arr = (arr * (2**bitdepth - 1)).astype(np.uint8)

        with open(yuv_path, "wb") as f:
            f.write(arr.tobytes())

        # Encode
        height, width = arr.shape[1:]
        cmd = [
            self.encoder_path,
            "-w",
            width,
            "-h",
            height,
            "--fps=1/1",
            "--limit=1",
            "--input-bit-depth=8",
            "--cpu-used=0",
            "--threads=1",
            "--passes=2",
            "--end-usage=q",
            "--cq-level=" + str(quality),
            "--i444",
            "--skip=0",
            "--tune=psnr",
            "--psnr",
            "--bit-depth=8",
            "-o",
            out_filepath,
            yuv_path,
        ]

        start = time.time()
        run_command(cmd)
        enc_time = time.time() - start

        # cleanup encoder input
        os.close(fd)
        os.unlink(yuv_path)

        # Decode
        cmd = [
            self.decoder_path,
            out_filepath,
            "-o",
            yuv_path,
            "--rawvideo",
            "--output-bit-depth=8",
        ]

        start = time.time()
        run_command(cmd)
        dec_time = time.time() - start

        # Compute PSNR
        rec_arr = np.fromfile(yuv_path, dtype=np.uint8)
        rec_arr = rec_arr.reshape(arr.shape)

        arr = arr.astype(np.float32) / (2**bitdepth - 1)
        rec_arr = rec_arr.astype(np.float32) / (2**bitdepth - 1)

        arr = ycbcr2rgb(torch.from_numpy(arr.copy())).numpy()
        rec_arr = ycbcr2rgb(torch.from_numpy(rec_arr.copy())).numpy()

        psnr_val, msssim_val = compute_metrics(arr, rec_arr, max_val=1.0)

        bpp = filesize(out_filepath) * 8.0 / (height * width)

        # Cleanup
        os.unlink(yuv_path)
        os.unlink(out_filepath)

        out = {
            "psnr": psnr_val,
            "ms-ssim": msssim_val,
            "bpp": bpp,
            "encoding_time": enc_time,
            "decoding_time": dec_time,
        }
        if return_rec:
            rec = Image.fromarray(
                (rec_arr.transpose(1, 2, 0) * 255.0).astype(np.uint8))
            return out, rec
        return out
Exemple #5
0
    def _run(self, img, quality, return_rec=False):
        if not 0 <= quality <= 51:
            raise ValueError(f"Invalid quality value: {quality} (0,51)")

        # Convert input image to yuv 444 file
        arr = np.asarray(read_image(img))
        fd, yuv_path = mkstemp(suffix=".yuv")
        out_filepath = os.path.splitext(yuv_path)[0] + ".bin"
        bitdepth = 8

        arr = arr.transpose((2, 0, 1))  # color channel first

        if not self.rgb:
            # convert rgb content to YCbCr
            rgb = torch.from_numpy(arr.copy()).float() / (2**bitdepth - 1)
            arr = np.clip(rgb2ycbcr(rgb).numpy(), 0, 1)
            arr = (arr * (2**bitdepth - 1)).astype(np.uint8)

        with open(yuv_path, "wb") as f:
            f.write(arr.tobytes())

        # Encode
        height, width = arr.shape[1:]
        cmd = [
            self.encoder_path,
            "-i",
            yuv_path,
            "-c",
            self.config_path,
            "-q",
            quality,
            "-o",
            "/dev/null",
            "-b",
            out_filepath,
            "-wdt",
            width,
            "-hgt",
            height,
            "-fr",
            "1",
            "-f",
            "1",
            "--InputChromaFormat=444",
            "--InputBitDepth=8",
            "--SEIDecodedPictureHash",
            "--Level=5.1",
            "--CUNoSplitIntraACT=0",
            "--ConformanceMode=1",
        ]

        if self.rgb:
            cmd += [
                "--InputColourSpaceConvert=RGBtoGBR",
                "--SNRInternalColourSpace=1",
                "--OutputInternalColourSpace=0",
            ]
        start = time.time()

        run_command(cmd)
        enc_time = time.time() - start

        # cleanup encoder input
        os.close(fd)
        os.unlink(yuv_path)

        # Decode
        cmd = [self.decoder_path, "-b", out_filepath, "-o", yuv_path, "-d", 8]

        if self.rgb:
            cmd.append("--OutputInternalColourSpace=GBRtoRGB")

        start = time.time()
        run_command(cmd)
        dec_time = time.time() - start
        # Compute PSNR
        rec_arr = np.fromfile(yuv_path, dtype=np.uint8)
        rec_arr = rec_arr.reshape(arr.shape)
        arr = arr.astype(np.float32) / (2**bitdepth - 1)
        rec_arr = rec_arr.astype(np.float32) / (2**bitdepth - 1)
        if not self.rgb:
            arr = ycbcr2rgb(torch.from_numpy(arr.copy())).numpy()
            rec_arr = ycbcr2rgb(torch.from_numpy(rec_arr.copy())).numpy()
        psnr_val, msssim_val = compute_metrics(arr, rec_arr, max_val=1.0)

        bpp = filesize(out_filepath) * 8.0 / (height * width)

        # Cleanup
        os.unlink(yuv_path)
        os.unlink(out_filepath)

        out = {
            "psnr": psnr_val,
            "ms-ssim": msssim_val,
            "bpp": bpp,
            "encoding_time": enc_time,
            "decoding_time": dec_time,
        }
        if return_rec:
            rec = Image.fromarray(
                (rec_arr.transpose(1, 2, 0) * 255.0).astype(np.uint8))
            return out, rec
        return out
Exemple #6
0
    def _run_impl(self, in_filepath, quality):
        if not 0 <= quality <= 63:
            raise ValueError(f"Invalid quality value: {quality} (0,63)")

        # Taking 8bit input for now
        bitdepth = 8

        # Convert input image to yuv 444 file
        arr = np.asarray(self._load_img(in_filepath))
        fd, yuv_path = mkstemp(suffix=".yuv")
        out_filepath = os.path.splitext(yuv_path)[0] + ".bin"

        arr = arr.transpose((2, 0, 1))  # color channel first

        if not self.rgb:
            # convert rgb content to YCbCr
            rgb = torch.from_numpy(arr.copy()).float() / (2**bitdepth - 1)
            arr = np.clip(rgb2ycbcr(rgb).numpy(), 0, 1)
            arr = (arr * (2**bitdepth - 1)).astype(np.uint8)

        with open(yuv_path, "wb") as f:
            f.write(arr.tobytes())

        # Encode
        height, width = arr.shape[1:]
        cmd = [
            self.encoder_path,
            "-i",
            yuv_path,
            "-c",
            self.config_path,
            "-q",
            quality,
            "-o",
            "/dev/null",
            "-b",
            out_filepath,
            "-wdt",
            width,
            "-hgt",
            height,
            "-fr",
            "1",
            "-f",
            "1",
            "--InputChromaFormat=444",
            "--InputBitDepth=8",
            "--ConformanceWindowMode=1",
        ]

        if self.rgb:
            cmd += [
                "--InputColourSpaceConvert=RGBtoGBR",
                "--SNRInternalColourSpace=1",
                "--OutputInternalColourSpace=0",
            ]
        start = time.time()
        run_command(cmd)
        enc_time = time.time() - start

        # cleanup encoder input
        os.close(fd)
        os.unlink(yuv_path)

        # Decode
        cmd = [self.decoder_path, "-b", out_filepath, "-o", yuv_path, "-d", 8]
        if self.rgb:
            cmd.append("--OutputInternalColourSpace=GBRtoRGB")

        start = time.time()
        run_command(cmd)
        dec_time = time.time() - start

        # Compute PSNR
        rec_arr = np.fromfile(yuv_path, dtype=np.uint8)
        rec_arr = rec_arr.reshape(arr.shape)

        arr = arr.astype(np.float32) / (2**bitdepth - 1)
        rec_arr = rec_arr.astype(np.float32) / (2**bitdepth - 1)
        if not self.rgb:
            arr = ycbcr2rgb(torch.from_numpy(arr.copy())).numpy()
            rec_arr = ycbcr2rgb(torch.from_numpy(rec_arr.copy())).numpy()

        bpp = filesize(out_filepath) * 8.0 / (height * width)

        # Cleanup
        os.unlink(yuv_path)
        os.unlink(out_filepath)

        out = {
            "bpp": bpp,
            "encoding_time": enc_time,
            "decoding_time": dec_time,
        }

        rec = Image.fromarray(
            (rec_arr.clip(0, 1).transpose(1, 2, 0) * 255.0).astype(np.uint8)
        )
        return out, rec