def test_pretrained(self): # test we can load the correct models from the urls for i in range(1, 5): net = mbt2018(i, metric="mse", pretrained=True) assert net.state_dict()["g_a.0.weight"].size(0) == 192 assert net.state_dict()["g_a.6.weight"].size(0) == 192 for i in range(5, 9): net = mbt2018(i, metric="mse", pretrained=True) assert net.state_dict()["g_a.0.weight"].size(0) == 192 assert net.state_dict()["g_a.6.weight"].size(0) == 320
def test_ok(self): for i in range(1, 5): net = mbt2018(i, metric="mse") assert isinstance(net, JointAutoregressiveHierarchicalPriors) assert net.state_dict()["g_a.0.weight"].size(0) == 192 assert net.state_dict()["g_a.6.weight"].size(0) == 192 for i in range(5, 9): net = mbt2018(i, metric="mse") assert isinstance(net, JointAutoregressiveHierarchicalPriors) assert net.state_dict()["g_a.0.weight"].size(0) == 192 assert net.state_dict()["g_a.6.weight"].size(0) == 320
def test_invalid_params(self): with pytest.raises(ValueError): mbt2018(-1) with pytest.raises(ValueError): mbt2018(10) with pytest.raises(ValueError): mbt2018(10, metric="ssim") with pytest.raises(ValueError): mbt2018(1, metric="ssim")
def main(argv): cmd_args = parse_args(argv) if cmd_args.resume == True: device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') ckpt = load_checkpoint(cmd_args.ckpt, device) args = ckpt["args"] save_dirs = prepare_save(model=args.model, dataset=args.dataname, quality=args.quality, ckpt=cmd_args.ckpt) else: args = cmd_args save_dirs = prepare_save(model=args.model, dataset=args.dataname, quality=args.quality, args=args) logger = Logger(log_interval=100, test_inteval=100, save_dirs=save_dirs, max_iter=args.iterations) train_writer = SummaryWriter( os.path.join(save_dirs["tensorboard_runs"], "train")) test_writer = SummaryWriter( os.path.join(save_dirs["tensorboard_runs"], "test")) if args.seed is not None: torch.manual_seed(args.seed) random.seed(args.seed) train_transforms = transforms.Compose( [transforms.RandomCrop(args.patch_size), transforms.ToTensor()]) test_transforms = transforms.Compose([transforms.ToTensor()]) train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms) test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms) train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True, drop_last=True) test_dataloader = DataLoader( test_dataset, batch_size=args.test_batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True, ) device = "cuda" if torch.cuda.is_available() else "cpu" if args.model == "AE": net = AutoEncoder() elif args.model == "scale_bmshj2018_factorized": net = Scale_FactorizedPrior(N=192, M=320, scale=args.scale) elif args.model == "bmshj2018_hyperprior": net = bmshj2018_hyperprior(quality=args.quality) elif args.model == "mbt2018_mean": net = mbt2018_mean(quality=args.quality) elif args.model == "mbt2018": net = mbt2018(quality=args.quality) elif args.model == "cheng2020_anchor": net = cheng2020_anchor(quality=args.quality) net = net.to(device) optimizer = optim.Adam(net.parameters(), lr=args.learning_rate) aux_optimizer = optim.Adam(net.aux_parameters(), lr=args.aux_learning_rate) if args.loss == "avg": criterion = RateDistortionLoss(lmbda=args.lmbda) elif args.loss == "ID": criterion = InformationDistillationRateDistortionLoss(lmbda=args.lmbda) best_loss = 1e10 if cmd_args.resume == True: logger.iteration = ckpt["iteration"] best_loss = ckpt["loss"] net.load_state_dict(ckpt["state_dict"]) optimizer.load_state_dict(ckpt["optimizer"]) aux_optimizer.load_state_dict(ckpt["aux_optimizer"]) start_epoch = ckpt["epoch"] else: start_epoch = 0 for epoch in range(start_epoch, args.epochs): train_one_epoch(net, criterion, train_dataloader, optimizer, aux_optimizer, epoch, args.clip_max_norm, test_dataloader, args, writer=train_writer, test_writer=test_writer, logger=logger)