Esempio n. 1
0
def _decode(inputpath, coder, show, device, output=None):
    decode_func = {
        CodecType.IMAGE_CODEC: decode_image,
        CodecType.VIDEO_CODEC: decode_video,
    }

    compressai.set_entropy_coder(coder)

    dec_start = time.time()
    with Path(inputpath).open("rb") as f:
        model, metric, quality = parse_header(read_uchars(f, 2))

        original_size = read_uints(f, 2)
        original_bitdepth = read_uchars(f, 1)[0]

        start = time.time()
        model_info = models[model]
        net = (model_info(quality=quality, metric=metric,
                          pretrained=True).to(device).eval())
        codec_type = (CodecType.IMAGE_CODEC
                      if model in image_models else CodecType.VIDEO_CODEC)

        load_time = time.time() - start
        print(f"Model: {model:s}, metric: {metric:s}, quality: {quality:d}")

        stream_info = CodecInfo(None, original_size, original_bitdepth, net,
                                device)
        out = decode_func[codec_type](f, stream_info, output)

    dec_time = time.time() - dec_start
    print(f"Decoded in {dec_time:.2f}s (model loading: {load_time:.2f}s)")

    if show:
        # For video, only the last frame is shown
        show_image(out["img"])
Esempio n. 2
0
def main(argv):
    args = setup_args().parse_args(argv)

    compressai.set_entropy_coder(args.entropy_coder)

    results = defaultdict(list)
    for q in args.qualities:
        sys.stderr.write(f'\r{args.model} | quality: {q:d}')
        sys.stderr.flush()
        model = models[args.model](quality=q,
                                   metric=args.metric,
                                   pretrained=True).eval()
        metrics = run_model(model, args.dataset, args.entropy_estimation)
        for k, v in metrics.items():
            results[k].append(v)
    sys.stderr.write('\n')
    sys.stderr.flush()

    description = 'entropy estimation' \
        if args.entropy_estimation else args.entropy_coder
    output = {
        'name': args.model,
        'description': f'Inference ({description})',
        'results': results,
    }

    print(json.dumps(output, indent=2))
Esempio n. 3
0
def _encode(input, num_of_frames, model, metric, quality, coder, device,
            output):
    encode_func = {
        CodecType.IMAGE_CODEC: encode_image,
        CodecType.VIDEO_CODEC: encode_video,
    }

    compressai.set_entropy_coder(coder)
    enc_start = time.time()

    start = time.time()
    model_info = models[model]
    net = model_info(quality=quality, metric=metric,
                     pretrained=True).to(device).eval()
    codec_type = (CodecType.IMAGE_CODEC
                  if model in image_models else CodecType.VIDEO_CODEC)

    codec_header_info = get_header(model, metric, quality, num_of_frames,
                                   codec_type)
    load_time = time.time() - start

    if not Path(input).is_file():
        raise FileNotFoundError(f"{input} does not exist")

    codec_info = CodecInfo(codec_header_info, None, None, net, device)
    out = encode_func[codec_type](input, codec_info, output)

    enc_time = time.time() - enc_start

    print(f"{out['bpp']:.3f} bpp |"
          f" Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)")
Esempio n. 4
0
def _encode(image, model, metric, quality, coder, output):
    compressai.set_entropy_coder(coder)
    enc_start = time.time()

    img = load_image(image)
    start = time.time()
    net = models[model](quality=quality, metric=metric, pretrained=True).eval()
    load_time = time.time() - start

    x = img2torch(img)
    h, w = x.size(2), x.size(3)
    p = 64  # maximum 6 strides of 2
    x = pad(x, p)

    with torch.no_grad():
        out = net.compress(x)

    shape = out['shape']
    header = get_header(model, metric, quality)

    with Path(output).open('wb') as f:
        write_uchars(f, header)
        # write original image size
        write_uints(f, (h, w))
        # write shape and number of encoded latents
        write_uints(f, (shape[0], shape[1], len(out['strings'])))
        for s in out['strings']:
            write_uints(f, (len(s[0]), ))
            write_bytes(f, s[0])

    enc_time = time.time() - enc_start
    size = filesize(output)
    bpp = float(size) * 8 / (img.size[0] * img.size[1])
    print(f'{bpp:.3f} bpp |'
          f' Encoded in {enc_time:.2f}s (model loading: {load_time:.2f}s)')
Esempio n. 5
0
def _decode(inputpath, coder, show, output=None):
    compressai.set_entropy_coder(coder)

    dec_start = time.time()
    with Path(inputpath).open('rb') as f:
        model, metric, quality = parse_header(read_uchars(f, 2))
        original_size = read_uints(f, 2)
        shape = read_uints(f, 2)
        strings = []
        n_strings = read_uints(f, 1)[0]
        for _ in range(n_strings):
            s = read_bytes(f, read_uints(f, 1)[0])
            strings.append([s])

    print(f'Model: {model:s}, metric: {metric:s}, quality: {quality:d}')
    start = time.time()
    net = models[model](quality=quality, metric=metric, pretrained=True).eval()
    load_time = time.time() - start

    with torch.no_grad():
        out = net.decompress(strings, shape)

    x_hat = crop(out['x_hat'], original_size)
    img = torch2img(x_hat)
    dec_time = time.time() - dec_start
    print(f'Decoded in {dec_time:.2f}s (model loading: {load_time:.2f}s)')

    if show:
        show_image(img)
    if output is not None:
        img.save(output)
Esempio n. 6
0
def main(argv):
    parser = setup_args()
    args = parser.parse_args(argv)

    if not args.source:
        print("Error: missing 'checkpoint' or 'pretrained' source.",
              file=sys.stderr)
        parser.print_help()
        sys.exit(1)

    filepaths = collect_images(args.dataset)
    if len(filepaths) == 0:
        print("Error: no images found in directory.", file=sys.stderr)
        sys.exit(1)

    compressai.set_entropy_coder(args.entropy_coder)

    if args.source == "pretrained":
        runs = sorted(args.qualities)
        opts = (args.architecture, args.metric)
        load_func = load_pretrained
        log_fmt = "\rEvaluating {0} | {run:d}"
    elif args.source == "checkpoint":
        runs = args.paths
        opts = (args.architecture, )
        load_func = load_checkpoint
        log_fmt = "\rEvaluating {run:s}"

    results = defaultdict(list)
    for run in runs:
        if args.verbose:
            sys.stderr.write(log_fmt.format(*opts, run=run))
            sys.stderr.flush()
        model = load_func(*opts, run)
        if args.cuda and torch.cuda.is_available():
            model = model.to("cuda")
        metrics = eval_model(model, filepaths, args.entropy_estimation,
                             args.half)
        for k, v in metrics.items():
            results[k].append(v)

    if args.verbose:
        sys.stderr.write("\n")
        sys.stderr.flush()

    description = ("entropy estimation"
                   if args.entropy_estimation else args.entropy_coder)
    output = {
        "name": args.architecture,
        "description": f"Inference ({description})",
        "results": results,
    }

    print(json.dumps(output, indent=2))
Esempio n. 7
0
def main(argv):
    args = setup_args().parse_args(argv)

    filepaths = collect_images(args.dataset)
    if len(filepaths) == 0:
        print("No images found in directory.")
        sys.exit(1)

    compressai.set_entropy_coder(args.entropy_coder)

    if args.source == "pretrained":
        runs = sorted(args.qualities)
        opts = (args.arch, args.metric)
        load_func = load_pretrained
        log_fmt = "\rEvaluating {0} | {run:d}"
    elif args.source == "checkpoint":
        runs = args.paths
        opts = (args.arch, )
        load_func = load_checkpoint
        log_fmt = "\rEvaluating {run:s}"

    results = defaultdict(list)
    for run in runs:
        if args.verbose:
            sys.stderr.write(log_fmt.format(*opts, run=run))
            sys.stderr.flush()
        model = load_func(*opts, run)
        metrics = eval_model(model, filepaths, args.entropy_estimation)
        for k, v in metrics.items():
            results[k].append(v)

    if args.verbose:
        sys.stderr.write("\n")
        sys.stderr.flush()

    description = ("entropy estimation"
                   if args.entropy_estimation else args.entropy_coder)
    output = {
        "name": args.arch,
        "description": f"Inference ({description})",
        "results": results,
    }

    print(json.dumps(output, indent=2))
Esempio n. 8
0
def test_set_entropy_coder():
    compressai.set_entropy_coder("ans")

    with pytest.raises(ValueError):
        compressai.set_entropy_coder("cabac")
Esempio n. 9
0
def test(argv):
    args = parse_args(argv)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    # set entropy coder
    coder = "ans"
    compressai.set_entropy_coder(coder)

    # load models
    from compressai.models import RPLVC
    net = RPLVC(N=128, Nf=128).cpu().eval()
    if torch.cuda.is_available():
        net = net.cuda()
    PATH = "./pretrained/checkpoint" + args.model + ".pth.tar"
    QP = args.qp
    gop = args.gop
    net.load_state_dict(torch.load(PATH, map_location='cpu')['state_dict'])

    #for dataset in ['ClassD','ClassC','ClassE','ClassB']:
    for dataset in ['ClassD']:
        video_root_path = os.path.join("./datasets/", dataset)
        bit_path = os.path.join('./bits/',
                                PATH.split('/')[-1].split('.')[0], dataset)
        rec_path = os.path.join('./recs/',
                                PATH.split('/')[-1].split('.')[0], dataset)
        rec_dec_path = os.path.join('./recs_dec/',
                                    PATH.split('/')[-1].split('.')[0], dataset)
        if args.encode:
            t_start = time.time()
            os.system(" ".join(
                ["mkdir", "-p", bit_path, rec_path, rec_dec_path]))
            #os.system(" ".join(["rm", bit_path+"/*", rec_path+"/*"]))
            psnr_all_list = []
            msssim_all_list = []
            bpp_all_list = []
            for video in os.listdir(video_root_path):
                coded_frame_num = 100
                psnr_list, msssim_list, bpp_resi_list, bpp_mv_list, bpp_sum_list = compress_video(
                    net,
                    video,
                    video_root_path,
                    coded_frame_num,
                    gop,
                    bit_path,
                    rec_path,
                    bpg_coding=True,
                    QP=QP,
                    vtm_coding=False,
                    verbose=args.verbose)
                assert len(psnr_list) == coded_frame_num
                psnr_all_list.append(mean(psnr_list))
                msssim_all_list.append(mean(msssim_list))
                bpp_all_list.append(mean(bpp_sum_list))
            enc_time = time.time() - t_start
            print(dataset)
            print(f" Encoded in {enc_time:.2f}s, hat mode |"
                  f" psnr {mean(psnr_all_list):.4f} |"
                  f" ms-ssim {mean(msssim_all_list):.4f} |"
                  f" bpp {mean(bpp_all_list):.4f}\n")
        if args.decode:
            t_start = time.time()
            #os.system(" ".join(["rm", rec_dec_path+"/*"]))
            for bit in os.listdir(bit_path):
                if "lvc.bin" in bit:
                    coded_frame_num = 100
                    decompress_video(net,
                                     coded_frame_num,
                                     gop,
                                     bit,
                                     bit_path,
                                     rec_dec_path,
                                     bpg_decoding=True,
                                     vtm_decoding=False,
                                     verbose=args.verbose)
            dec_time = time.time() - t_start
            print("Summary:")
            print(f" Decoded in {dec_time:.2f}s, hat mode\n")
        if args.check:
            print("Checking " + dataset)
            for recs in os.listdir(rec_path):
                enc_rec = os.path.join(rec_path, recs)
                dec_rec = os.path.join(rec_dec_path, recs)
                assert os.system(" ".join(["cmp", enc_rec,
                                           dec_rec])) == 0, "MISMATCH!!!"
            print("  Check Pass!")
Esempio n. 10
0
def test_set_entropy_coder():
    compressai.set_entropy_coder('ans')

    with pytest.raises(ValueError):
        compressai.set_entropy_coder('cabac')