Example #1
0
    def test_pretrained(self):
        # test we can load the correct models from the urls
        for i in range(1, 5):
            net = mbt2018(i, metric="mse", pretrained=True)
            assert net.state_dict()["g_a.0.weight"].size(0) == 192
            assert net.state_dict()["g_a.6.weight"].size(0) == 192

        for i in range(5, 9):
            net = mbt2018(i, metric="mse", pretrained=True)
            assert net.state_dict()["g_a.0.weight"].size(0) == 192
            assert net.state_dict()["g_a.6.weight"].size(0) == 320
Example #2
0
    def test_ok(self):
        for i in range(1, 5):
            net = mbt2018(i, metric="mse")
            assert isinstance(net, JointAutoregressiveHierarchicalPriors)
            assert net.state_dict()["g_a.0.weight"].size(0) == 192
            assert net.state_dict()["g_a.6.weight"].size(0) == 192

        for i in range(5, 9):
            net = mbt2018(i, metric="mse")
            assert isinstance(net, JointAutoregressiveHierarchicalPriors)
            assert net.state_dict()["g_a.0.weight"].size(0) == 192
            assert net.state_dict()["g_a.6.weight"].size(0) == 320
Example #3
0
    def test_invalid_params(self):
        with pytest.raises(ValueError):
            mbt2018(-1)

        with pytest.raises(ValueError):
            mbt2018(10)

        with pytest.raises(ValueError):
            mbt2018(10, metric="ssim")

        with pytest.raises(ValueError):
            mbt2018(1, metric="ssim")
Example #4
0
def main(argv):
    cmd_args = parse_args(argv)

    if cmd_args.resume == True:
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        ckpt = load_checkpoint(cmd_args.ckpt, device)
        args = ckpt["args"]
        save_dirs = prepare_save(model=args.model,
                                 dataset=args.dataname,
                                 quality=args.quality,
                                 ckpt=cmd_args.ckpt)
    else:
        args = cmd_args
        save_dirs = prepare_save(model=args.model,
                                 dataset=args.dataname,
                                 quality=args.quality,
                                 args=args)

    logger = Logger(log_interval=100,
                    test_inteval=100,
                    save_dirs=save_dirs,
                    max_iter=args.iterations)
    train_writer = SummaryWriter(
        os.path.join(save_dirs["tensorboard_runs"], "train"))
    test_writer = SummaryWriter(
        os.path.join(save_dirs["tensorboard_runs"], "test"))

    if args.seed is not None:
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    train_transforms = transforms.Compose(
        [transforms.RandomCrop(args.patch_size),
         transforms.ToTensor()])

    test_transforms = transforms.Compose([transforms.ToTensor()])

    train_dataset = ImageFolder(args.dataset,
                                split="train",
                                transform=train_transforms)
    test_dataset = ImageFolder(args.dataset,
                               split="test",
                               transform=test_transforms)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True,
                                  pin_memory=True,
                                  drop_last=True)

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=args.test_batch_size,
        num_workers=args.num_workers,
        shuffle=False,
        pin_memory=True,
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"

    if args.model == "AE":
        net = AutoEncoder()
    elif args.model == "scale_bmshj2018_factorized":
        net = Scale_FactorizedPrior(N=192, M=320, scale=args.scale)
    elif args.model == "bmshj2018_hyperprior":
        net = bmshj2018_hyperprior(quality=args.quality)
    elif args.model == "mbt2018_mean":
        net = mbt2018_mean(quality=args.quality)
    elif args.model == "mbt2018":
        net = mbt2018(quality=args.quality)
    elif args.model == "cheng2020_anchor":
        net = cheng2020_anchor(quality=args.quality)

    net = net.to(device)
    optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)
    aux_optimizer = optim.Adam(net.aux_parameters(), lr=args.aux_learning_rate)
    if args.loss == "avg":
        criterion = RateDistortionLoss(lmbda=args.lmbda)
    elif args.loss == "ID":
        criterion = InformationDistillationRateDistortionLoss(lmbda=args.lmbda)

    best_loss = 1e10
    if cmd_args.resume == True:
        logger.iteration = ckpt["iteration"]
        best_loss = ckpt["loss"]
        net.load_state_dict(ckpt["state_dict"])
        optimizer.load_state_dict(ckpt["optimizer"])
        aux_optimizer.load_state_dict(ckpt["aux_optimizer"])
        start_epoch = ckpt["epoch"]
    else:
        start_epoch = 0

    for epoch in range(start_epoch, args.epochs):
        train_one_epoch(net,
                        criterion,
                        train_dataloader,
                        optimizer,
                        aux_optimizer,
                        epoch,
                        args.clip_max_norm,
                        test_dataloader,
                        args,
                        writer=train_writer,
                        test_writer=test_writer,
                        logger=logger)