Ejemplo n.º 1
0
    def test_params(self):
        for i in range(1, 6):
            net = bmshj2018_hyperprior(i, metric="mse")
            assert isinstance(net, ScaleHyperprior)
            assert net.state_dict()["g_a.0.weight"].size(0) == 128
            assert net.state_dict()["g_a.6.weight"].size(0) == 192

        for i in range(6, 9):
            net = bmshj2018_hyperprior(i, metric="mse")
            assert isinstance(net, ScaleHyperprior)
            assert net.state_dict()["g_a.0.weight"].size(0) == 192
            assert net.state_dict()["g_a.6.weight"].size(0) == 320
Ejemplo n.º 2
0
def forward(img_path):
    net = bmshj2018_hyperprior(quality=2, pretrained=True).eval().to(device)
    print(f'Parameters: {sum(p.numel() for p in net.parameters())}')

    img = Image.open(img_path).convert('RGB')

    encoded = _encode(img, net) # return torch tensor
    decoded = _decode(encoded, img, net) # return torch tensor
Ejemplo n.º 3
0
    def test_invalid_params(self):
        with pytest.raises(ValueError):
            bmshj2018_hyperprior(-1)

        with pytest.raises(ValueError):
            bmshj2018_hyperprior(10)

        with pytest.raises(ValueError):
            bmshj2018_hyperprior(10, metric="ssim")

        with pytest.raises(ValueError):
            bmshj2018_hyperprior(1, metric="ssim")
Ejemplo n.º 4
0
    def test_update(self):
        # get a pretrained model
        net = bmshj2018_hyperprior(quality=1, pretrained=True).eval()
        assert not net.update()
        assert not net.update(force=False)

        quantized_cdf = net.gaussian_conditional._quantized_cdf
        offset = net.gaussian_conditional._offset
        cdf_length = net.gaussian_conditional._cdf_length
        assert net.update(force=True)

        def approx(a, b):
            return ((a - b).abs() <= 2).all()

        assert approx(net.gaussian_conditional._cdf_length, cdf_length)
        assert approx(net.gaussian_conditional._offset, offset)
        assert approx(net.gaussian_conditional._quantized_cdf, quantized_cdf)
Ejemplo n.º 5
0
def main(argv):
    cmd_args = parse_args(argv)

    if cmd_args.resume == True:
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        ckpt = load_checkpoint(cmd_args.ckpt, device)
        args = ckpt["args"]
        save_dirs = prepare_save(model=args.model,
                                 dataset=args.dataname,
                                 quality=args.quality,
                                 ckpt=cmd_args.ckpt)
    else:
        args = cmd_args
        save_dirs = prepare_save(model=args.model,
                                 dataset=args.dataname,
                                 quality=args.quality,
                                 args=args)

    logger = Logger(log_interval=100,
                    test_inteval=100,
                    save_dirs=save_dirs,
                    max_iter=args.iterations)
    train_writer = SummaryWriter(
        os.path.join(save_dirs["tensorboard_runs"], "train"))
    test_writer = SummaryWriter(
        os.path.join(save_dirs["tensorboard_runs"], "test"))

    if args.seed is not None:
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    train_transforms = transforms.Compose(
        [transforms.RandomCrop(args.patch_size),
         transforms.ToTensor()])

    test_transforms = transforms.Compose([transforms.ToTensor()])

    train_dataset = ImageFolder(args.dataset,
                                split="train",
                                transform=train_transforms)
    test_dataset = ImageFolder(args.dataset,
                               split="test",
                               transform=test_transforms)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True,
                                  pin_memory=True,
                                  drop_last=True)

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=args.test_batch_size,
        num_workers=args.num_workers,
        shuffle=False,
        pin_memory=True,
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"

    if args.model == "AE":
        net = AutoEncoder()
    elif args.model == "scale_bmshj2018_factorized":
        net = Scale_FactorizedPrior(N=192, M=320, scale=args.scale)
    elif args.model == "bmshj2018_hyperprior":
        net = bmshj2018_hyperprior(quality=args.quality)
    elif args.model == "mbt2018_mean":
        net = mbt2018_mean(quality=args.quality)
    elif args.model == "mbt2018":
        net = mbt2018(quality=args.quality)
    elif args.model == "cheng2020_anchor":
        net = cheng2020_anchor(quality=args.quality)

    net = net.to(device)
    optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)
    aux_optimizer = optim.Adam(net.aux_parameters(), lr=args.aux_learning_rate)
    if args.loss == "avg":
        criterion = RateDistortionLoss(lmbda=args.lmbda)
    elif args.loss == "ID":
        criterion = InformationDistillationRateDistortionLoss(lmbda=args.lmbda)

    best_loss = 1e10
    if cmd_args.resume == True:
        logger.iteration = ckpt["iteration"]
        best_loss = ckpt["loss"]
        net.load_state_dict(ckpt["state_dict"])
        optimizer.load_state_dict(ckpt["optimizer"])
        aux_optimizer.load_state_dict(ckpt["aux_optimizer"])
        start_epoch = ckpt["epoch"]
    else:
        start_epoch = 0

    for epoch in range(start_epoch, args.epochs):
        train_one_epoch(net,
                        criterion,
                        train_dataloader,
                        optimizer,
                        aux_optimizer,
                        epoch,
                        args.clip_max_norm,
                        test_dataloader,
                        args,
                        writer=train_writer,
                        test_writer=test_writer,
                        logger=logger)
Ejemplo n.º 6
0
def load_compress_ai_model(config):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    net = bmshj2018_hyperprior(quality=2, pretrained=True).eval().to(device)
    print('loaded compressai model')
    return net
Ejemplo n.º 7
0
# In[1]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
metric = 'mse'  # only pre-trained model for mse are available for now
quality = 1  # lower quality -> lower bit-rate (use lower quality to clearly see visual differences in the notebook)

# ## Load some pretrained models

# In[2]:

networks = {
    'bmshj2018-factorized':
    bmshj2018_factorized(quality=quality, pretrained=True).eval().to(device),
    'bmshj2018-hyperprior':
    bmshj2018_hyperprior(quality=quality, pretrained=True).eval().to(device),
    'mbt2018-mean':
    mbt2018_mean(quality=quality, pretrained=True).eval().to(device)
}

# ## Inference

# ### Load input data

# In[7]:

img = Image.oimg = Image.open('./assets/stmalo_fracape.png').convert('RGB')
x = transforms.ToTensor()(img).unsqueeze(0)

# In[8]: