Esempio n. 1
0
def main(args):
    imgs = load_image(args.input, args.ref)

    vgg = VGG(model_type='vgg19').to(device)
    swapper = Swapper().to(device)

    map_in = vgg(imgs['bic'].to(device), TARGET_LAYERS)
    map_ref = vgg(imgs['ref'].to(device), TARGET_LAYERS)
    map_ref_blur = vgg(imgs['ref_blur'].to(device), TARGET_LAYERS)

    with torch.no_grad(), timer('Feature swapping'):
        maps, weights, correspondences = swapper(map_in, map_ref, map_ref_blur)

    model = SRNTT(use_weights=args.use_weights).to(device)
    model.load_state_dict(torch.load(args.weight))

    img_hr = imgs['hr'].to(device)
    img_lr = imgs['lr'].to(device)
    maps = {
        k: torch.tensor(v).unsqueeze(0).to(device)
        for k, v in maps.items()
    }
    weights = torch.tensor(weights).reshape(1, 1, *weights.shape).to(device)

    with torch.no_grad(), timer('Inference'):
        _, img_sr = model(img_lr, maps, weights)

    psnr = PSNR()(img_sr.clamp(0, 1), img_hr.clamp(0, 1)).item()
    ssim = SSIM()(img_sr.clamp(0, 1), img_hr.clamp(0, 1)).item()
    print(f'[Result] PSNR:{psnr:.2f}, SSIM:{ssim:.4f}')

    save_image(img_sr.clamp(0, 1), './out.png')
Esempio n. 2
0
    def __init__(self, args):
        self.args = args

        args.logger.info('Initializing trainer')
        self.model = get_model(args)
        params_cnt = count_parameters(self.model)
        args.logger.info("params "+str(params_cnt))
        torch.cuda.set_device(args.rank)
        self.model.cuda(args.rank)
        self.model = torch.nn.parallel.DistributedDataParallel(self.model,
                device_ids=[args.rank])
        train_dataset, val_dataset = get_dataset(args)

        if args.split == 'train':
            # train loss
            self.RGBLoss = RGBLoss(args, sharp=False)
            self.SegLoss = nn.CrossEntropyLoss()
            self.RGBLoss.cuda(args.rank)
            self.SegLoss.cuda(args.rank)

            if args.optimizer == "adamax":
                self.optimizer = torch.optim.Adamax(list(self.model.parameters()), lr=args.learning_rate)
            elif args.optimizer == "adam":
                self.optimizer = torch.optim.Adam(self.model.parameters(), lr=args.learning_rate)
            elif args.optimizer == "sgd":
                self.optimizer = torch.optim.SGD(self.model.parameters(), lr=args.learning_rate, momentum=0.9)

            train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
            self.train_loader = torch.utils.data.DataLoader(
                train_dataset, batch_size=args.batch_size//args.gpus, shuffle=False,
                num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)

        else:
            # val criteria
            self.L1Loss  = nn.L1Loss().cuda(args.rank)
            self.PSNRLoss = PSNR().cuda(args.rank)
            self.SSIMLoss = SSIM().cuda(args.rank)
            self.IoULoss = IoU().cuda(args.rank)
            self.VGGCosLoss = VGGCosineLoss().cuda(args.rank)

            val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
            self.val_loader = torch.utils.data.DataLoader(
                val_dataset, batch_size=args.batch_size//args.gpus, shuffle=False,
                num_workers=args.num_workers, pin_memory=True, sampler=val_sampler)

        torch.backends.cudnn.benchmark = True
        self.global_step = 0
        self.epoch=1
        if args.resume or ( args.split != 'train' and not args.checkepoch_range):
            self.load_checkpoint()

        if args.rank == 0:
            writer_name = args.path+'/{}_int_{}_len_{}_{}_logs'.format(self.args.split, int(self.args.interval), self.args.vid_length, self.args.dataset)
            self.writer = SummaryWriter(writer_name)

        self.stand_heat_map = self.create_stand_heatmap()
    def __init__(self, args):
        self.args = args

        args.logger.info('Initializing trainer')

        self.model = get_model(args)
        if args.load_G:
            for p in self.model.netG.parameters():
                p.requires_grad = False
        torch.cuda.set_device(args.rank)
        self.model.cuda(args.rank)
        self.model = torch.nn.parallel.DistributedDataParallel(
            self.model, device_ids=[args.rank])

        train_dataset, val_dataset = get_dataset(args)

        if not args.val:
            # train loss
            self.RGBLoss = RGBLoss(args).cuda(args.rank)
            self.SegLoss = nn.CrossEntropyLoss().cuda(args.rank)
            self.GANLoss = GANMapLoss().cuda(args.rank)
            # self.GANLoss = GANLoss(tensor=torch.FloatTensor).cuda(args.rank)
            self.GANFeatLoss = nn.L1Loss().cuda(args.rank)

            if args.optG == 'adamax':
                self.optG = torch.optim.Adamax(
                    self.model.module.netG.parameters(), lr=args.lr_G)
            if args.optD == 'sgd':
                self.optD = torch.optim.SGD(
                    self.model.module.netD.parameters(),
                    lr=args.lr_D,
                    momentum=0.9)
            elif args.optD == 'adamax':
                self.optD = torch.optim.Adamax(
                    self.model.module.netD.parameters(), lr=args.lr_D)

            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset)

            self.train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size // args.gpus,
                shuffle=False,
                num_workers=args.num_workers,
                pin_memory=True,
                sampler=train_sampler)
        else:
            # val criteria
            self.L1Loss = nn.L1Loss().cuda(args.rank)
            self.PSNRLoss = PSNR().cuda(args.rank)
            self.SSIMLoss = SSIM().cuda(args.rank)
            self.IoULoss = IoU().cuda(args.rank)
            self.VGGCosLoss = VGGCosineLoss().cuda(args.rank)

            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_dataset)

            self.val_loader = torch.utils.data.DataLoader(
                val_dataset,
                batch_size=args.batch_size // args.gpus,
                shuffle=False,
                num_workers=args.num_workers,
                pin_memory=True,
                sampler=val_sampler)

        torch.backends.cudnn.benchmark = True
        self.global_step = 0
        self.epoch = 1
        if args.load_G:
            self.load_temp()
        elif args.resume or (args.val and not args.checkepoch_range):
            self.load_checkpoint()

        if args.rank == 0:
            if not args.load_G:
                if args.val:
                    self.writer =  SummaryWriter(args.path+'/val_logs') if args.interval == 2 else\
                                    SummaryWriter(args.path+'/val_int_1_logs')
                else:
                    self.writer = SummaryWriter(args.path + '/logs')
            else:
                if args.val:
                    self.writer =  SummaryWriter(args.path+'/dis_val_logs') if args.interval == 2 else\
                                    SummaryWriter(args.path+'/dis_val_int_1_logs')
                else:
                    self.writer = SummaryWriter(
                        args.path + '/dis_{}_logs'.format(args.session))
        self.heatmap = self.create_stand_heatmap()
Esempio n. 4
0
def main(args):
    init_seeds(seed=args.seed)

    # split data
    files = list([f.stem for f in Path(args.dataroot).glob('map/*.npz')])
    train_files, val_files = train_test_split(files, test_size=0.1)

    # define dataloaders
    train_set = ReferenceDataset(train_files, args.dataroot)
    val_set = ReferenceDatasetEval(val_files, args.dataroot)
    train_loader = DataLoader(train_set,
                              args.batch_size,
                              shuffle=True,
                              num_workers=4)
    val_loader = DataLoader(val_set, args.batch_size, drop_last=True)

    # define networks
    netG = SRNTT(args.ngf, args.n_blocks, args.use_weights).to(device)
    netG.content_extractor.load_state_dict(torch.load(args.init_weight))
    if args.netD == 'image':
        netD = ImageDiscriminator(args.ndf).to(device)
    elif args.netD == 'patch':
        netD = Discriminator(args.ndf).to(device)

    # define criteria
    criterion_rec = nn.L1Loss().to(device)
    criterion_per = PerceptualLoss().to(device)
    criterion_adv = AdversarialLoss().to(device)
    criterion_tex = TextureLoss(args.use_weights).to(device)

    # metrics
    criterion_psnr = PSNR(max_val=1., mode='Y')
    criterion_ssim = SSIM(window_size=11)

    # define optimizers
    optimizer_G = optim.Adam(netG.parameters(), args.lr)
    optimizer_D = optim.Adam(netD.parameters(), args.lr)

    scheduler_G = StepLR(optimizer_G,
                         int(args.n_epochs * len(train_loader) / 2), 0.1)
    scheduler_D = StepLR(optimizer_D,
                         int(args.n_epochs * len(train_loader) / 2), 0.1)

    # for tensorboard
    writer = SummaryWriter(log_dir=f'runs/{args.pid}' if args.pid else None)

    if args.netG_pre is None:
        """ pretrain """
        step = 0
        for epoch in range(1, args.n_epochs_init + 1):
            for i, batch in enumerate(train_loader, 1):
                img_hr = batch['img_hr'].to(device)
                img_lr = batch['img_lr'].to(device)
                maps = {k: v.to(device) for k, v in batch['maps'].items()}
                weights = batch['weights'].to(device)

                _, img_sr = netG(img_lr, maps, weights)
                """ train G """
                optimizer_G.zero_grad()
                g_loss = criterion_rec(img_sr, img_hr)
                g_loss.backward()
                optimizer_G.step()
                """ logging """
                writer.add_scalar('pre/g_loss', g_loss.item(), step)
                if step % args.display_freq == 0:
                    writer.add_images('pre/img_lr', img_lr.clamp(0, 1), step)
                    writer.add_images('pre/img_hr', img_hr.clamp(0, 1), step)
                    writer.add_images('pre/img_sr', img_sr.clamp(0, 1), step)

                log_txt = [
                    f'[Pre][Epoch{epoch}][{i}/{len(train_loader)}]',
                    f'G Loss: {g_loss.item()}'
                ]
                print(' '.join(log_txt))

                step += 1

                if args.debug:
                    break

            out_path = Path(writer.log_dir) / f'netG_pre{epoch:03}.pth'
            torch.save(netG.state_dict(), out_path)

    else:  # ommit pre-training
        netG.load_state_dict(torch.load(args.netG_pre))
        if args.netD_pre:
            netD.load_state_dict(torch.load(args.netD_pre))
    """ train with all losses """
    step = 0
    for epoch in range(1, args.n_epochs + 1):
        """ training loop """
        netG.train()
        netD.train()
        for i, batch in enumerate(train_loader, 1):
            img_hr = batch['img_hr'].to(device)
            img_lr = batch['img_lr'].to(device)
            maps = {k: v.to(device) for k, v in batch['maps'].items()}
            weights = batch['weights'].to(device)

            _, img_sr = netG(img_lr, maps, weights)
            """ train D """
            optimizer_D.zero_grad()
            for p in netD.parameters():
                p.requires_grad = True
            for p in netG.parameters():
                p.requires_grad = False

            # compute WGAN loss
            d_out_real = netD(img_hr)
            d_loss_real = criterion_adv(d_out_real, True)
            d_out_fake = netD(img_sr.detach())
            d_loss_fake = criterion_adv(d_out_fake, False)
            d_loss = d_loss_real + d_loss_fake

            # gradient penalty
            gradient_penalty = compute_gp(netD, img_hr.data, img_sr.data)
            d_loss += 10 * gradient_penalty

            d_loss.backward()
            optimizer_D.step()
            """ train G """
            optimizer_G.zero_grad()
            for p in netD.parameters():
                p.requires_grad = False
            for p in netG.parameters():
                p.requires_grad = True

            # compute all losses
            loss_rec = criterion_rec(img_sr, img_hr)
            loss_per = criterion_per(img_sr, img_hr)
            loss_adv = criterion_adv(netD(img_sr), True)
            loss_tex = criterion_tex(img_sr, maps, weights)

            # optimize with combined d_loss
            g_loss = (loss_rec * args.lambda_rec + loss_per * args.lambda_per +
                      loss_adv * args.lambda_adv + loss_tex * args.lambda_tex)
            g_loss.backward()
            optimizer_G.step()
            """ logging """
            writer.add_scalar('train/g_loss', g_loss.item(), step)
            writer.add_scalar('train/loss_rec', loss_rec.item(), step)
            writer.add_scalar('train/loss_per', loss_per.item(), step)
            writer.add_scalar('train/loss_tex', loss_tex.item(), step)
            writer.add_scalar('train/loss_adv', loss_adv.item(), step)
            writer.add_scalar('train/d_loss', d_loss.item(), step)
            writer.add_scalar('train/d_real', d_loss_real.item(), step)
            writer.add_scalar('train/d_fake', d_loss_fake.item(), step)
            if step % args.display_freq == 0:
                writer.add_images('train/img_lr', img_lr, step)
                writer.add_images('train/img_hr', img_hr, step)
                writer.add_images('train/img_sr', img_sr.clamp(0, 1), step)

            log_txt = [
                f'[Train][Epoch{epoch}][{i}/{len(train_loader)}]',
                f'G Loss: {g_loss.item()}, D Loss: {d_loss.item()}'
            ]
            print(' '.join(log_txt))

            scheduler_G.step()
            scheduler_D.step()

            step += 1

            if args.debug:
                break
        """ validation loop """
        netG.eval()
        netD.eval()
        val_psnr, val_ssim = 0, 0
        tbar = tqdm(total=len(val_loader))
        for i, batch in enumerate(val_loader, 1):
            img_hr = batch['img_hr'].to(device)
            img_lr = batch['img_lr'].to(device)
            maps = {k: v.to(device) for k, v in batch['maps'].items()}
            weights = batch['weights'].to(device)

            with torch.no_grad():
                _, img_sr = netG(img_lr, maps, weights)
                val_psnr += criterion_psnr(img_hr, img_sr.clamp(0, 1)).item()
                val_ssim += criterion_ssim(img_hr, img_sr.clamp(0, 1)).item()

            tbar.update(1)

            if args.debug:
                break
        else:
            tbar.close()
            val_psnr /= len(val_loader)
            val_ssim /= len(val_loader)

        writer.add_scalar('val/psnr', val_psnr, epoch)
        writer.add_scalar('val/ssim', val_ssim, epoch)

        print(f'[Val][Epoch{epoch}] PSNR:{val_psnr:.4f}, SSIM:{val_ssim:.4f}')

        netG_path = Path(writer.log_dir) / f'netG_{epoch:03}.pth'
        netD_path = Path(writer.log_dir) / f'netD_{epoch:03}.pth'
        torch.save(netG.state_dict(), netG_path)
        torch.save(netD.state_dict(), netD_path)
Esempio n. 5
0
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--weight', '-w', type=str, required=True)
    parser.add_argument('--use_weights', action='store_true')
    args = parser.parse_args()

    dataset = CUFED5Dataset('/home/ubuntu/srntt-pytorch/data/CUFED5')
    dataloader = DataLoader(dataset)

    vgg = VGG(model_type='vgg19').to(device)
    swapper = Swapper().to(device)
    model = SRNTT(use_weights=args.use_weights).to(device)
    model.load_state_dict(torch.load(args.weight))

    criterion_psnr = PSNR()

    table = []
    tbar = tqdm(total=len(dataloader))
    for batch_idx, batch in enumerate(dataloader):
        with torch.no_grad():
            img_hr = batch['img_hr'].to(device)
            img_lr = batch['img_lr'].to(device)
            img_in_up = batch['img_in_up'].to(device)

            map_in = vgg(img_in_up, TARGET_LAYERS)

            row = [batch['filename'][0].split('_')[0]]
            for ref_idx in range(7):
                ref = batch['ref'][ref_idx]
                map_ref = vgg(ref['ref'].to(device), TARGET_LAYERS)
Esempio n. 6
0
    def __init__(self, args):
        self.args = args

        args.logger.info('Initializing trainer')
        # if not os.path.isdir('../predict'):       only used in validation
        #     os.makedirs('../predict')
        self.model = get_model(args)
        if self.args.lock_coarse:
            for p in self.model.coarse_model.parameters():
                p.requires_grad = False
        torch.cuda.set_device(args.rank)
        self.model.cuda(args.rank)
        self.model = torch.nn.parallel.DistributedDataParallel(self.model,
                device_ids=[args.rank])
        train_dataset, val_dataset = get_dataset(args)

        if not args.val:
            # train loss
            self.coarse_RGBLoss = RGBLoss(args, sharp=False)
            self.refine_RGBLoss = RGBLoss(args, sharp=False, refine=True)
            self.SegLoss = nn.CrossEntropyLoss()
            self.GANLoss = GANLoss(tensor=torch.FloatTensor)

            self.coarse_RGBLoss.cuda(args.rank)
            self.refine_RGBLoss.cuda(args.rank)
            self.SegLoss.cuda(args.rank)
            self.GANLoss.cuda(args.rank)

            if args.optimizer == "adamax":
                self.optG = torch.optim.Adamax(list(self.model.module.coarse_model.parameters()) + list(self.model.module.refine_model.parameters()), lr=args.learning_rate)
            elif args.optimizer == "adam":
                self.optG = torch.optim.Adam(self.model.parameters(), lr=args.learning_rate)
            elif args.optimizer == "sgd":
                self.optG = torch.optim.SGD(self.model.parameters(), lr=args.learning_rate, momentum=0.9)

            # self.optD = torch.optim.Adam(self.model.module.discriminator.parameters(), lr=args.learning_rate)
            self.optD = torch.optim.SGD(self.model.module.discriminator.parameters(), lr=args.learning_rate, momentum=0.9)


            train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
            self.train_loader = torch.utils.data.DataLoader(
                train_dataset, batch_size=args.batch_size//args.gpus, shuffle=False,
                num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)

        else:
            # val criteria
            self.L1Loss  = nn.L1Loss().cuda(args.rank)
            self.PSNRLoss = PSNR().cuda(args.rank)
            self.SSIMLoss = SSIM().cuda(args.rank)
            self.IoULoss = IoU().cuda(args.rank)
            self.VGGCosLoss = VGGCosineLoss().cuda(args.rank)

            val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
            self.val_loader = torch.utils.data.DataLoader(
                val_dataset, batch_size=args.batch_size//args.gpus, shuffle=False,
                num_workers=args.num_workers, pin_memory=True, sampler=val_sampler)

        torch.backends.cudnn.benchmark = True
        self.global_step = 0
        self.epoch=1
        if args.resume or (args.val and not args.checkepoch_range):
            self.load_checkpoint()

        if args.rank == 0:
            if args.val:
                self.writer =  SummaryWriter(args.path+'/val_logs') if args.interval == 2 else\
                                SummaryWriter(args.path+'/val_int_1_logs')
            else:
                self.writer = SummaryWriter(args.path+'/logs')
        self.heatmap = self.create_stand_heatmap()