Exemple #1
0
def run(opt, device):
    print('===> Loading datasets')
    testing_data_loader = DataLoader(dataset=DatasetFromFolder(opt.Dataset,
                                                               set='test'),
                                     num_workers=opt.threads,
                                     batch_size=1,
                                     shuffle=False)

    print('===> Building models')
    net_g = torch.load(opt.Folder + '/final_g.pt',
                       map_location="cuda:%i" % opt.GPUid)
    net_g.eval()

    test_bar = tqdm(testing_data_loader, desc='Colorizing thermal')
    running_results = {'batch_sizes': 0, 'psnr': 0, 'rmse': 0, 'ssim': 0}

    for data, target in test_bar:
        batch_size = data.size(0)
        running_results['batch_sizes'] += batch_size

        data, target = data.to(device), target.to(device)
        with torch.no_grad():
            colorized = net_g(data)

        mse, batch_ssim, psnr = validation_methods(colorized, target)
        running_results['ssim'] += batch_ssim
        running_results['psnr'] += psnr
        running_results['rmse'] += math.sqrt(mse)

        test_bar.set_description(
            desc='PSNR: %.6f  SSIM: %.6f   RMSE: %.6f' %
            (running_results['psnr'] / running_results['batch_sizes'],
             running_results['ssim'] / running_results['batch_sizes'],
             running_results['rmse'] / running_results['batch_sizes']))

        dic = opt.Folder + '/test'
        if not os.path.exists(dic):
            os.makedirs(dic)

        ndarr = colorized[0].mul_(255).add_(0.5).clamp_(0, 255).permute(
            1, 2, 0).to('cpu', torch.uint8).numpy()
        im = Image.fromarray(ndarr)
        im.save(join(dic, 'C_%i.png' % running_results['batch_sizes']))

        ndarr = data[0].mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to(
            'cpu', torch.uint8).numpy()
        im = Image.fromarray(ndarr)
        im.save(join(dic, 'T_%i.png' % running_results['batch_sizes']))
def test(args, device):
    G = torch.load('epochs/' + args.Folder + '/final.pt',
                   map_location="cuda:%i" % args.GPUid)
    G.eval()

    test_loader = DataLoader(dataset=DatasetFromFolder(args.Dataset,
                                                       set='test'),
                             num_workers=5,
                             batch_size=1,
                             shuffle=False)
    test_bar = tqdm(test_loader, desc='Colorizing thermal')

    running_results = {'batch_sizes': 0, 'psnr': 0, 'rmse': 0, 'ssim': 0}
    for data, _, RGB, fname in test_bar:
        fname = fname
        batch_size = data.size(0)
        running_results['batch_sizes'] += batch_size

        data = data.to(device)
        with torch.no_grad():
            colorized_img = G(data)
            colorized_img = denormalize_tensorLAB(colorized_img.data.cpu())
            colorized_img = lab2rgb(colorized_img[0].permute(1, 2, 0).numpy())

            colorized_img = np.clip(colorized_img, 0, 1)
            colorized_img = torch.FloatTensor(colorized_img).permute(
                2, 0, 1).unsqueeze(0)

        mse, batch_ssim, psnr = validation_methods(colorized_img, RGB)
        running_results['ssim'] += batch_ssim
        running_results['psnr'] += psnr
        running_results['rmse'] += math.sqrt(mse.item())

        test_bar.set_description(
            desc='PSNR: %.6f  SSIM: %.6f   RMSE: %.6f' %
            (running_results['psnr'] / running_results['batch_sizes'],
             running_results['ssim'] / running_results['batch_sizes'],
             running_results['rmse'] / running_results['batch_sizes']))

        dic = args.Folder + '/test/' + fname[0][0].replace('\\', '/')
        if not os.path.exists(dic):
            os.makedirs(dic)

        ndarr = colorized_img[0].mul_(255).add_(0.5).clamp_(0, 255).permute(
            1, 2, 0).to('cpu', torch.uint8).numpy()
        im = Image.fromarray(ndarr)
        im.save(join(dic, 'C_' + fname[1][0] + '.png'))
def test():
    opt = parser.parse_args()
    G = torch.load('epochs/' + opt.Folder + '/final.pt', map_location="cuda:%i" % opt.GPUid)
    Gauss3 = GaussianLayer(layers=3, k=25, sigma=12).cuda(opt.GPUid)
    G.eval()

    ds = opt.Dataset
    test_loader = DataLoader(dataset=DatasetFromFolder(ds, set='test'), num_workers=5, batch_size=1, shuffle=False)
    test_bar = tqdm(test_loader, desc='Colorizing thermal')

    running_results = {'batch_sizes': 0, 'psnr': 0, 'rmse': 0, 'ssim': 0}
    for data, target, fname in test_bar:
        fname = fname
        batch_size = data.size(0)
        running_results['batch_sizes'] += batch_size

        data = 1 - data
        data, target = Variable(data).cuda(opt.GPUid), Variable(target).cuda(opt.GPUid)
        with torch.no_grad():
            data_LF = Gauss3(data)
            data_HF = data - data_LF
            colorized = G(data)
            Gcolorized = Gauss3(colorized)

            colorized_img = Gcolorized + 3 * data_HF



        mse, batch_ssim, psnr = validation_methods(colorized_img, target)
        running_results['ssim'] += batch_ssim
        running_results['psnr'] += psnr
        running_results['rmse'] += math.sqrt(mse.item())

        test_bar.set_description(desc='PSNR: %.6f  SSIM: %.6f   RMSE: %.6f'
                                      % (running_results['psnr'] / running_results['batch_sizes'],
                                         running_results['ssim'] / running_results['batch_sizes'],
                                         running_results['rmse'] / running_results['batch_sizes']))


        dic = 'epochs/' + opt.Folder + '/test/' + fname[0][0].replace('\\', '/')
        if not os.path.exists(dic):
            os.makedirs(dic)

        ndarr = colorized_img[0].mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
        im = Image.fromarray(ndarr)
        im.save(join(dic, 'C_' + fname[1][0] + '.png'))
def train(args, device):

    G = Model().to(device)
    weights_init(G, he=False)
    criterion = SSIM(window_size=[4, 16])
    optimizer = optim.Adam(G.parameters(), lr=args.learning_rate)
    train_loader = DataLoader(dataset=DatasetFromFolder(args.Dataset,
                                                        set='train'),
                              num_workers=5,
                              batch_size=8,
                              shuffle=True)
    val_loader = DataLoader(dataset=DatasetFromFolder(args.Dataset,
                                                      set='valid'),
                            num_workers=5,
                            batch_size=1,
                            shuffle=False)

    totalpsnr = []
    totalLoss = []
    for epoch in range(0, args.num_epochs, 1):
        train_bar = tqdm(train_loader, desc='Colorizing thermal')
        running_results = {
            'batch_sizes': 0,
            'loss': 0,
            'L1loss': 0,
            'SSIMloss': 0
        }
        G.train()
        for data, target, _ in train_bar:
            batch_size = data.size(0)
            running_results['batch_sizes'] += batch_size
            data, target = data.to(device), target.to(device)

            G.zero_grad()
            colorized = G(data)

            loss1 = criterion(colorized[:, 0:1], target[:, 0:1])
            loss2 = l1_loss(colorized[:, 1:3], target[:, 1:3])
            loss = loss1 + loss2

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_results['loss'] += loss.item() * batch_size
            running_results['L1loss'] += loss2.item() * batch_size
            running_results['SSIMloss'] += loss1.item() * batch_size
            train_bar.set_description(
                desc='[%d/%d] L1_Loss: %.6f  SSIM: %.6f' %
                (epoch, args.num_epochs,
                 running_results['L1loss'] / running_results['batch_sizes'],
                 running_results['SSIMloss'] / running_results['batch_sizes']))

        totalLoss.append(
            [running_results['loss'] / running_results['batch_sizes']])
        torch.save(G, args.Folder + '/final.pt')

        if (epoch + 1) % 1 == 0 and epoch >= 0:
            G.eval()
            val_bar = tqdm(val_loader)

            valing_results = {'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
            c = 0
            for data, _, target in val_bar:
                c += 1
                batch_size = data.size(0)
                valing_results['batch_sizes'] += batch_size
                data = data.to(device)

                with torch.no_grad():
                    colorized = G(data)
                    colorized = denormalize_tensorLAB(colorized)
                    colorized = lab2rgb(colorized[0].permute(
                        1, 2, 0).data.cpu().numpy())[np.newaxis, :]
                    colorized = torch.FloatTensor(colorized).permute(
                        0, 3, 1, 2)
                    colorized = torch.clamp(colorized, 0, 1)

                _, batch_ssim, psnr = validation_methods(colorized, target)
                valing_results['ssim'] += batch_ssim
                valing_results['psnr'] += psnr

                val_bar.set_description(
                    desc='[Colorizing Thermal] psnr: %.2f dB ssim: %.2f ' %
                    (valing_results['psnr'] / valing_results['batch_sizes'],
                     valing_results['ssim'] / valing_results['batch_sizes']))

                if (epoch + 1) % 10 == 0 and epoch >= 0:
                    val_images = []
                    val_images.extend([
                        data.squeeze().cpu(),
                        target.squeeze().cpu(),
                        colorized.squeeze().cpu()
                    ])

                    val_images = torch.stack(val_images)
                    val_images = torch.chunk(val_images,
                                             val_images.size(0) // 3)

                    dic = args.Folder + '/imgs'
                    if not os.path.exists(dic):
                        os.makedirs(dic)

                    for image in val_images:
                        image = utils.make_grid(image, nrow=3, padding=5)
                        utils.save_image(image,
                                         args.Folder +
                                         '/imgs/epoch_%i_%i_.png' %
                                         (epoch + 1, c),
                                         padding=5)

            totalpsnr.append(valing_results['psnr'] /
                             valing_results['batch_sizes'])
def train():

    opt = parser.parse_args()
    G = Model()
    Gauss3 = GaussianLayer(layers=3, k=25, sigma=12).cuda(1)
    print('# generator parameters:', sum(param.numel() for param in G.parameters()))

    # G.apply(weights_init)
    weights_init(G)
    l1_loss = nn.L1Loss()
    mse = nn.MSELoss()
    if torch.cuda.is_available():
        G.cuda(opt.GPUid)
        Gauss3.cuda(opt.GPUid)

    g_optimizer = optim.Adam(G.parameters(), lr=opt.lr)
    totalpsnr = []
    totalLoss = []

    if opt.Checkpoint != 'default':
        checkpoint = torch.load(os.getcwd()+ '/epochs/' + opt.Folder +'/netG_training_epoch_300.pt')
        G.load_state_dict(checkpoint['model_state_dict'])
        g_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epochInit = checkpoint['epoch']
        totalLoss = checkpoint['loss']
        print('Model Loaded...')
    else :
        epochInit = 0


    ds = opt.Dataset
    train_set = DatasetFromFolder(ds, set='train')
    val_set = DatasetFromFolder(ds, set='valid')
    train_loader = DataLoader(dataset=train_set, num_workers=5, batch_size=opt.patchSize, shuffle=True)
    val_loader = DataLoader(dataset=val_set, num_workers=5, batch_size=1, shuffle=False)

    for epoch in range(epochInit, opt.num_epochs, 1):
        train_bar = tqdm(train_loader, desc='Colorizing thermal')
        running_results = {'batch_sizes': 0, 'Tloss': 0, 'loss': 0, 'Gloss': 0}

        learning_rate = set_lr(opt, epoch, g_optimizer)
        for data, target in train_bar:
            data = 1 - data
            batch_size = data.size(0)
            running_results['batch_sizes'] += batch_size
            if torch.cuda.is_available():
                data, target = data.cuda(opt.GPUid), target.cuda(opt.GPUid)

            Gtarget = Gauss3(target, 12)

            G.train()
            G.zero_grad()

            colorized_img = G(data)
            Gcolorized_img = Gauss3(colorized_img, 12)
            loss = l1_loss(colorized_img, target)
            Gloss = mse(Gcolorized_img, Gtarget)

            total_loss = loss + 10 * Gloss
            total_loss.backward()
            g_optimizer.step()

            running_results['loss'] += loss.item() * batch_size
            running_results['Gloss'] += Gloss.item() * batch_size
            running_results['Tloss'] += total_loss.item() * batch_size

            train_bar.set_description(desc='[%d/%d] lr: %.e, Loss: %.6f , GLoss: %.6f'
                                           % (epoch, opt.num_epochs, learning_rate,
                                              running_results['loss'] / running_results['batch_sizes'],
                                              running_results['Gloss'] / running_results['batch_sizes']))

        totalLoss.append([running_results['Tloss'] / running_results['batch_sizes']])

        SaveData(totalLoss,'epochs/' + opt.Folder, '/loss.pkl')

        if (epoch + 1)% 1 == 0 and epoch >= 0:
            G.eval()
            val_bar = tqdm(val_loader)

            valing_results = {'psnr_F': 0, 'ssim_F': 0, 'psnr_G': 0, 'ssim_G': 0, 'batch_sizes': 0}
            c=0
            for data, target in val_bar:
                c+=1
                rgb = target.clone()
                batch_size = data.size(0)
                valing_results['batch_sizes'] += batch_size
                data = Variable(data)
                data = 1 - data
                target = Variable(target)

                if torch.cuda.is_available():
                    data = data.cuda(opt.GPUid)
                    target = target.cuda(opt.GPUid)
                    rgb = rgb.cuda(opt.GPUid)

                with torch.no_grad():
                    target = Gauss3(target)
                    data_LF = Gauss3(data)
                    data_HF = data - data_LF
                    colorized = G(data)
                    Gcolorized = Gauss3(colorized)

                psnr_colorized = Gcolorized + 3 * data_HF

                _, batch_ssim, psnr = validation_methods(Gcolorized, target)
                valing_results['ssim_G'] += batch_ssim
                valing_results['psnr_G'] += psnr

                _, batch_ssim, psnr = validation_methods(torch.clamp(psnr_colorized, 0, 1), rgb)
                valing_results['ssim_F'] += batch_ssim
                valing_results['psnr_F'] += psnr


                val_bar.set_description(
                    desc='[Colorizing Thermal] psnr_F: %.2f dB ssim_F: %.2f '
                         'psnr_G: %.2f dB ssim_G: %.2f '% (
                        valing_results['psnr_F'] / valing_results['batch_sizes'],
                        valing_results['ssim_F'] / valing_results['batch_sizes'],
                        valing_results['psnr_G'] / valing_results['batch_sizes'],
                        valing_results['ssim_G'] / valing_results['batch_sizes']))


                if (epoch + 1) % 50 == 0 and epoch > 0:
                    val_images = []

                    target = target + 3 * data_HF
                    Gcolorized = Gcolorized + 3 * data_HF
                    target = torch.clamp(target, 0, 1)
                    Gcolorized = torch.clamp(Gcolorized, 0, 1)

                    val_images.extend(
                        [data.squeeze().cpu(), rgb.squeeze().cpu(), target.squeeze().cpu(), Gcolorized.squeeze().cpu()])

                    val_images = torch.stack(val_images)
                    val_images = torch.chunk(val_images, val_images.size(0) // 4)
                    val_save_bar = tqdm(val_images, desc='[saving training results]')

                    for image in val_save_bar:
                        image = utils.make_grid(image, nrow=4, padding=5)
                        utils.save_image(image, 'epochs/' + opt.Folder + '/imgs/epoch_%i_%i_.png' % (epoch+1, c),
                                         padding=5)

            totalpsnr.append([valing_results['psnr_F'] / valing_results['batch_sizes'],
                              valing_results['ssim_F'] / valing_results['batch_sizes'],
                              valing_results['psnr_G'] / valing_results['batch_sizes'],
                              valing_results['ssim_G'] / valing_results['batch_sizes']])

            SaveData(totalpsnr, 'epochs/' + opt.Folder, '/psnr.pkl')
            torch.save(G, 'epochs/' + opt.Folder + '/final.pt')

            torch.cuda.empty_cache()

        if (epoch + 1) % 50 == 0 and epoch > 0:
            torch.save({
                    'epoch': epoch +1 ,
                    'model_state_dict': G.state_dict(),
                    'optimizer_state_dict': g_optimizer.state_dict(),
                    'loss': totalLoss,
                    }, os.getcwd()+ '/epochs/' + opt.Folder + '/netG_training_epoch_%d.pt' %(epoch+1))
Exemple #6
0
def run(opt, device):
    print('===> Loading datasets')
    training_data_loader = DataLoader(dataset=DatasetFromFolder(opt.Dataset,
                                                                set='train'),
                                      num_workers=opt.threads,
                                      batch_size=opt.batch_size,
                                      shuffle=True)
    val_data_loader = DataLoader(dataset=DatasetFromFolder(opt.Dataset,
                                                           set='valid'),
                                 num_workers=opt.threads,
                                 batch_size=1,
                                 shuffle=False)

    print('===> Building models')
    net_g = Generator().to(device)
    net_d = Discriminator().to(device)
    weights_init(net_g)
    weights_init(net_d)

    criterionGAN = GANLoss(use_lsgan=False).to(device)
    criterionL1 = nn.L1Loss()
    tvLoss = TVLoss(opt.lambt)
    optimizer_g = optim.Adam(net_g.parameters(),
                             lr=opt.lrg,
                             betas=(0.5, 0.999))
    optimizer_d = optim.Adam(net_d.parameters(),
                             lr=opt.lrd,
                             betas=(0.5, 0.999))
    Fex = FeatureExtractor(
        torchvision.models.vgg16(pretrained=True)).to(device)

    for epoch in range(opt.num_epochs):
        train_bar = tqdm(training_data_loader, desc='Colorizing thermal')
        running_results = {'batch_sizes': 0, 'loss_d': 0, 'loss_g': 0}

        for data, target in train_bar:
            batch_size = data.size(0)
            running_results['batch_sizes'] += batch_size

            real_a, real_b = data.to(device), target.to(device)
            fake_b = net_g(real_a)

            ######################
            # Update D network
            ######################

            optimizer_d.zero_grad()
            pred_fake = net_d(real_a.detach(), fake_b.detach())
            loss_d_fake = criterionGAN(pred_fake, False)

            pred_real = net_d(real_a, real_b)
            loss_d_real = criterionGAN(pred_real, True)

            loss_d = (loss_d_fake + loss_d_real) * 0.5
            optimizer_d.zero_grad()
            loss_d.backward()
            optimizer_d.step()

            ######################
            # Update G network
            ######################

            optimizer_g.zero_grad()
            pred_fake = net_d(real_a, fake_b)

            loss_g_adv = criterionGAN(pred_fake, True) * opt.lamba
            loss_g_content = criterionL1(fake_b, real_b) * opt.lambc
            loss_g_tv = tvLoss(fake_b)

            real_b_Fexs = Fex(normalize_batch(real_b))
            fake_b_Fexs = Fex(normalize_batch(fake_b))
            loss_g_percept = 0
            for r, f in zip(real_b_Fexs, fake_b_Fexs):
                loss_g_content += criterionL1(r, f)

            loss_g = loss_g_adv + loss_g_content + loss_g_tv + loss_g_percept * opt.lambp
            optimizer_g.zero_grad()
            loss_g.backward()
            optimizer_g.step()

            running_results['loss_d'] += loss_d.item() * batch_size
            running_results['loss_g'] += loss_g.item() * batch_size

            train_bar.set_description(
                desc='[%d/%d] lrg: %.e, lrd: %.e, loss_d: %.6f , loss_g: %.6f'
                % (epoch, opt.num_epochs, opt.lrg, opt.lrd,
                   running_results['loss_d'] / running_results['batch_sizes'],
                   running_results['loss_g'] / running_results['batch_sizes']))

        torch.save(net_d, opt.Folder + '/final_d.pt')
        torch.save(net_g, opt.Folder + '/final_g.pt')

        if (epoch + 1) % 1 == 0 and epoch >= 0:
            net_g.eval()
            val_bar = tqdm(val_data_loader)
            valing_results = {
                'psnr': 0,
                'ssim': 0,
                'rmse': 0,
                'batch_sizes': 0
            }

            for data, target in val_bar:
                batch_size = data.size(0)
                valing_results['batch_sizes'] += batch_size
                data, target = data.to(device), target.to(device)

                with torch.no_grad():
                    colorized = net_g(data)

                mse, batch_ssim, psnr = validation_methods(colorized, target)
                valing_results['ssim'] += batch_ssim
                valing_results['psnr'] += psnr
                valing_results['rmse'] += np.sqrt(mse)

                val_bar.set_description(
                    desc=
                    '[Colorizing Thermal] psnr: %.2f dB ssim: %.2f rmse: %.2f'
                    % (valing_results['psnr'] / valing_results['batch_sizes'],
                       valing_results['ssim'] / valing_results['batch_sizes'],
                       valing_results['rmse'] / valing_results['batch_sizes']))

                if (epoch + 1) % 5 == 0 and epoch > 0:
                    colorized = torch.clamp(colorized, 0, 1)
                    val_images = torch.cat(
                        (data.cpu(), target.cpu(), colorized.cpu()), dim=0)
                    image = utils.make_grid(val_images, nrow=3, padding=5)
                    utils.save_image(image,
                                     opt.Folder + '/imgs/epoch_%i_%i_.png' %
                                     (epoch, valing_results['batch_sizes']),
                                     padding=5)