コード例 #1
0
ファイル: PSIARRE.py プロジェクト: tamwaiban/CISR_PSI
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        SR_args = {
            'scale': opt.sr_factor,
            'n_feats': opt.n_feats,
            'n_resblocks': opt.n_resblocks,
            'n_colors': opt.n_colors,
            'main_model': opt.main_model,
            'recur_step': opt.recur_step,
            'res_scale': opt.res_scale,
            'device': self.device,
            'n_resgroups1': opt.n_resgroups1,
            'n_resgroups2': opt.n_resgroups2,
            'rgb_range': opt.rgb_range
        }

        self.sr_factor = opt.sr_factor
        self.model = PSIARRENet(SR_args)
        self.model = nn.DataParallel(self.model, device_ids=opt.gpu_ids)
        self.model.to(self.device)
        self.criterion = nn.L1Loss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=opt.lr)
        self.scheduler = lr_scheduler.StepLR(self.optimizer,
                                             step_size=opt.lr_decay,
                                             gamma=0.5)
        if opt.n_resgroups1 == 2:
            self.tiny = True
        else:
            self.tiny = False
コード例 #2
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        args = {'scale': opt.sr_factor, 'n_feats': 256, 'n_resblocks': 32, 'res_scale': 0.1, 'n_colors': 3,
                'rgb_range': opt.rgb_range}
        self.model = EDSR(args).to(self.device)
        self.criterion = nn.L1Loss()
        self.model = nn.DataParallel(self.model, opt.gpu_ids)
        # self.criterion = self.criterion.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=opt.lr)
        self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=500, gamma=0.5)
コード例 #3
0
    def initialize(self, opt, epoch=0):
        BaseModel.initialize(self, opt)
        torch.backends.cudnn.benchmark = True

        # define losses
        self.lossCollector = LossCollector()
        self.lossCollector.initialize(opt)

        # define networks
        self.define_networks(epoch)

        # load networks
        self.load_networks()
コード例 #4
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        kernel_size = 3
        n_colors = opt.n_colors

        self.model = NonLocalModule(kernel_size, n_colors, self.device)
        self.model = nn.DataParallel(self.model, opt.gpu_ids)
        self.model.to(self.device)
        self.criterion = nn.L1Loss().to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=opt.lr)
        self.scheduler = lr_scheduler.StepLR(self.optimizer,
                                             step_size=opt.lr_decay,
                                             gamma=0.5)
コード例 #5
0
ファイル: RDN.py プロジェクト: luohongming/SR_baseline
    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        args = {
            'scale': opt.sr_factor,
            'G0': 64,
            'RDNkSize': 3,
            'RDNconfig': 'B',
            'n_colors': 3
        }

        self.model = RDN(args).to(self.device)
        self.criterion = nn.MSELoss()
        self.model = nn.DataParallel(self.model, opt.gpu_ids)
        # self.criterion = self.criterion.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=opt.lr)
        self.scheduler = lr_scheduler.StepLR(self.optimizer,
                                             step_size=50,
                                             gamma=0.5)
コード例 #6
0
    def initialize(self, opt, epoch=0):
        BaseModel.initialize(self, opt)
        torch.backends.cudnn.benchmark = True

        # define losses
        self.lossCollector = LossCollector()
        self.lossCollector.initialize(opt)

        # Face network
        self.refine_face = hasattr(opt, 'refine_face') and opt.refine_face
        self.faceRefiner = None
        if self.refine_face or self.add_face_D:
            self.faceRefiner = FaceRefineModel()
            self.faceRefiner.initialize(opt, self.add_face_D, self.refine_face)

        # define networks
        self.define_networks(epoch)

        # load networks
        self.load_networks()
コード例 #7
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.isTrain = opt.isTrain

        # load/define networks
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.schedulers = []
            self.optimizers = []
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')
コード例 #8
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        # set networks
        self.initialize_networks(opt)

        # set loss functions
        self.initialize_loss(opt)

        # set optimizer
        self.initialize_optimizer(opt)

        self.initialize_other(opt)

        self.model_dict = {
            'netG': {
                'model': self.netG.module if self.use_gpu else self.netG,
                'optimizer': self.optimizer_G
            },
            'netD': {
                'model': self.netD.module if self.use_gpu else self.netD,
                'optimizer': self.optimizer_D
            }
        }
        self.opt = opt
コード例 #9
0
    def __init__(self, opt, dataset=None):
        BaseModel.initialize(self, opt)
        self.model_rgb = Deeplab_VGG(self.opt.label_nc, self.opt.depthconv)
        self.model_HHA = Deeplab_VGG(self.opt.label_nc, self.opt.depthconv)

        self.model = nn.Sequential(*[self.model_rgb, self.model_HHA])

        if self.opt.isTrain:
            self.criterionSeg = torch.nn.CrossEntropyLoss(
                ignore_index=255).cuda()
            # self.optimizer = torch.optim.SGD(
            #     [
            #         {'params': self.model_rgb.Scale.get_1x_lr_params_NOscale(), 'lr': self.opt.lr},
            #         {'params': self.model_rgb.Scale.get_10x_lr_params(), 'lr': 10 * self.opt.lr},
            #         {'params': self.model_rgb.Scale.get_2x_lr_params_NOscale(), 'lr': 2 * self.opt.lr,
            #          'weight_decay': 0.},
            #         {'params': self.model_rgb.Scale.get_20x_lr_params(), 'lr': 20 * self.opt.lr, 'weight_decay': 0.},
            #         {'params': self.model_HHA.Scale.get_1x_lr_params_NOscale(), 'lr': self.opt.lr},
            #         {'params': self.model_HHA.Scale.get_10x_lr_params(), 'lr': 10 * self.opt.lr},
            #         {'params': self.model_HHA.Scale.get_2x_lr_params_NOscale(), 'lr': 2 * self.opt.lr,
            #          'weight_decay': 0.},
            #         {'params': self.model_HHA.Scale.get_20x_lr_params(), 'lr': 20 * self.opt.lr, 'weight_decay': 0.}
            #     ],
            #     lr=self.opt.lr, momentum=self.opt.momentum, weight_decay=self.opt.wd)
            params_rgb = list(self.model_rgb.Scale.parameters())
            params_HHA = list(self.model_HHA.Scale.parameters())
            self.optimizer = torch.optim.SGD(params_rgb + params_HHA,
                                             lr=self.opt.lr,
                                             momentum=self.opt.momentum,
                                             weight_decay=self.opt.wd)
            #
            # self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.opt.lr, momentum=self.opt.momentum, weight_decay=self.opt.wd)

            self.old_lr = self.opt.lr
            self.averageloss = []
            # copy scripts
            self.model_path = './models'  #os.path.dirname(os.path.realpath(__file__))
            self.data_path = './data'  #os.path.dirname(os.path.realpath(__file__))
            shutil.copyfile(os.path.join(self.model_path, 'Deeplab_HHA.py'),
                            os.path.join(self.model_dir, 'Deeplab.py'))
            shutil.copyfile(os.path.join(self.model_path, 'VGG_Deeplab.py'),
                            os.path.join(self.model_dir, 'VGG_Deeplab.py'))
            shutil.copyfile(os.path.join(self.model_path, 'model_utils.py'),
                            os.path.join(self.model_dir, 'model_utils.py'))
            shutil.copyfile(os.path.join(self.data_path, dataset.datafile),
                            os.path.join(self.model_dir, dataset.datafile))
            shutil.copyfile(os.path.join(self.data_path, 'base_dataset.py'),
                            os.path.join(self.model_dir, 'base_dataset.py'))

            self.writer = SummaryWriter(self.tensorborad_dir)
            self.counter = 0

        if not self.isTrain or self.opt.continue_train:
            pretrained_path = ''  # if not self.isTrain else opt.load_pretrain

            if self.opt.pretrained_model != '' or (
                    self.opt.pretrained_model_HHA != ''
                    and self.opt.pretrained_model_rgb != ''):
                if self.opt.pretrained_model_HHA != '' and self.opt.pretrained_model_rgb != '':
                    self.load_pretrained_network(self.model_rgb,
                                                 self.opt.pretrained_model_rgb,
                                                 self.opt.which_epoch_rgb,
                                                 False)
                    self.load_pretrained_network(self.model_HHA,
                                                 self.opt.pretrained_model_HHA,
                                                 self.opt.which_epoch_HHA,
                                                 False)
                else:
                    self.load_pretrained_network(self.model_rgb,
                                                 self.opt.pretrained_model,
                                                 self.opt.which_epoch, False)
                    self.load_pretrained_network(self.model_HHA,
                                                 self.opt.pretrained_model,
                                                 self.opt.which_epoch, False)
                print(
                    "successfully loaded from pretrained model with given path!"
                )
            else:
                self.load()
                print("successfully loaded from pretrained model 0!")

        self.model_rgb.cuda()
        self.model_HHA.cuda()
        self.normweightgrad = 0.
 def initialize(self, opt):
     BaseModel.initialize(self, opt)        
     
     # define losses
     self.define_losses() 
     self.tD = 1           
コード例 #11
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.count = 0
        input_depth = opt.input_nc
        output_depth = opt.output_nc
        self.net_shared = skip(input_depth,
                               num_channels_down=[64, 128, 256, 256, 256],
                               num_channels_up=[64, 128, 256, 256, 256],
                               num_channels_skip=[4, 4, 4, 4, 4],
                               upsample_mode=[
                                   'nearest', 'nearest', 'bilinear',
                                   'bilinear', 'bilinear'
                               ],
                               need_sigmoid=True,
                               need_bias=True,
                               pad='reflection')
        self.netDec_a = ResNet_decoders(opt.ngf, output_depth)
        self.netDec_b = ResNet_decoders(opt.ngf, output_depth)

        self.net_input = self.get_noise(input_depth, 'noise',
                                        (self.opt.fineSize, self.opt.fineSize))
        self.net_input_saved = self.net_input.detach().clone()
        self.noise = self.net_input.detach().clone()

        use_sigmoid = opt.no_lsgan
        self.netD_b = networks.define_D(opt.output_nc, opt.ndf,
                                        opt.which_model_netD, opt.n_layers_D,
                                        opt.norm, use_sigmoid, opt.init_type,
                                        self.gpu_ids)

        if not opt.dont_load_pretrained_autoencoder:
            which_epoch = opt.which_epoch
            self.load_network(self.netDec_b, 'Dec_b', which_epoch)
            self.load_network(self.net_shared, 'Net_shared', which_epoch)
            self.load_network(self.netD_b, 'D', which_epoch)

        if len(self.gpu_ids) > 0:
            dtype = torch.cuda.FloatTensor
            self.net_input = self.net_input.type(dtype).detach()
            self.net_shared = self.net_shared.type(dtype)
            self.netDec_a = self.netDec_a.type(dtype)
            self.netDec_b = self.netDec_b.type(dtype)
            self.netD_b = self.netD_b.type(dtype)

        self.fake_A_pool = ImagePool(opt.pool_size)
        self.fake_B_pool = ImagePool(opt.pool_size)

        # define loss functions
        self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan,
                                             tensor=self.Tensor)
        self.mse = torch.nn.MSELoss()

        # initialize optimizers
        self.optimizer_Net = torch.optim.Adam(
            itertools.chain(self.net_shared.parameters(),
                            self.netDec_a.parameters()),
            lr=0.007,
            betas=(opt.beta1,
                   0.999))  # skip 0.01   # OST 0.001 # skip large 0.007
        self.optimizer_Dec_b = torch.optim.Adam(
            self.netDec_b.parameters(), lr=0.000007,
            betas=(opt.beta1,
                   0.999))  # OST 0.000007 skip 0.00002 skip large 0.000007
        self.optimizer_D_b = torch.optim.Adam(self.netD_b.parameters(),
                                              lr=0.0002,
                                              betas=(opt.beta1, 0.999))

        self.optimizers = []
        self.schedulers = []
        self.optimizers.append(self.optimizer_Net)
        self.optimizers.append(self.optimizer_Dec_b)
        self.optimizers.append(self.optimizer_D_b)
        for optimizer in self.optimizers:
            self.schedulers.append(networks.get_scheduler(optimizer, opt))
コード例 #12
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        if opt.resize_or_crop != 'none': # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True
        self.isTrain = opt.isTrain
        input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc

        ##### define networks
        # Generator network
        netG_input_nc = input_nc + opt.output_nc
        if not opt.no_instance:
            netG_input_nc += 1

        self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
                                      opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
                                      opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)

        # Discriminator network
        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = 4*opt.output_nc
            #netD_input_nc = input_nc + opt.output_nc
            if not opt.no_instance:
                netD_input_nc += 1
            self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
                                          opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)

        # Face discriminator network
        if self.isTrain and opt.face:
            use_sigmoid = opt.no_lsgan
            netD_input_nc = 2*opt.output_nc
            if not opt.no_instance:
                netD_input_nc += 1
            self.netDface = networks.define_D_face(netD_input_nc, opt.ndf,
                    opt.n_layers_D, opt.norm, use_sigmoid, 1, not
                    opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)

        #Face residual network
        if opt.face:
            self.faceGen = networks.define_G(opt.output_nc*2, opt.output_nc, 64, 'global',
                                  n_downsample_global=3, n_blocks_global=5, n_local_enhancers=0,
                                  n_blocks_local=0, norm=opt.norm, gpu_ids=self.gpu_ids)

        print('---------- Networks initialized -------------')

        # load networks
        if (not self.isTrain or opt.continue_train or opt.load_pretrain):
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
                if opt.face:
                    self.load_network(self.netDface, 'Dface', opt.which_epoch, pretrained_path)
            if opt.face:
                self.load_network(self.faceGen, 'Gface', opt.which_epoch, pretrained_path)

        # set loss functions and optimizers
        if self.isTrain:
            if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
                raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
            self.fake_pool = ImagePool(opt.pool_size)
            self.old_lr = opt.lr

            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionFeat = torch.nn.L1Loss()
            if not opt.no_vgg_loss:
                self.criterionVGG = networks.VGGLoss(self.gpu_ids)
            if opt.use_l1:
                self.criterionL1 = torch.nn.L1Loss()

            # Loss names
            self.loss_names = ['G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake', 'G_GANface', 'D_realface', 'D_fakeface']

            # initialize optimizers
            # optimizer G
            if opt.niter_fix_global > 0:
                import sys
                if sys.version_info >= (3,0):
                    finetune_list = set()
                else:
                    from sets import Set
                    finetune_list = Set()

                params_dict = dict(self.netG.named_parameters())
                params = []
                for key, value in params_dict.items():       
                    if key.startswith('model' + str(opt.n_local_enhancers)):                    
                        params += [value]
                        finetune_list.add(key.split('.')[0])  
                print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
                print('The layers that are finetuned are ', sorted(finetune_list))                         
            else:
                params = list(self.netG.parameters())

            if opt.face:
                params = list(self.faceGen.parameters())
            else:
                if opt.niter_fix_main == 0:
                    params += list(self.netG.parameters())

            self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))

            # optimizer D
            if opt.niter > 0 and opt.face:
                print('------------- Only training the face discriminator network (for %d epochs) ------------' % opt.niter)
                params = list(self.netDface.parameters())
            else:
                if opt.face:
                    params = list(self.netD.parameters()) + list(self.netDface.parameters())
                else:
                    params = list(self.netD.parameters())

            self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
コード例 #13
0
    def __init__(self, opt, dataset=None, encoder='VGG'):
        BaseModel.initialize(self, opt)
        self.encoder = encoder
        if encoder == 'VGG':
            self.model = Deeplab_VGG(self.opt.label_nc, self.opt.depthconv)

        if self.opt.isTrain:
            self.criterionSeg = torch.nn.CrossEntropyLoss(
                ignore_index=255).cuda()
            # self.criterionSeg = torch.nn.CrossEntropyLoss(ignore_index=255).cuda()
            # self.criterionSeg = nn.NLLLoss2d(ignore_index=255)#.cuda()

            if encoder == 'VGG':
                self.optimizer = torch.optim.SGD(
                    [{
                        'params': self.model.Scale.get_1x_lr_params_NOscale(),
                        'lr': self.opt.lr
                    }, {
                        'params': self.model.Scale.get_10x_lr_params(),
                        'lr': self.opt.lr
                    }, {
                        'params': self.model.Scale.get_2x_lr_params_NOscale(),
                        'lr': self.opt.lr,
                        'weight_decay': 0.
                    }, {
                        'params': self.model.Scale.get_20x_lr_params(),
                        'lr': self.opt.lr,
                        'weight_decay': 0.
                    }],
                    lr=self.opt.lr,
                    momentum=self.opt.momentum,
                    weight_decay=self.opt.wd)

            # self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.opt.lr, momentum=self.opt.momentum, weight_decay=self.opt.wd)

            self.old_lr = self.opt.lr
            self.averageloss = []
            # copy scripts
            self.model_path = './models'  #os.path.dirname(os.path.realpath(__file__))
            self.data_path = './data'  #os.path.dirname(os.path.realpath(__file__))
            shutil.copyfile(os.path.join(self.model_path, 'Deeplab.py'),
                            os.path.join(self.model_dir, 'Deeplab.py'))

            if encoder == 'VGG':
                shutil.copyfile(
                    os.path.join(self.model_path, 'VGG_Deeplab.py'),
                    os.path.join(self.model_dir, 'VGG_Deeplab.py'))
            shutil.copyfile(os.path.join(self.model_path, 'model_utils.py'),
                            os.path.join(self.model_dir, 'model_utils.py'))
            shutil.copyfile(os.path.join(self.data_path, dataset.datafile),
                            os.path.join(self.model_dir, dataset.datafile))
            shutil.copyfile(os.path.join(self.data_path, 'base_dataset.py'),
                            os.path.join(self.model_dir, 'base_dataset.py'))

            self.writer = SummaryWriter(self.tensorborad_dir)
            self.counter = 0

        if not self.isTrain or self.opt.continue_train:
            if self.opt.pretrained_model != '':
                self.load_pretrained_network(self.model,
                                             self.opt.pretrained_model,
                                             self.opt.which_epoch,
                                             strict=False)
                print(
                    "Successfully loaded from pretrained model with given path!"
                )
            else:
                self.load()
                print("Successfully loaded model, continue training....!")

        self.model.cuda()
        self.normweightgrad = 0.
コード例 #14
0
    def initialize(self):  # , opt
        BaseModel.initialize(self)  # , opt

        batchSize = 32
        fineSize = 256
        input_nc = 3
        output_nc = 3
        vgg = 0
        skip = 0.8
        ngf = 64
        pool_size = 50
        norm = 'instance'
        lr = 0.0001
        no_dropout = True
        no_lsgan = True
        continue_train = True
        use_wgan = 0.0
        use_mse = True
        beta1 = 0.5
        global which_direction
        new_lr = True
        niter_decay = 100
        l1 = 10.0

        # batch size
        nb = batchSize
        # 图像size
        size = fineSize
        #self.opt = opt
        self.input_A = self.Tensor(nb, input_nc, size, size)
        self.input_B = self.Tensor(nb, output_nc, size, size)
        self.input_img = self.Tensor(nb, input_nc, size, size)
        self.input_A_gray = self.Tensor(nb, 1, size, size)

        # Default 0, use perceptrual loss
        if vgg > 0:
            self.vgg_loss = networks.PerceptualLoss()
            self.vgg_loss.cuda()
            self.vgg = networks.load_vgg16("./model")
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        # default=0.8, help='B = net.forward(A) + skip*A'
        skip = True if skip > 0 else False

        # which_model_netG,  default = 'unet-256', selects model to use for netG
        # ngf, default = 64, of gen filters in first conv layer'
        # norm, default = 'instance', instance normalization or batch normalization
        # no_dropout, default = 'True', no dropout for the generator
        self.netG_A = networks.define_G(input_nc,
                                        output_nc,
                                        ngf,
                                        which_model_netG,
                                        norm,
                                        not no_dropout,
                                        self.gpu_ids,
                                        skip=skip)

        if not self.isTrain or continue_train:
            #which epoch to load
            which_epoch = 'lastest'
            self.load_network(self.netG_A, 'G_A', which_epoch)

        # --pool_size', default=50, help='the size of image buffer that stores previously generated images'
        # lr, default=0.0001
        if self.isTrain:
            self.old_lr = lr
            self.fake_A_pool = ImagePool(pool_size)
            self.fake_B_pool = ImagePool(pool_size)
            # define loss functions
            if use_wgan:
                self.criterionGAN = networks.DiscLossWGANGP()
            else:
                # no_lsgan = True
                self.criterionGAN = networks.GANLoss(use_lsgan=not no_lsgan,
                                                     tensor=self.Tensor)
            if use_mse:
                self.criterionCycle = torch.nn.MSELoss()
            else:
                self.criterionCycle = torch.nn.L1Loss()
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(self.netG_A.parameters(),
                                                lr=lr,
                                                betas=(beta1, 0.999))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A)
        if isTrain:
            self.netG_A.train()
        else:
            self.netG_A.eval()
        print('-----------------------------------------------')