Exemplo n.º 1
0
def main(argv):
    args = parse_args(argv)

    if args.seed is not None:
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    train_transforms = transforms.Compose(
        [transforms.RandomCrop(args.patch_size),
         transforms.ToTensor()])

    test_transforms = transforms.Compose(
        [transforms.CenterCrop(args.patch_size),
         transforms.ToTensor()])

    train_dataset = ImageFolder(args.dataset,
                                split='train',
                                transform=train_transforms)
    test_dataset = ImageFolder(args.dataset,
                               split='test',
                               transform=test_transforms)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True,
                                  pin_memory=True)

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=args.test_batch_size,
                                 num_workers=args.num_workers,
                                 shuffle=False,
                                 pin_memory=True)

    device = 'cuda' if args.cuda and torch.cuda.is_available() else 'cpu'
    net = AutoEncoder()
    net = net.to(device)
    optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)
    aux_optimizer = optim.Adam(net.aux_parameters(), lr=args.aux_learning_rate)
    criterion = RateDistortionLoss(lmbda=args.lmbda)

    best_loss = 1e10
    for epoch in range(args.epochs):
        train_epoch(epoch, train_dataloader, net, criterion, optimizer,
                    aux_optimizer)

        loss = test_epoch(epoch, test_dataloader, net, criterion)

        is_best = loss < best_loss
        best_loss = min(loss, best_loss)
        if args.save:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': net.state_dict(),
                    'loss': loss,
                    'optimizer': optimizer.state_dict(),
                    'aux_optimizer': aux_optimizer.state_dict(),
                }, is_best)
Exemplo n.º 2
0
    def test_init_ok(self, tmpdir):
        tmpdir.mkdir("train")
        tmpdir.mkdir("test")

        train_dataset = ImageFolder(tmpdir, split="train")
        test_dataset = ImageFolder(tmpdir, split="test")

        assert len(train_dataset) == 0
        assert len(test_dataset) == 0
Exemplo n.º 3
0
    def test_count_ok(self, tmpdir):
        tmpdir.mkdir("train")
        (tmpdir / "train" / "img1.jpg").write("")
        (tmpdir / "train" / "img2.jpg").write("")
        (tmpdir / "train" / "img3.jpg").write("")

        train_dataset = ImageFolder(tmpdir, split="train")

        assert len(train_dataset) == 3
Exemplo n.º 4
0
    def test_count_ok(self, tmpdir):
        tmpdir.mkdir('train')
        (tmpdir / "train" / 'img1.jpg').write('')
        (tmpdir / "train" / 'img2.jpg').write('')
        (tmpdir / "train" / 'img3.jpg').write('')

        train_dataset = ImageFolder(tmpdir, split='train')

        assert len(train_dataset) == 3
Exemplo n.º 5
0
    def test_load_transforms(self, tmpdir):
        tmpdir.mkdir("train")
        save_fake_image((tmpdir / "train" / "img0.jpeg").strpath)

        transform = transforms.Compose([
            transforms.CenterCrop((128, 128)),
            transforms.ToTensor(),
        ])
        train_dataset = ImageFolder(tmpdir, split="train", transform=transform)
        assert isinstance(train_dataset[0], torch.Tensor)
        assert train_dataset[0].size() == (3, 128, 128)
Exemplo n.º 6
0
def main(argv):
    args = parse_args(argv)

    if args.seed is not None:
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    train_transforms = transforms.Compose(
        [transforms.RandomCrop(args.patch_size), transforms.ToTensor()]
    )

    test_transforms = transforms.Compose(
        [transforms.CenterCrop(args.patch_size), transforms.ToTensor()]
    )

    train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms)
    test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms)

    device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True,
        pin_memory=(device == "cuda"),
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=args.test_batch_size,
        num_workers=args.num_workers,
        shuffle=False,
        pin_memory=(device == "cuda"),
    )

    net = models[args.model](quality=3)
    net = net.to(device)

    if args.cuda and torch.cuda.device_count() > 1:
        net = CustomDataParallel(net)

    optimizer, aux_optimizer = configure_optimizers(net, args)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
    criterion = RateDistortionLoss(lmbda=args.lmbda)

    last_epoch = 0
    if args.checkpoint:  # load from previous checkpoint
        print("Loading", args.checkpoint)
        checkpoint = torch.load(args.checkpoint, map_location=device)
        last_epoch = checkpoint["epoch"] + 1
        net.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        aux_optimizer.load_state_dict(checkpoint["aux_optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])

    best_loss = float("inf")
    for epoch in range(last_epoch, args.epochs):
        print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
        train_one_epoch(
            net,
            criterion,
            train_dataloader,
            optimizer,
            aux_optimizer,
            epoch,
            args.clip_max_norm,
        )
        loss = test_epoch(epoch, test_dataloader, net, criterion)
        lr_scheduler.step(loss)

        is_best = loss < best_loss
        best_loss = min(loss, best_loss)

        if args.save:
            save_checkpoint(
                {
                    "epoch": epoch,
                    "state_dict": net.state_dict(),
                    "loss": loss,
                    "optimizer": optimizer.state_dict(),
                    "aux_optimizer": aux_optimizer.state_dict(),
                    "lr_scheduler": lr_scheduler.state_dict(),
                },
                is_best,
            )
Exemplo n.º 7
0
def main(argv):
    args = parse_args(argv)

    if args.seed is not None:
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    # train_transforms = transforms.Compose(
    #     [transforms.RandomCrop(args.patch_size),
    #      transforms.ToTensor()])
    #
    # test_transforms = transforms.Compose(
    #     [transforms.CenterCrop(args.patch_size),
    #      transforms.ToTensor()])
    train_transforms = transforms.Compose([transforms.ToTensor()])

    test_transforms = transforms.Compose([transforms.ToTensor()])

    train_dataset = ImageFolder(args.dataset,
                                split='train',
                                patch_size=args.patch_size,
                                transform=train_transforms)
    test_dataset = ImageFolder(args.dataset,
                               split='test',
                               patch_size=args.patch_size,
                               transform=test_transforms)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True,
                                  pin_memory=False)

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=args.test_batch_size,
                                 num_workers=args.num_workers,
                                 shuffle=False,
                                 pin_memory=False)

    device = 'cuda' if (torch.cuda.is_available()
                        and args.cuda != -1) else 'cpu'
    print(device)
    if device == 'cuda':
        torch.cuda.set_device(args.cuda)
    print('temp gpu device number:')
    print(torch.cuda.current_device())
    #net assign
    # with torch.autograd.set_detect_anomaly(True): #for debug gradient
    # net = DSIC(N=128,M=192,F=21,C=32,K=5) #(N=128,M=192,F=21,C=32,K=5)
    ##h**o
    nethomo = HomographyModel()
    net = HSIC(N=128, M=192, K=5)
    net2 = Independent_EN()  #交叉质量增强
    #也可用GMM_together() 调用一个网络包括整体  分开调用方便测试溶解效果

    # net = HSIC(N=128, M=192, K=15)

    # 加载最新模型继续训练
    if os.path.exists("homo_best.pth.tar"):
        model = torch.load('homo_best.pth.tar',
                           map_location=lambda storage, loc: storage)
        model.keys()
        # net.load_state_dict(torch.load('path/params.pkl'))
        nethomo.load_state_dict(model['state_dict'])
        print("load h**o model ok")
    else:
        print("h**o from none")

    # 加载最新模型继续训练
    if os.path.exists("checkpoint_best_loss.pth.tar"):
        model = torch.load('checkpoint_best_loss.pth.tar',
                           map_location=lambda storage, loc: storage)
        model.keys()
        # net.load_state_dict(torch.load('path/params.pkl'))
        net.load_state_dict(model['state_dict'])
        print("load model ok")
    else:
        print("train from none")

    # 加载最新模型继续训练
    if os.path.exists("second_checkpoint_best_loss.pth.tar"):
        model = torch.load('second_checkpoint_best_loss.pth.tar',
                           map_location=lambda storage, loc: storage)
        model.keys()
        # net.load_state_dict(torch.load('path/params.pkl'))
        net2.load_state_dict(model['state_dict'])
        print("2load model ok")
    else:
        print("2train from none")

    #
    nethomo = nethomo.to(device)

    net = net.to(device)
    net2 = net2.to(device)

    print("lambda:", args.lmbda)
    criterion = RateDistortionLoss(lmbda=args.lmbda)

    for epoch in [0]:  # 只跑一次
        loss = test_epoch(epoch, test_dataloader, nethomo, net, net2,
                          criterion)
Exemplo n.º 8
0
def main(argv):
    args = parse_args(argv)

    if args.seed is not None:
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    # train_transforms = transforms.Compose(
    #     [transforms.RandomCrop(args.patch_size),
    #      transforms.ToTensor()])
    #
    # test_transforms = transforms.Compose(
    #     [transforms.CenterCrop(args.patch_size),
    #      transforms.ToTensor()])
    train_transforms = transforms.Compose(
        [transforms.ToTensor()])

    test_transforms = transforms.Compose(
        [transforms.ToTensor()])

    train_dataset = ImageFolder(args.dataset,
                                split='train',
                                patch_size=args.patch_size,
                                transform=train_transforms,
                                need_file_name = True)
    test_dataset = ImageFolder(args.dataset,
                               split='test',
                               patch_size=args.patch_size,
                               transform=test_transforms,
                                need_file_name = True)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True,
                                  pin_memory=False)

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=args.test_batch_size,
                                 num_workers=args.num_workers,
                                 shuffle=False,
                                 pin_memory=False)


    device = 'cuda' if (torch.cuda.is_available() and args.cuda!=-1) else 'cpu'
    print(device)
    if device=='cuda':
        torch.cuda.set_device(args.cuda)
    print('temp gpu device number:')
    print(torch.cuda.current_device())
    #net assign
    # with torch.autograd.set_detect_anomaly(True): #for debug gradient
    # net = DSIC(N=128,M=192,F=21,C=32,K=5) #(N=128,M=192,F=21,C=32,K=5)
    net = HSIC(N=128, M=192, K=5)
    net2 = Independent_EN()  # 独立增强!!!!

    # net = HSIC(N=128, M=192, K=15)
    # 加载最新模型继续训练
    if os.path.exists("checkpoint_best_loss.pth.tar"):
        model = torch.load('checkpoint_best_loss.pth.tar', map_location=lambda storage, loc: storage)
        model.keys()
        # net.load_state_dict(torch.load('path/params.pkl'))
        net.load_state_dict(model['state_dict'])
        print("load model ok")
    else:
        print("train from none")

    # 加载最新模型继续训练
    if os.path.exists("second_checkpoint_best_loss.pth.tar"):
        model = torch.load('second_checkpoint_best_loss.pth.tar', map_location=lambda storage, loc: storage)
        model.keys()
        # net.load_state_dict(torch.load('path/params.pkl'))
        net2.load_state_dict(model['state_dict'])
        print("2load model ok")
    else:
        print("2train from none")

    net = net.to(device)
    net2 = net2.to(device)
    optimizer = optim.Adam(net2.parameters(), lr=args.learning_rate)
    aux_optimizer = optim.Adam(net.aux_parameters(), lr=args.aux_learning_rate)
    print("lambda:", args.lmbda)
    criterion = RateDistortionLoss(lmbda=args.lmbda)

    # best_loss = 1e10
    # for epoch in range(args.epochs):
    #     train_epoch(epoch, train_dataloader, net, net2, criterion, optimizer,
    #                 aux_optimizer, log_file=args.logfile)
    for epoch in [0]:  # 只跑一次
        # try:
        # 验证集
        loss = test_epoch(epoch, test_dataloader, net, net2, criterion)
Exemplo n.º 9
0
def main(argv):
    args = parse_args(argv)

    if args.seed is not None:
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    train_transforms = transforms.Compose(
        [transforms.RandomCrop(args.patch_size),
         transforms.ToTensor()])

    test_transforms = transforms.Compose(
        [transforms.CenterCrop(args.patch_size),
         transforms.ToTensor()])

    train_dataset = ImageFolder(args.dataset,
                                split="train",
                                transform=train_transforms)
    test_dataset = ImageFolder(args.dataset,
                               split="test",
                               transform=test_transforms)

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True,
        pin_memory=True,
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=args.test_batch_size,
        num_workers=args.num_workers,
        shuffle=False,
        pin_memory=True,
    )

    device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
    net = AutoEncoder()
    net = net.to(device)
    optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)
    aux_optimizer = optim.Adam(net.aux_parameters(), lr=args.aux_learning_rate)
    criterion = RateDistortionLoss(lmbda=args.lmbda)

    best_loss = 1e10
    for epoch in range(args.epochs):
        train_one_epoch(
            net,
            criterion,
            train_dataloader,
            optimizer,
            aux_optimizer,
            epoch,
            args.clip_max_norm,
        )

        loss = test_epoch(epoch, test_dataloader, net, criterion)

        is_best = loss < best_loss
        best_loss = min(loss, best_loss)
        if args.save:
            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "state_dict": net.state_dict(),
                    "loss": loss,
                    "optimizer": optimizer.state_dict(),
                    "aux_optimizer": aux_optimizer.state_dict(),
                },
                is_best,
            )
Exemplo n.º 10
0
def main(argv):
    cmd_args = parse_args(argv)

    if cmd_args.resume == True:
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        ckpt = load_checkpoint(cmd_args.ckpt, device)
        args = ckpt["args"]
        save_dirs = prepare_save(model=args.model,
                                 dataset=args.dataname,
                                 quality=args.quality,
                                 ckpt=cmd_args.ckpt)
    else:
        args = cmd_args
        save_dirs = prepare_save(model=args.model,
                                 dataset=args.dataname,
                                 quality=args.quality,
                                 args=args)

    logger = Logger(log_interval=100,
                    test_inteval=100,
                    save_dirs=save_dirs,
                    max_iter=args.iterations)
    train_writer = SummaryWriter(
        os.path.join(save_dirs["tensorboard_runs"], "train"))
    test_writer = SummaryWriter(
        os.path.join(save_dirs["tensorboard_runs"], "test"))

    if args.seed is not None:
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    train_transforms = transforms.Compose(
        [transforms.RandomCrop(args.patch_size),
         transforms.ToTensor()])

    test_transforms = transforms.Compose([transforms.ToTensor()])

    train_dataset = ImageFolder(args.dataset,
                                split="train",
                                transform=train_transforms)
    test_dataset = ImageFolder(args.dataset,
                               split="test",
                               transform=test_transforms)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True,
                                  pin_memory=True,
                                  drop_last=True)

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=args.test_batch_size,
        num_workers=args.num_workers,
        shuffle=False,
        pin_memory=True,
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"

    if args.model == "AE":
        net = AutoEncoder()
    elif args.model == "scale_bmshj2018_factorized":
        net = Scale_FactorizedPrior(N=192, M=320, scale=args.scale)
    elif args.model == "bmshj2018_hyperprior":
        net = bmshj2018_hyperprior(quality=args.quality)
    elif args.model == "mbt2018_mean":
        net = mbt2018_mean(quality=args.quality)
    elif args.model == "mbt2018":
        net = mbt2018(quality=args.quality)
    elif args.model == "cheng2020_anchor":
        net = cheng2020_anchor(quality=args.quality)

    net = net.to(device)
    optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)
    aux_optimizer = optim.Adam(net.aux_parameters(), lr=args.aux_learning_rate)
    if args.loss == "avg":
        criterion = RateDistortionLoss(lmbda=args.lmbda)
    elif args.loss == "ID":
        criterion = InformationDistillationRateDistortionLoss(lmbda=args.lmbda)

    best_loss = 1e10
    if cmd_args.resume == True:
        logger.iteration = ckpt["iteration"]
        best_loss = ckpt["loss"]
        net.load_state_dict(ckpt["state_dict"])
        optimizer.load_state_dict(ckpt["optimizer"])
        aux_optimizer.load_state_dict(ckpt["aux_optimizer"])
        start_epoch = ckpt["epoch"]
    else:
        start_epoch = 0

    for epoch in range(start_epoch, args.epochs):
        train_one_epoch(net,
                        criterion,
                        train_dataloader,
                        optimizer,
                        aux_optimizer,
                        epoch,
                        args.clip_max_norm,
                        test_dataloader,
                        args,
                        writer=train_writer,
                        test_writer=test_writer,
                        logger=logger)
Exemplo n.º 11
0
    def test_load(self, tmpdir):
        tmpdir.mkdir("train")
        save_fake_image((tmpdir / "train" / "img0.jpeg").strpath)

        train_dataset = ImageFolder(tmpdir, split="train")
        assert isinstance(train_dataset[0], Image.Image)
Exemplo n.º 12
0
 def test_invalid_dir(self, tmpdir):
     with pytest.raises(RuntimeError):
         ImageFolder(tmpdir)
Exemplo n.º 13
0
    def test_load(self, tmpdir):
        tmpdir.mkdir('train')
        save_fake_image((tmpdir / 'train' / 'img0.jpeg').strpath)

        train_dataset = ImageFolder(tmpdir, split='train')
        assert isinstance(train_dataset[0], Image.Image)