Ejemplo n.º 1
0
    def __init__(self, opt):
        super(RotateModel, self).__init__()
        self.opt = opt
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
            else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
            else torch.ByteTensor
        self.real_image = torch.zeros(opt.batchSize, 3, opt.crop_size,
                                      opt.crop_size)
        self.input_semantics = torch.zeros(opt.batchSize, 3, opt.crop_size,
                                           opt.crop_size)

        self.netG, self.netD, self.netE, self.netD_rotate = self.initialize_networks(
            opt)

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                                 tensor=self.FloatTensor,
                                                 opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()
Ejemplo n.º 2
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
            else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
            else torch.ByteTensor

        self.netG, self.netD, self.netE, self.netIG, self.netFE, self.netB, self.netD2, self.netSIG = self.initialize_networks(
            opt)

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                                 tensor=self.FloatTensor,
                                                 opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            self.criterionGANFeat = networks.GANFeatLoss(self.opt)
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()
            if not opt.no_orient_loss:
                self.criterionOrient = networks.L1OLoss(self.opt)

            self.criterionStyleContent = networks.StyleContentLoss(opt)
            # the loss of RGB background
            self.criterionBackground = networks.RGBBackgroundL1Loss()

            self.criterionRGBL1 = nn.L1Loss()
            self.criterionRGBL2 = nn.MSELoss()
            self.criterionLabL1 = networks.LabColorLoss(opt)

            if opt.unpairTrain:
                self.criterionHairAvgLab = networks.HairAvgLabLoss(opt)
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
            else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
            else torch.ByteTensor

        self.netG, self.netD, self.netE, self.netF = self.initialize_networks(
            opt)

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                                 tensor=self.FloatTensor,
                                                 opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
            if opt.booru_loss and opt.no_vgg_loss:
                self.criterionBooru = networks.BooruLoss(self.opt.gpu_ids)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()
            if opt.L2_loss:
                self.L2Loss = torch.nn.MSELoss()
            if opt.L1_loss:
                self.L1Loss = torch.nn.L1Loss()
            if self.opt.hsv_tv:
                self.hsvTVLoss = networks.HSVTVLoss()
            if self.opt.high_sv:
                self.HighSVLoss = networks.HighSVLoss()
Ejemplo n.º 4
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
            else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
            else torch.ByteTensor

        self.netGbg, self.netG, self.netD = self.initialize_networks(opt)

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                                 tensor=self.FloatTensor,
                                                 opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            if opt.two_step_model:
                self.criterionMask = torch.nn.BCEWithLogitsLoss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)

        if opt.embed_captions and not opt.use_precomputed_captions:
            print('Loading pretrained BERT model...')
            from transformers import BertModel
            self.bert = BertModel.from_pretrained(
                'bert-base-uncased', output_hidden_states=True).eval()
            if self.use_gpu():
                self.bert = self.bert.cuda()
Ejemplo n.º 5
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
            else torch.FloatTensor
        torch_version = torch.__version__.split('.')
        if int(torch_version[1]) >= 2:
            self.ByteTensor = torch.cuda.BoolTensor if self.use_gpu() \
                else torch.BoolTensor
        else:
            self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
                else torch.ByteTensor

        self.netG, self.netD, self.netE = self.initialize_networks(opt)

        self.amp = True if AMP and opt.use_amp and opt.isTrain else False

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                                 tensor=self.FloatTensor,
                                                 opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()
Ejemplo n.º 6
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
            else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
            else torch.ByteTensor

        self.netG, self.netD, self.netE = self.initialize_networks(opt)

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                                 tensor=self.FloatTensor,
                                                 opt=self.opt)
            self.criterionFeat = nn.L1Loss()
            self.criterionL1 = nn.L1Loss()
            self.criterionL2 = nn.MSELoss()
            self.criterionOpenEDS = MSECalculator.calculate_mse_for_tensors
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
            if opt.lambda_style_feat > 0:
                # loss on style feature maps
                self.criterion_style_feat = nn.MSELoss()
            if opt.lambda_style_w > 0:
                # Loss on latent style code
                self.criterion_style_w = nn.MSELoss()
            if opt.lambda_gram > 0:
                self.criterion_gram = networks.StyleLoss()
            self.reset_loss_log()
Ejemplo n.º 7
0
    def define_losses(self):
        opt = self.opt
        # set loss functions
        if self.isTrain or opt.finetune:
            self.fake_pool = ImagePool(0)
            self.old_lr = opt.lr

            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                                 tensor=self.Tensor,
                                                 opt=opt)
            self.criterionFeat = torch.nn.L1Loss()
            self.criterionFlow = networks.MaskedL1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(opt, self.gpu_ids)

            # Names so we can breakout loss
            self.loss_names_G = [
                'G_GAN', 'G_GAN_Feat', 'G_VGG', 'Gf_GAN', 'Gf_GAN_feat',
                'GT_GAN', 'GT_GAN_Feat', 'F_Flow', 'F_Warp', 'W'
            ]
            self.loss_names_D = [
                'D_real', 'D_fake', 'Df_real', 'Df_fake', 'DT_real', 'DT_fake'
            ]
            self.loss_names = self.loss_names_G + self.loss_names_D
Ejemplo n.º 8
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() else torch.ByteTensor

        self.netG, self.netD, self.netE = self.initialize_networks(opt)

        # set loss functions
        if opt.isTrain:
            if self.opt.dataset_mode == 'cityscapes':
                self.background_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 16, 21, 23]
            elif self.opt.dataset_mode == 'coco':
                self.background_list = [114, 115, 116, 117, 118, 124, 126, 127, 128, 135, 136, 
                144, 145, 147, 148, 149, 151, 154, 155, 157, 158, 159, 160, 161, 162, 164, 169, 
                170, 171, 172, 173, 174, 175, 176, 177, 178, 182]
            elif self.opt.dataset_mode == 'ade20k':
                self.background_list = [0, 1, 2, 3, 4, 5, 6, 9, 11, 13, 16, 17, 21, 25, 26, 29, 32,
                46, 48, 51, 52, 54, 60, 61, 68, 69, 81, 91, 94, 101, 128]
            self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()
Ejemplo n.º 9
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor
        self.perturbation = opt.perturbation

        self.netG, self.netD, self.netE = self.initialize_networks(opt)

        self.generator = opt.netG

        self.noise_range = opt.noise_range
        if self.perturbation:
            self.netP = PerturbationNet(opt)
            self.netP.cuda()
        if self.opt.manipulation:
            self.vocab = pickle.load(open(opt.vocab_path, 'rb'))
            self.txt_enc = EncoderText(len(self.vocab), 300, 1024, 1)
            self.txt_enc.load_state_dict(
                torch.load(opt.vse_enc_path, map_location='cpu')['model'][1])
            self.txt_enc.eval().cuda()

        # set loss functions
        if self.perturbation:
            self.criterionPix = torch.nn.L1Loss()
            self.criterionVGG = networks.VGGLoss(self.opt.gpu)
        elif opt.isTrain:
            self.loss_l1pix = opt.l1pix_loss
            self.loss_gan = not opt.no_disc
            self.loss_ganfeat = not opt.no_ganFeat_loss
            self.loss_vgg = not opt.no_vgg_loss

            if self.loss_l1pix:
                self.criterionPix = torch.nn.L1Loss()
            if self.loss_gan:
                self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                                     tensor=self.FloatTensor,
                                                     opt=self.opt)
            if self.loss_ganfeat:
                self.criterionFeat = torch.nn.L1Loss()
            if self.loss_vgg:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu)
Ejemplo n.º 10
0
 def set_losses(self):
     # set loss functions
     opt = self.opt
     if opt.isTrain:
         self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                              tensor=self.FloatTensor,
                                              opt=self.opt)
         self.criterionFeat = torch.nn.L1Loss()
         if not opt.no_vgg_loss:
             self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
         if opt.use_vae:
             self.KLDLoss = networks.KLDLoss()
Ejemplo n.º 11
0
    def __init__(self, opt):
        super(AvModel, self).__init__()
        self.opt = opt
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
            else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
            else torch.ByteTensor
        self.netG, self.netD, self.netA, self.netA_sync, self.netV, self.netE = \
            self.initialize_networks(opt)

        # set loss functions
        if opt.isTrain:
            self.loss_cls = CrossEntropyLoss()
            self.criterionFeat = torch.nn.L1Loss()

            if opt.softmax_contrastive:
                self.criterionSoftmaxContrastive = networks.SoftmaxContrastiveLoss(
                )
            if opt.train_recognition or opt.train_sync:
                pass

            else:
                self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                                     tensor=self.FloatTensor,
                                                     opt=self.opt)

                if not opt.no_vgg_loss:
                    self.criterionVGG = networks.VGGLoss(self.opt)

                if opt.vgg_face:
                    self.VGGFace = VGGFace19(self.opt)
                    self.criterionVGGFace = networks.VGGLoss(
                        self.opt, self.VGGFace)

            if opt.disentangle:
                self.criterionLogSoftmax = networks.L2SoftmaxLoss()
Ejemplo n.º 12
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() else torch.ByteTensor

        self.netG, self.netD, self.netE = self.initialize_networks(opt)

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()
Ejemplo n.º 13
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
            else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
            else torch.ByteTensor

        self.netG, self.netD, self.netE = self.initialize_networks(opt)

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                                 tensor=self.FloatTensor,
                                                 opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()
            self.segmentation_model = \
                torch.load('/home/qasima/venv_spade/SPADE/checkpoints/isic_fold_1/'
                           'model_epochs100_percent100_isic_256')
Ejemplo n.º 14
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
            else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
            else torch.ByteTensor

        self.netG_for_CT, self.netD_aligned, self.netG_for_MR, self.netD_unaligned = self.initialize_networks(opt)

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(
                opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
            if not opt.no_L1_loss:
                self.criterionL1 = torch.nn.L1Loss()
            if not opt.no_SSIMLoss:
                self.criterionSSIM = networks.SSIMLoss()
            if not opt.no_GradientDifferenceLoss:
                self.criterionGDL = networks.GradientDifferenceLoss()
            if not opt.no_cycle_loss:
                self.criterionCYC = torch.nn.L1Loss()
Ejemplo n.º 15
0
    def __init__(self, config):
        super().__init__()
        self.cfg = config
        self.loss_g_weights = np.array([1.0, 1.0, 10, 10, 10, 10])
        self.loss_g_weights /= self.loss_g_weights.sum()

        self.batch_size = self.cfg.batch_size
        self.dataset_path = self.cfg.dataroot
        self.num_wrokers = 32

        if self.cfg.resize_or_crop != "none" or self.cfg.isTrain is False:
            # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True

        mask_channels = (
            self.cfg.label_nc if self.cfg.label_nc != 0 else self.cfg.input_nc
        )

        # vae network
        self.netVAE = networks.define_VAE(input_nc=mask_channels)
        vae_checkpoint = torch.load(self.cfg.vae_path)
        self.netVAE.load_state_dict(vae_checkpoint["vae"])
        self.vae_lambda = 2.5

        # generator network
        self.netG = networks.define_G(
            mask_channels,
            self.cfg.output_nc,  # image channels
            self.cfg.ngf,  # gen filters in first conv layer
            self.cfg.netG,  # global or local
            self.cfg.n_downsample_global,  # num of downsampling layers in netG
            self.cfg.n_blocks_global,  # num of residual blocks
            self.cfg.n_local_enhancers,  # ignored
            self.cfg.n_blocks_local,  # ignored
            self.cfg.norm,  # instance normalization or batch normalization
        )
        # discriminator network
        if self.cfg.isTrain:
            use_sigmoid = self.cfg.lsgan is False
            netD_input_nc = mask_channels + self.cfg.output_nc
            self.netD = networks.define_D(
                netD_input_nc,
                self.cfg.ndf,  # filters in first conv layer
                self.cfg.n_layers_D,
                self.cfg.norm,
                use_sigmoid,
                self.cfg.num_D,
                getIntermFeat=self.cfg.ganFeat_loss,
            )
            netB_input_nc = self.cfg.output_nc * 2
            self.netB = networks.define_B(
                netB_input_nc, self.cfg.output_nc, 32, 3, 3, self.cfg.norm
            )
        # loss functions
        self.use_pool = self.cfg.pool_size > 0
        if self.cfg.pool_size > 0:
            self.fake_pool = ImagePool(self.cfg.pool_size)

        self.criterionGAN = networks.GANLoss(use_lsgan=self.cfg.lsgan,)
        self.criterionFeat = torch.nn.L1Loss()
        self.criterionVGG = networks.VGGLoss(self.cfg.gpu_ids)
Ejemplo n.º 16
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
            else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
            else torch.ByteTensor

        self.netG, self.netD, self.netE, self.netG2, self.netD2, self.netE2 = self.initialize_networks(
            opt)

        self.canny_net = backward_canny.Canny_Net(opt.sigma,
                                                  opt.high_threshold,
                                                  opt.low_threshold,
                                                  opt.robust_threshold)
        if self.use_gpu():
            self.canny_net.cuda()

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                                 tensor=self.FloatTensor,
                                                 opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()

        if opt.cls:
            from adv import pgd, wrn

            # Create model
            if opt.cls_model == 'wrn':
                self.net = wrn.WideResNet(opt.layers,
                                          10,
                                          opt.widen_factor,
                                          dropRate=opt.droprate)
            else:
                assert False, opt.cls_model + ' is not supported.'

            if len(opt.gpu_ids) > 0:
                assert (torch.cuda.is_available())
                self.net.cuda()

            # Restore model if desired
            if opt.load != '':
                self.net = IdentityMapping(self.net)
                if os.path.isfile(opt.load):
                    self.net.load_state_dict(torch.load(opt.load))
                    print('Appointed Model Restored!')
                else:
                    model_name = os.path.join(
                        opt.load, opt.dataset + opt.cls_model + '_epoch_' +
                        str(opt.start_epoch) + '.pt')
                    if os.path.isfile(model_name):
                        self.net.load_state_dict(torch.load(model_name))
                        print('Model restored! Epoch:', opt.start_epoch)
                    else:
                        raise Exception("Could not resume")

        if opt.cnn_edge:
            from adv import zip_wrn
            if opt.blur_edge:
                self.edge_net = zip_wrn.BlurZipNet()
            else:
                self.edge_net = zip_wrn.ZipNet()
            if len(opt.gpu_ids) > 0:
                assert (torch.cuda.is_available())
                self.edge_net.cuda()

            if opt.cnnedge_load != '':
                self.edge_net = IdentityMapping(self.edge_net)
                if os.path.isfile(opt.cnnedge_load):
                    self.edge_net.load_state_dict(torch.load(opt.cnnedge_load))
                    print('Appointed Model Restored!')
                else:
                    model_name = os.path.join(
                        opt.cnnedge_load, opt.dataset + opt.cls_model +
                        '_epoch_' + str(opt.start_epoch) + '.pt')
                    if os.path.isfile(model_name):
                        self.edge_net.load_state_dict(torch.load(model_name))
                        print('Model restored! Epoch:', opt.start_epoch)
                    else:
                        raise Exception("Could not resume")