Ejemplo n.º 1
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        if self.isTrain:
            if self.train_phase == 'generator':
                self.model_names = ['G']
                self.loss_names = ['G_I_L1','G_I_L2','SSIM', 'PSNR', 'vgg']
            else:
                self.model_names = ['G', 'D']
                self.loss_names = ['G_GAN', 'G_I_L1', 'G_I_L2', 'D_GAN_fake', 'D_GAN_real', 'SSIM', 'PSNR', 'vgg']
        else:  # during test time, only load Gs
            self.model_names = ['G']
            self.loss_names = ['SSIM', 'PSNR']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        self.visual_names = ['real_A', 'fake_B', 'real_B']

        self.netG = networks.define_G(self.opt, opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
        self.criterionL1 = torch.nn.L1Loss()
        self.criterionMSE = torch.nn.MSELoss()
        self.ssim_loss = pytorch_msssim.SSIM(val_range=1)
        if opt.use_vgg:
            self.perceptual = losses.PerceptualLoss()
            self.perceptual.initialize(self.criterionMSE)
        if self.isTrain:
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)


        if self.isTrain and self.train_phase == 'together':
            self.no_wgan = opt.no_wgan
            self.no_wgan_gp = opt.no_wgan_gp
            if self.no_wgan_gp == False:
                self.disc_step = opt.disc_step
            else:
                self.disc_step = 1
            self.disc_model = opt.disc_model
            use_sigmoid = opt.no_lsgan

            if opt.disc_model=='pix2pix':
                self.netD = networks.define_D(opt,opt.input_nc + opt.output_nc, opt.ndf,
                                              opt.which_model_netD,
                                              opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain,
                                              self.gpu_ids)
            if opt.disc_model=='traditional':
                self.netD = networks.define_D(opt,opt.output_nc, opt.ndf, opt.which_model_netD,
                                                opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain,
                                                self.gpu_ids)

            self.loss_wgan_gp = opt.loss_wgan_gp
            self.fake_pool = ImagePool(opt.pool_size)
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, use_l1=not opt.no_l1gan).to(self.device)
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_D)
Ejemplo n.º 2
0
def compute_loss(predicts, targets, weights, w_idx=0):

    # targets = Variable(targets.data.clone(), requires_grad=False)
    mse = nn.MSELoss()
    ssim = pytorch_msssim.SSIM()

    loss_mse = mse(predicts, targets)
    loss_ssim = 1 - ssim(predicts, targets)

    loss = loss_mse + weights[w_idx] * loss_ssim
    return loss
Ejemplo n.º 3
0
    def __init__(self, args):
        super(Loss, self).__init__()
        print('Preparing loss function:')

        self.loss = []
        self.loss_module = nn.ModuleList()
        for loss in args.loss.split('+'):
            weight, loss_type = loss.split('*')
            if loss_type == 'MSE':
                loss_function = nn.MSELoss()
            elif loss_type == 'Huber':
                loss_function = HuberLoss(delta=.5)
            elif loss_type == 'L1':
                loss_function = nn.L1Loss()
            elif loss_type.find('VGG') >= 0:
                loss_function = VGG(loss_type[3:])
            elif loss_type == 'SSIM':
                loss_function = pytorch_msssim.SSIM(val_range=1.)
            elif loss_type.find('GAN') >= 0:
                loss_function = Adversarial(args, loss_type)

            self.loss.append({
                'type': loss_type,
                'weight': float(weight),
                'function': loss_function
            })
            if loss_type.find('GAN') >= 0 >= 0:
                self.loss.append({
                    'type': 'DIS',
                    'weight': 1,
                    'function': None
                })

        if len(self.loss) > 1:
            self.loss.append({'type': 'Total', 'weight': 0, 'function': None})

        for l in self.loss:
            if l['function'] is not None:
                print('{:.3f} * {}'.format(l['weight'], l['type']))
                self.loss_module.append(l['function'])

        device = torch.device('cuda' if args.cuda else 'cpu')
        self.loss_module.to(device)
        #if args.precision == 'half': self.loss_module.half()
        if args.cuda:  # and args.n_GPUs > 1:
            self.loss_module = nn.DataParallel(self.loss_module)
Ejemplo n.º 4
0
# Create Model
model = autoencoder.ConvolutionalAE(
    max_filters=max_filters,
    num_layers=num_layers,
    input_image_dimensions=image_size,
    latent_dim=latent_dim,
    small_conv=small_conv,
)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

ssim_module = None
if use_ssim_loss:
    ssim_module = pytorch_msssim.SSIM(data_range=1.0,
                                      win_size=11,
                                      win_sigma=1.5,
                                      K=(0.01, 0.03))

################################################################################
################################### Training ###################################
################################################################################

# Train
all_samples = []
all_train_loss = []
all_val_loss = []

# Get an initial "epoch 0" sample
model.eval()
with torch.no_grad():
    epoch_sample = model(sample.to(device))
Ejemplo n.º 5
0
def train(args):

    print('Number of GPUs available: ' + str(torch.cuda.device_count()))
    model = nn.DataParallel(CAEP(num_resblocks).cuda())
    print('Done Setup Model.')

    dataset = BSDS500Crop128(args.dataset_path)
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=args.shuffle,
                            num_workers=args.num_workers)
    testset = Kodak(args.testset_path)
    testloader = DataLoader(testset,
                            batch_size=testset.__len__(),
                            num_workers=args.num_workers)
    print(
        f"Done Setup Training DataLoader: {len(dataloader)} batches of size {args.batch_size}"
    )
    print(f"Done Setup Testing DataLoader: {len(testset)} Images")

    MSE = nn.MSELoss()
    SSIM = pytorch_msssim.SSIM().cuda()
    MSSSIM = pytorch_msssim.MSSSIM().cuda()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=1e-10)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=10,
        verbose=True,
    )

    writer = SummaryWriter(log_dir=f'TBXLog/{args.exp_name}')

    # ADMM variables
    Z = torch.zeros(16, 32, 32).cuda()
    U = torch.zeros(16, 32, 32).cuda()
    Z.requires_grad = False
    U.requires_grad = False

    if args.load != '':
        pretrained_state_dict = torch.load(f"./chkpt/{args.load}/model.state")
        current_state_dict = model.state_dict()
        current_state_dict.update(pretrained_state_dict)
        model.load_state_dict(current_state_dict)
        # Z = torch.load(f"./chkpt/{args.load}/Z.state")
        # U = torch.load(f"./chkpt/{args.load}/U.state")
        if args.load == args.exp_name:
            optimizer.load_state_dict(
                torch.load(f"./chkpt/{args.load}/opt.state"))
            scheduler.load_state_dict(
                torch.load(f"./chkpt/{args.load}/lr.state"))
        print('Model Params Loaded.')

    model.train()

    for ei in range(args.res_epoch + 1, args.res_epoch + args.num_epochs + 1):
        # train
        train_loss = 0
        train_ssim = 0
        train_msssim = 0
        train_psnr = 0
        train_peanalty = 0
        train_bpp = 0
        avg_c = torch.zeros(16, 32, 32).cuda()
        avg_c.requires_grad = False

        for bi, crop in enumerate(dataloader):
            x = crop.cuda()
            y, c = model(x)

            psnr = compute_psnr(x, y)
            mse = MSE(y, x)
            ssim = SSIM(x, y)
            msssim = MSSSIM(x, y)

            mix = 1000 * (1 - msssim) + 1000 * (1 - ssim) + 1e4 * mse + (45 -
                                                                         psnr)
            # ADMM Step 1
            peanalty = rho / 2 * torch.norm(c - Z + U, 2)
            bpp = compute_bpp(c, x.shape[0], 'crop', save=False)

            avg_c += torch.mean(c.detach() /
                                (len(dataloader) * args.admm_every),
                                dim=0)

            loss = mix + peanalty
            if ei == 1 and args.load != args.exp_name:
                loss = 1e5 * mse  # warm up

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(
                '[%3d/%3d][%5d/%5d] Loss: %f, SSIM: %f, MSSSIM: %f, PSNR: %f, Norm of Code: %f, BPP: %2f'
                % (ei, args.num_epochs + args.res_epoch, bi, len(dataloader),
                   loss, ssim, msssim, psnr, peanalty, bpp))
            writer.add_scalar('batch_train/loss', loss,
                              ei * len(dataloader) + bi)
            writer.add_scalar('batch_train/ssim', ssim,
                              ei * len(dataloader) + bi)
            writer.add_scalar('batch_train/msssim', msssim,
                              ei * len(dataloader) + bi)
            writer.add_scalar('batch_train/psnr', psnr,
                              ei * len(dataloader) + bi)
            writer.add_scalar('batch_train/norm', peanalty,
                              ei * len(dataloader) + bi)
            writer.add_scalar('batch_train/bpp', bpp,
                              ei * len(dataloader) + bi)

            train_loss += loss.item() / len(dataloader)
            train_ssim += ssim.item() / len(dataloader)
            train_msssim += msssim.item() / len(dataloader)
            train_psnr += psnr.item() / len(dataloader)
            train_peanalty += peanalty.item() / len(dataloader)
            train_bpp += bpp / len(dataloader)

        writer.add_scalar('epoch_train/loss', train_loss, ei)
        writer.add_scalar('epoch_train/ssim', train_ssim, ei)
        writer.add_scalar('epoch_train/msssim', train_msssim, ei)
        writer.add_scalar('epoch_train/psnr', train_psnr, ei)
        writer.add_scalar('epoch_train/norm', train_peanalty, ei)
        writer.add_scalar('epoch_train/bpp', train_bpp, ei)

        if ei % args.admm_every == args.admm_every - 1:
            # ADMM Step 2
            Z = (avg_c + U).masked_fill_((torch.Tensor(
                np.argsort((avg_c + U).data.cpu().numpy(), axis=None)) >= int(
                    (1 - pruning_ratio) * 16 * 32 * 32)).view(16, 32,
                                                              32).cuda(),
                                         value=0)
            # ADMM Step 3
            U += avg_c - Z

        # test
        model.eval()
        val_loss = 0
        val_ssim = 0
        val_msssim = 0
        val_psnr = 0
        val_peanalty = 0
        val_bpp = 0
        for bi, (img, patches, _) in enumerate(testloader):
            avg_loss = 0
            avg_ssim = 0
            avg_msssim = 0
            avg_psnr = 0
            avg_peanalty = 0
            avg_bpp = 0
            for i in range(6):
                for j in range(4):
                    x = torch.Tensor(patches[:, i, j, :, :, :]).cuda()
                    y, c = model(x)

                    psnr = compute_psnr(x, y)
                    mse = MSE(y, x)
                    ssim = SSIM(x, y)
                    msssim = MSSSIM(x, y)

                    mix = 1000 * (1 - msssim) + 1000 * (
                        1 - ssim) + 1e4 * mse + (45 - psnr)

                    peanalty = rho / 2 * torch.norm(c - Z + U, 2)
                    bpp = compute_bpp(c,
                                      x.shape[0],
                                      f'Kodak_patches_{i}_{j}',
                                      save=True)
                    loss = mix + peanalty

                    avg_loss += loss.item() / 24
                    avg_ssim += ssim.item() / 24
                    avg_msssim += msssim.item() / 24
                    avg_psnr += psnr.item() / 24
                    avg_peanalty += peanalty.item() / 24
                    avg_bpp += bpp / 24

            save_kodak_img(model, img, 0, patches, writer, ei)
            save_kodak_img(model, img, 10, patches, writer, ei)
            save_kodak_img(model, img, 20, patches, writer, ei)

            val_loss += avg_loss
            val_ssim += avg_ssim
            val_msssim += avg_msssim
            val_psnr += avg_psnr
            val_peanalty += avg_peanalty
            val_bpp += avg_bpp
        print(
            '*Kodak: [%3d/%3d] Loss: %f, SSIM: %f, MSSSIM: %f, Norm of Code: %f, BPP: %.2f'
            % (ei, args.num_epochs + args.res_epoch, val_loss, val_ssim,
               val_msssim, val_peanalty, val_bpp))

        # bz = call('tar -jcvf ./code/code.tar.bz ./code', shell=True)
        # total_code_size = os.stat('./code/code.tar.bz').st_size
        # total_bpp = total_code_size * 8 / 24 / 768 / 512

        writer.add_scalar('test/loss', val_loss, ei)
        writer.add_scalar('test/ssim', val_ssim, ei)
        writer.add_scalar('test/msssim', val_msssim, ei)
        writer.add_scalar('test/psnr', val_psnr, ei)
        writer.add_scalar('test/norm', val_peanalty, ei)
        writer.add_scalar('test/bpp', val_bpp, ei)
        # writer.add_scalar('test/total_bpp', total_bpp, ei)
        model.train()

        scheduler.step(train_loss)

        # save model
        if ei % args.save_every == args.save_every - 1:
            torch.save(model.state_dict(),
                       f"./chkpt/{args.exp_name}/model.state")
            torch.save(optimizer.state_dict(),
                       f"./chkpt/{args.exp_name}/opt.state")
            torch.save(scheduler.state_dict(),
                       f"./chkpt/{args.exp_name}/lr.state")
            torch.save(Z, f"./chkpt/{args.exp_name}/Z.state")
            torch.save(U, f"./chkpt/{args.exp_name}/U.state")

    writer.close()
Ejemplo n.º 6
0
         multichannel=True,
         gaussian_weights=True,
     ),
     'piq.ssim':
     piq.ssim,
     'kornia.SSIM-halfloss':
     kornia.SSIM(
         window_size=11,
         reduction='mean',
     ),
     'piq.SSIM-loss':
     piq.SSIMLoss(),
     'IQA.SSIM-loss':
     IQA.SSIM(),
     'vainf.SSIM':
     vainf.SSIM(data_range=1.),
     'piqa.SSIM':
     piqa.SSIM(),
 }),
 'MS-SSIM': (2, {
     'piq.ms_ssim': piq.multi_scale_ssim,
     'piq.MS_SSIM-loss': piq.MultiScaleSSIMLoss(),
     'IQA.MS_SSIM-loss': IQA.MS_SSIM(),
     'vainf.MS_SSIM': vainf.MS_SSIM(data_range=1.),
     'piqa.MS_SSIM': piqa.MS_SSIM(),
 }),
 'LPIPS': (
     2,
     {
         'piq.LPIPS': piq.LPIPS(),
         # 'IQA.LPIPS': IQA.LPIPSvgg(),
Ejemplo n.º 7
0
import torch
import torch.nn as nn
import numpy as np
from model import CAEP
from utils import Kodak, GeneralDS, compute_bpp
from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader
import pytorch_msssim
import time
from torch.utils.data.sampler import SequentialSampler

SSIM = pytorch_msssim.SSIM().cuda()

print('Number of GPUs available: ' + str(torch.cuda.device_count()))

model_before = nn.DataParallel(CAEP(15, 16).cuda())
pretrained_state_dict = torch.load(
    f"./chkpt/bpp_10_before_pruning/model.state")
current_state_dict = model_before.state_dict()
current_state_dict.update(pretrained_state_dict)
model_before.load_state_dict(current_state_dict)
model_before.eval()
print('Done Setup Model_before_pruning.')

model_after = nn.DataParallel(CAEP(15, 16).cuda())
pretrained_state_dict = torch.load(f"./chkpt/bpp_05/model.state")
#print(pretrained_state_dict)
current_state_dict = model_after.state_dict()
current_state_dict.update(pretrained_state_dict)
model_after.load_state_dict(current_state_dict)
model_after.eval()
Ejemplo n.º 8
0
 def __init__(self, **kwargs):
     super().__init__(**kwargs)
     self.ssim = ssim.SSIM(data_range=1)
Ejemplo n.º 9
0
def main():

    cudnn.benchmark = True
    # Dataset
    train_data = DatasetFromHdf5('./Data/train_Material_.h5')
    print(len(train_data))
    val_data = DatasetFromHdf5('./Data/valid_Material_.h5')
    print(len(val_data))

    # Data Loader (Input Pipeline)
    train_data_loader = DataLoader(dataset=train_data,
                                   num_workers=1,
                                   batch_size=64,
                                   shuffle=True,
                                   pin_memory=True)
    val_loader = DataLoader(dataset=val_data,
                            num_workers=1,
                            batch_size=1,
                            shuffle=False,
                            pin_memory=True)

    # Model

    model = resblock(conv_bn_relu_res_block, 10, 25, 25)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    if torch.cuda.is_available():
        model.cuda()

    # Parameters, Loss and Optimizer
    start_epoch = 0
    end_epoch = 100
    init_lr = 0.0001
    iteration = 0
    record_test_loss = 1000
    # criterion_RRMSE = torch.nn.L1Loss()
    criterion_RRMSE = rrmse_loss
    criterion_Angle = Angle_Loss
    criterion_MSE = torch.nn.MSELoss()
    criterion_SSIM = pytorch_msssim.SSIM()
    # criterion_Div = Divergence_Loss
    criterion_Div = torch.nn.KLDivLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=init_lr,
                                 betas=(0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=0.01)

    model_path = './models/'
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    loss_csv = open(os.path.join(model_path, 'loss_material.csv'), 'w+')

    log_dir = os.path.join(model_path, 'train_material.log')
    logger = initialize_logger(log_dir)

    # Resume
    resume_file = ''
    if resume_file:
        if os.path.isfile(resume_file):
            print("=> loading checkpoint '{}'".format(resume_file))
            checkpoint = torch.load(resume_file)
            start_epoch = checkpoint['epoch']
            iteration = checkpoint['iter']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])

    for epoch in range(start_epoch + 1, end_epoch):

        start_time = time.time()
        train_loss, iteration, lr = train(train_data_loader, model,
                                          criterion_MSE, criterion_RRMSE,
                                          criterion_Angle, criterion_SSIM,
                                          criterion_Div, optimizer, iteration,
                                          init_lr, end_epoch, epoch)
        test_loss, loss_angle, loss_reconstruct, loss_SSIM, loss_Div = validate(
            val_loader, model, criterion_MSE, criterion_RRMSE, criterion_Angle,
            criterion_SSIM, criterion_Div)

        # xxx_loss = validate_save(val_loader, model, criterion_MSE, criterion_RRMSE, epoch)

        save_checkpoint_material(model_path, epoch, iteration, model,
                                 optimizer)

        # print loss
        end_time = time.time()
        epoch_time = end_time - start_time
        print(
            "Epoch [%d], Iter[%d], Time:%.9f, learning rate : %.9f, Train Loss: %.9f Test Loss: %.9f , Angle Loss: %.9f, Recon Loss: %.9f, SSIM Loss: %.9f ,  Div Loss: %.9f"
            % (epoch, iteration, epoch_time, lr, train_loss, test_loss,
               loss_angle, loss_reconstruct, loss_SSIM, loss_Div))

        # save loss
        record_loss(loss_csv, epoch, iteration, epoch_time, lr, train_loss,
                    test_loss)
        logger.info(
            "Epoch [%d], Iter[%d], Time:%.9f, learning rate : %.9f, Train Loss: %.9f Test Loss: %.9f, Angle Loss: %.9f, Recon Loss: %.9f, SSIM Loss: %.9f,  Div Loss: %.9f "
            % (epoch, iteration, epoch_time, lr, train_loss, test_loss,
               loss_angle, loss_reconstruct, loss_SSIM, loss_Div))
Ejemplo n.º 10
0
################################################################################
##################################### Model ####################################
################################################################################

# Create Enhancer
model = cnn_enhancer.ImageEnhancerCNN(input_channels, num_filters, num_layers,
                                      use_4by4_conv)
model.to(device)
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

ssim_module = None
if use_ssim_loss:
    ssim_module = pytorch_msssim.SSIM(data_range=1.0,
                                      win_size=11,
                                      win_sigma=1.5,
                                      K=(0.01, 0.03))
################################################################################
################################### Training ###################################
################################################################################

# Train
all_train_loss = []
all_val_loss = []

for epoch in range(epochs):
    train_loss = 0
    val_loss = 0

    # Training Loop
    model.train()
Ejemplo n.º 11
0
npImg1 = cv2.imread("einstein.png")

img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0) / 255.0
img2 = torch.rand(img1.size())

if torch.cuda.is_available():
    img1 = img1.cuda()
    img2 = img2.cuda()

img1 = Variable(img1, requires_grad=False)
img2 = Variable(img2, requires_grad=True)

# Functional: pytorch_msssim.msssim(img1, img2, window_size = 11, size_average = True)
msssim_value = pytorch_msssim.msssim(img1, img2)
print("Initial msssim:", msssim_value)

# Module: pytorch_msssim.SSIM(window_size = 11, size_average = True)
msssim_loss = pytorch_msssim.SSIM()

optimizer = optim.Adam([img2], lr=0.01)
msssim_out = msssim_loss(img1, img2)
'''
while msssim_value < 0.95:
    optimizer.zero_grad()
    msssim_out = -msssim_loss(img1, img2)
    msssim_value = - msssim_out.data[0]
    print(msssim_value)
    msssim_out.backward()
    optimizer.step()
'''
Ejemplo n.º 12
0
    def forward(self, flabel, image, ds_image, infer=False):
        # Encode Inputs
        input_flabel, real_image, ds_image = self.encode_input(flabel, image, ds_image)  

        # input to G: downsampled | compact image
        if self.opt.comp_type=='compG': # use compG
          if not self.opt.no_seg: # with seg
            compG_input = torch.cat((input_flabel, real_image), dim=1)                                                                                                 
          else: # no seg
            compG_input = real_image;
          comp_image = self.compG.forward(compG_input)

          ### tensor-level bilinear
          upsample = torch.nn.Upsample(scale_factor=self.opt.alpha, mode='bilinear')      
          up_image = upsample(comp_image)      

        else: # use bicubic downsampling (ds)
          up_image = ds_image
          
        if not self.opt.no_seg: # seg
          if self.opt.comp_type!='none': # seg, ds | comp_image
            input_fconcat = torch.cat((input_flabel, up_image), dim=1)
          else: # no ds, but seg
            input_fconcat = input_flabel
        else: # no seg
          input_flabel = None
          if self.opt.comp_type != 'none': # ds (ds | comp_image)
            input_fconcat = up_image                        

        # add compact image, so that G tries to find the best residual
        res = self.netG.forward(input_fconcat)
        fake_image_f = res + up_image

        # Fake Detection and Loss        
        pred_fake_pool_f = self.discriminate(input_flabel, fake_image_f, use_pool=True)
        loss_D_fake = self.criterionGAN(pred_fake_pool_f, False)

        # Real Detection and Loss                
        pred_real_f = self.discriminate(input_flabel, real_image)
        loss_D_real = self.criterionGAN(pred_real_f, True)

        # GAN loss (Fake Passability Loss)
        if input_flabel is not None:
          inputD_concat = torch.cat((input_flabel, fake_image_f), dim=1)
        else:
          inputD_concat = fake_image_f        
        pred_fake_f = self.netD.forward(inputD_concat)
        loss_G_GAN = self.criterionGAN(pred_fake_f, True)
        
        # GAN feature matching loss
        loss_G_GAN_Feat = 0
        if not self.opt.no_ganFeat_loss:
            feat_weights = 4.0 / (self.opt.n_layers_D + 1)
            D_weights = 1.0 / self.opt.num_D
            for i in range(self.opt.num_D):
                for j in range(len(pred_fake_f[i])-1):
                    loss_G_GAN_Feat += D_weights * feat_weights * \
                        self.criterionFeat(pred_fake_f[i][j], pred_real_f[i][j].detach()) * self.opt.lambda_feat
                   
        # VGG feature matching loss
        loss_G_VGG = 0
        if not self.opt.no_vgg_loss:
            loss_G_VGG = (self.criterionVGG(fake_image_f, real_image)) * self.opt.lambda_feat            
            
        # l1 loss between x and x'
        loss_G_DIS = 0
        criterionDIS = torch.nn.L1Loss()
        loss_G_DIS = criterionDIS(fake_image_f, real_image) * self.opt.lambda_feat * 2
                
        # SSIM Loss
        loss_G_SSIM=0
        ssim_loss = pytorch_msssim.SSIM()
        loss_G_SSIM = -ssim_loss(real_image, fake_image_f)

        # Only return the fake_B image if necessary to save BW
        return [ [ loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_G_DIS, loss_G_SSIM, loss_D_real, loss_D_fake ], None if not infer else real_image, None if not infer else input_flabel, None if not infer else fake_image_f, None if not infer else res, None if not infer else comp_image, None if not infer else up_image ]        
Ejemplo n.º 13
0
import pandas as pd
import numpy as np
import torch
from GAN_hessian_compute import hessian_compute
# from hessian_analysis_tools import scan_hess_npz, plot_spectra, average_H, compute_hess_corr, plot_consistency_example
# from hessian_axis_visualize import vis_eigen_explore, vis_eigen_action, vis_eigen_action_row, vis_eigen_explore_row
from GAN_utils import loadStyleGAN2, StyleGAN2_wrapper, loadBigGAN, BigGAN_wrapper, loadPGGAN, PGGAN_wrapper
import matplotlib.pylab as plt
import matplotlib
#%%
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import lpips
ImDist = lpips.LPIPS(net="squeeze").cuda()
# SSIM
import pytorch_msssim
D = pytorch_msssim.SSIM(
)  # note SSIM, higher the score the more similar they are. So to confirm to the distance convention, we use 1 - SSIM as a proxy to distance.


# L2 / MSE
def MSE(im1, im2):
    return (im1 - im2).pow(2).mean(dim=[1, 2, 3])


# L1 / MAE
# def L1(im1, im2):
#     return (im1 - im2).abs().mean(dim=[1,2,3])
# Note L1 is less proper for this task, as it's farther away from a distance square function.
#%% Utility functions to quantify relationship between 2 eigvals or 2 Hessians.
def spectra_cmp(eigvals1, eigvals2, show=True):
    cc = np.corrcoef((eigvals1), (eigvals2))[0, 1]
    logcc = np.corrcoef(np.log10(np.abs(eigvals1) + 1E-8),
Ejemplo n.º 14
0
def run(config, dataset_path, IPF_path, image_list, idx, fileText, args):

    # Create a directory for saving adversarial images
    adv_path_no = 'Results/{}/{}_{}_{}/'.format(args.method,
                                                str(args.strength),
                                                args.dataset, args.adv_model)
    if not os.path.isdir(adv_path_no):
        os.makedirs(adv_path_no)

    # Structure loss function
    l2_loss = nn.MSELoss()
    ssim_loss = pytorch_msssim.SSIM()

    # Using GPU
    if config.GPU >= 0:
        with torch.cuda.device(config.GPU):
            config.model.cuda()
            l2_loss.cuda()
            ssim_loss.cuda()

    # Setup optimizer
    optimizer = optim.Adam(config.model.parameters(), lr=config.LR)

    # Load the classifier for attacking
    if args.adv_model == 'resnet18':
        classifier = models.resnet18(pretrained=True)
    elif args.adv_model == 'resnet50':
        classifier = models.resnet50(pretrained=True)
    elif args.adv_model == 'alexnet':
        classifier = models.alexnet(pretrained=True)

    classifier.cuda()
    classifier.eval()

    # Freeze the parameters of the classifier under attack to not be updated
    for param in classifier.parameters():
        param.requires_grad = False

    # The name of the chosen image
    img_name = image_list[idx].split('/')[-1]

    # Load and Pre-processing the clean and filtered image
    x = processImage(dataset_path, img_name)
    gt_enh = processImage(IPF_path, img_name)

    # Compute the residual perturbation
    gt_noise = gt_enh - x

    # Perform inference on the clean image
    class_x, prob_class_x, prob_x, logit_x, semantic_vec, super_x = PreidictLabel(
        x, classifier)

    maxIters = 3000

    for it in tqdm(range(maxIters)):

        with autograd.detect_anomaly():

            noise = config.forward(x, gt_noise, config)

            # Enhance adversarial image
            enh = (x + noise).clamp(min=0, max=1)

            # Perform inference on the generated adversarial image
            class_enh, prob_class_enh, prob_enh, logit_enh, _, super_adv = PreidictLabel(
                enh, classifier)

            # Computing structure and semantic adversarial losses
            loss0 = l2_loss(noise, gt_noise)
            loss1 = 1 - ssim_loss((noise + 1) / 2., (gt_noise + 1) / 2.)
            loss2 = AdvLoss(logit_enh,
                            class_x,
                            semantic_vec,
                            is_targeted=False,
                            num_classes=1000)

            # Normalized MSE
            loss3 = loss0.cpu().data.numpy().item(0) / l2_loss(
                gt_noise,
                torch.zeros(1, 3, 224, 224).cuda())

            loss = loss0 + 0.01 * loss1 + loss2

            # backward
            optimizer.zero_grad()
            loss.backward()
            if config.clip is not None:
                torch.nn.utils.clip_grad_norm(config.model.parameters(),
                                              config.clip)
            optimizer.step()

            # Save the generated adversarial image
            cv2.imwrite('{}{}'.format(adv_path_no, img_name),
                        recreate_image(enh))

            adv_img = processImage(adv_path_no, img_name)
            class_adv, _, _, _, _, super_adv = PreidictLabel(
                adv_img, classifier)

            if args.method == 'Nonlinear_Detail':
                if (super_x != super_adv and class_x != class_adv
                        and loss3 < 0.04 and it > 2500):
                    break
            elif args.method == 'Log':
                if (super_x != super_adv and class_x != class_adv
                        and loss3 < 0.003 and it > 2500):
                    break
            elif args.method == 'Linear_Detail':
                if (super_x != super_adv and class_x != class_adv
                        and loss3 < 0.04 and it > 2500):
                    break
            elif args.method == 'Gamma':
                if (super_x != super_adv and class_x != class_adv
                        and loss3 < 0.0005 and it > 2500):
                    break

            #print(img_name, it+1, super_x, super_adv, class_x.cpu().data.numpy().item(0), class_enh.cpu().data.numpy().item(0), class_adv.cpu().data.numpy().item(0),  loss0.cpu().data.numpy().item(0), loss3.cpu().data.numpy().item(0), loss1.cpu().data.numpy().item(0), loss2.cpu().data.numpy().item(0))

    text = '{}\tItrs:{}\tSemantic labels, Clean:{}\t Adversarial:{}\t Categorical labels, Clean:{}\t Adversarial:{}\t L_2 loss:{:.5f}\t SSIM loss{:.5f}\t Adv loss:{:.5f}\n'.format(
        img_name, it + 1, super_x, super_adv, class_x, class_adv,
        loss0.cpu().data.numpy().item(0),
        loss1.cpu().data.numpy().item(0),
        loss2.cpu().data.numpy().item(0))

    fileText.write(text)
    return adv_img