Ejemplo n.º 1
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
            else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
            else torch.ByteTensor

        self.netG, self.netD, self.netE, self.netIG, self.netFE, self.netB, self.netD2, self.netSIG = self.initialize_networks(
            opt)

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                                 tensor=self.FloatTensor,
                                                 opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            self.criterionGANFeat = networks.GANFeatLoss(self.opt)
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()
            if not opt.no_orient_loss:
                self.criterionOrient = networks.L1OLoss(self.opt)

            self.criterionStyleContent = networks.StyleContentLoss(opt)
            # the loss of RGB background
            self.criterionBackground = networks.RGBBackgroundL1Loss()

            self.criterionRGBL1 = nn.L1Loss()
            self.criterionRGBL2 = nn.MSELoss()
            self.criterionLabL1 = networks.LabColorLoss(opt)

            if opt.unpairTrain:
                self.criterionHairAvgLab = networks.HairAvgLabLoss(opt)
Ejemplo n.º 2
0
    def __init__(self, opt):
        super(RotateModel, self).__init__()
        self.opt = opt
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
            else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
            else torch.ByteTensor
        self.real_image = torch.zeros(opt.batchSize, 3, opt.crop_size,
                                      opt.crop_size)
        self.input_semantics = torch.zeros(opt.batchSize, 3, opt.crop_size,
                                           opt.crop_size)

        self.netG, self.netD, self.netE, self.netD_rotate = self.initialize_networks(
            opt)

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                                 tensor=self.FloatTensor,
                                                 opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()
Ejemplo n.º 3
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
            else torch.FloatTensor
        torch_version = torch.__version__.split('.')
        if int(torch_version[1]) >= 2:
            self.ByteTensor = torch.cuda.BoolTensor if self.use_gpu() \
                else torch.BoolTensor
        else:
            self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
                else torch.ByteTensor

        self.netG, self.netD, self.netE = self.initialize_networks(opt)

        self.amp = True if AMP and opt.use_amp and opt.isTrain else False

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                                 tensor=self.FloatTensor,
                                                 opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
            else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
            else torch.ByteTensor

        self.netG, self.netD, self.netE, self.netF = self.initialize_networks(
            opt)

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                                 tensor=self.FloatTensor,
                                                 opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
            if opt.booru_loss and opt.no_vgg_loss:
                self.criterionBooru = networks.BooruLoss(self.opt.gpu_ids)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()
            if opt.L2_loss:
                self.L2Loss = torch.nn.MSELoss()
            if opt.L1_loss:
                self.L1Loss = torch.nn.L1Loss()
            if self.opt.hsv_tv:
                self.hsvTVLoss = networks.HSVTVLoss()
            if self.opt.high_sv:
                self.HighSVLoss = networks.HighSVLoss()
Ejemplo n.º 5
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() else torch.ByteTensor

        self.netG, self.netD, self.netE = self.initialize_networks(opt)

        # set loss functions
        if opt.isTrain:
            if self.opt.dataset_mode == 'cityscapes':
                self.background_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 16, 21, 23]
            elif self.opt.dataset_mode == 'coco':
                self.background_list = [114, 115, 116, 117, 118, 124, 126, 127, 128, 135, 136, 
                144, 145, 147, 148, 149, 151, 154, 155, 157, 158, 159, 160, 161, 162, 164, 169, 
                170, 171, 172, 173, 174, 175, 176, 177, 178, 182]
            elif self.opt.dataset_mode == 'ade20k':
                self.background_list = [0, 1, 2, 3, 4, 5, 6, 9, 11, 13, 16, 17, 21, 25, 26, 29, 32,
                46, 48, 51, 52, 54, 60, 61, 68, 69, 81, 91, 94, 101, 128]
            self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()
Ejemplo n.º 6
0
 def set_losses(self):
     # set loss functions
     opt = self.opt
     if opt.isTrain:
         self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                              tensor=self.FloatTensor,
                                              opt=self.opt)
         self.criterionFeat = torch.nn.L1Loss()
         if not opt.no_vgg_loss:
             self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
         if opt.use_vae:
             self.KLDLoss = networks.KLDLoss()
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() else torch.ByteTensor

        self.netG, self.netD, self.netE = self.initialize_networks(opt)

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()
Ejemplo n.º 8
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
            else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
            else torch.ByteTensor

        self.netG, self.netD, self.netE = self.initialize_networks(opt)

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                                 tensor=self.FloatTensor,
                                                 opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()
            self.segmentation_model = \
                torch.load('/home/qasima/venv_spade/SPADE/checkpoints/isic_fold_1/'
                           'model_epochs100_percent100_isic_256')
Ejemplo n.º 9
0
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
            else torch.FloatTensor
        self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
            else torch.ByteTensor

        self.netG, self.netD, self.netE, self.netG2, self.netD2, self.netE2 = self.initialize_networks(
            opt)

        self.canny_net = backward_canny.Canny_Net(opt.sigma,
                                                  opt.high_threshold,
                                                  opt.low_threshold,
                                                  opt.robust_threshold)
        if self.use_gpu():
            self.canny_net.cuda()

        # set loss functions
        if opt.isTrain:
            self.criterionGAN = networks.GANLoss(opt.gan_mode,
                                                 tensor=self.FloatTensor,
                                                 opt=self.opt)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.opt.gpu_ids)
            if opt.use_vae:
                self.KLDLoss = networks.KLDLoss()

        if opt.cls:
            from adv import pgd, wrn

            # Create model
            if opt.cls_model == 'wrn':
                self.net = wrn.WideResNet(opt.layers,
                                          10,
                                          opt.widen_factor,
                                          dropRate=opt.droprate)
            else:
                assert False, opt.cls_model + ' is not supported.'

            if len(opt.gpu_ids) > 0:
                assert (torch.cuda.is_available())
                self.net.cuda()

            # Restore model if desired
            if opt.load != '':
                self.net = IdentityMapping(self.net)
                if os.path.isfile(opt.load):
                    self.net.load_state_dict(torch.load(opt.load))
                    print('Appointed Model Restored!')
                else:
                    model_name = os.path.join(
                        opt.load, opt.dataset + opt.cls_model + '_epoch_' +
                        str(opt.start_epoch) + '.pt')
                    if os.path.isfile(model_name):
                        self.net.load_state_dict(torch.load(model_name))
                        print('Model restored! Epoch:', opt.start_epoch)
                    else:
                        raise Exception("Could not resume")

        if opt.cnn_edge:
            from adv import zip_wrn
            if opt.blur_edge:
                self.edge_net = zip_wrn.BlurZipNet()
            else:
                self.edge_net = zip_wrn.ZipNet()
            if len(opt.gpu_ids) > 0:
                assert (torch.cuda.is_available())
                self.edge_net.cuda()

            if opt.cnnedge_load != '':
                self.edge_net = IdentityMapping(self.edge_net)
                if os.path.isfile(opt.cnnedge_load):
                    self.edge_net.load_state_dict(torch.load(opt.cnnedge_load))
                    print('Appointed Model Restored!')
                else:
                    model_name = os.path.join(
                        opt.cnnedge_load, opt.dataset + opt.cls_model +
                        '_epoch_' + str(opt.start_epoch) + '.pt')
                    if os.path.isfile(model_name):
                        self.edge_net.load_state_dict(torch.load(model_name))
                        print('Model restored! Epoch:', opt.start_epoch)
                    else:
                        raise Exception("Could not resume")