Esempio n. 1
0
    def initialize_networks(self, opt):
        netG = None
        netD = None
        netE = None
        netV = None
        netA = None
        netA_sync = None
        if opt.train_recognition:
            netV = networks.define_V(opt)
        elif opt.train_sync:
            netA_sync = networks.define_A_sync(opt) if opt.use_audio else None
            netE = networks.define_E(opt)
        else:

            netG = networks.define_G(opt)
            netA = networks.define_A(
                opt) if opt.use_audio and opt.use_audio_id else None
            netA_sync = networks.define_A_sync(opt) if opt.use_audio else None
            netE = networks.define_E(opt)
            netV = networks.define_V(opt)

            if opt.isTrain:
                netD = networks.define_D(opt)

        if not opt.isTrain or opt.continue_train:
            self.load_network(netG, 'G', opt.which_epoch)
            self.load_network(netV, 'V', opt.which_epoch)
            self.load_network(netE, 'E', opt.which_epoch)
            if opt.use_audio:
                if opt.use_audio_id:
                    self.load_network(netA, 'A', opt.which_epoch)
                self.load_network(netA_sync, 'A_sync', opt.which_epoch)

            if opt.isTrain and not opt.noload_D:
                self.load_network(netD, 'D', opt.which_epoch)
                # self.load_network(netD_rotate, 'D_rotate', opt.which_epoch, pretrained_path)

        else:
            if self.opt.pretrain:
                if opt.netE == 'fan':
                    netE.load_pretrain()
                netV.load_pretrain()
            if opt.load_separately:
                netG = self.load_separately(netG, 'G', opt)
                netA = self.load_separately(
                    netA, 'A',
                    opt) if opt.use_audio and opt.use_audio_id else None
                netA_sync = self.load_separately(
                    netA_sync, 'A_sync', opt) if opt.use_audio else None
                netV = self.load_separately(netV, 'V', opt)
                netE = self.load_separately(netE, 'E', opt)
                if not opt.noload_D:
                    netD = self.load_separately(netD, 'D', opt)
        return netG, netD, netA, netA_sync, netV, netE
Esempio n. 2
0
    def initialize_networks(self, opt, end2end=False, triple=False):
        if opt.end2endtri:
            netG_1 = networks.define_G(opt, triple)
            netD_1 = networks.define_D(opt, triple)
        netG = networks.define_G(opt)
        netD = networks.define_D(opt) if opt.isTrain else None
        netE = networks.define_E(opt) if opt.use_vae else None

        if not opt.isTrain or opt.continue_train:
            if opt.end2endtri:
                netG_1 = util.load_network(netG_1, 'G', opt.which_triple_epoch,
                                           opt, triple)
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain:  # and end2end:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)
                if opt.end2endtri:
                    netD_1 = util.load_network(netD_1, 'D',
                                               opt.which_triple_epoch, opt,
                                               triple)
            if opt.use_vae:
                netE = util.load_network(netE, 'E', opt.which_epoch, opt)
        if not opt.end2endtri:
            netG_1 = None
            netD_1 = None

        return netG, netD, netE, netG_1, netD_1
    def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        # netD = networks.define_D(opt) if opt.isTrain else None
        if opt.isTrain:
            opt.label_nc = opt.label_nc-1
            netD = networks.define_D(opt)
        else:
            netD_fine = None

        netE = networks.define_E(opt) if opt.use_vae else None
        if opt.isTrain:
            opt.label_nc = (opt.label_nc+1)
            netD_fine = networks.define_D(opt)
        else:
            netD_fine = None

        if not opt.isTrain or opt.continue_train:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)
                netD_fine = util.load_network(netD_fine, 'D', opt.which_epoch, opt)
            else:
                netD = None
                netD_fine = None
            if opt.use_vae:
                netE = util.load_network(netE, 'E', opt.which_epoch, opt)

        return netG, netD, netE, netD_fine
Esempio n. 4
0
    def initialize_networks(self, opt):

        netG = networks.define_G(opt)
        netD = networks.define_D(opt) if opt.isTrain else None
        netD_rotate = networks.define_D(opt) if opt.isTrain else None
        netE = networks.define_E(opt) if opt.use_vae else None
        pretrained_path = ''
        if not opt.isTrain or opt.continue_train:
            self.load_network(netG, 'G', opt.which_epoch, pretrained_path)
            if opt.isTrain and not opt.noload_D:
                self.load_network(netD, 'D', opt.which_epoch, pretrained_path)
                self.load_network(netD_rotate, 'D_rotate', opt.which_epoch,
                                  pretrained_path)
            if opt.use_vae:
                self.load_network(netE, 'E', opt.which_epoch, pretrained_path)
        else:

            if opt.load_separately:
                netG = self.load_separately(netG, 'G', opt)
                if not opt.noload_D:
                    netD = self.load_separately(netD, 'D', opt)
                    netD_rotate = self.load_separately(netD_rotate, 'D_rotate',
                                                       opt)
                if opt.use_vae:
                    netE = self.load_separately(netE, 'E', opt)

        return netG, netD, netE, netD_rotate
Esempio n. 5
0
    def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        netD = networks.define_D(opt) if opt.isTrain else None
        netE = networks.define_E(opt) if opt.use_vae else None

        if not opt.isTrain or opt.continue_train:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)

        return netG, netD, netE
Esempio n. 6
0
    def initialize_networks(self, opt):

        netG2 = networks.define_G(opt)
        netD2 = networks.define_D(opt) if opt.isTrain else None
        netE2 = networks.define_E(opt) if opt.use_vae else None

        if not opt.isTrain or opt.continue_train:
            netG2 = util.load_network2(netG2, 'G', opt.which_epoch, opt)
            if opt.isTrain:
                netD2 = util.load_network2(netD2, 'D', opt.which_epoch, opt)
            if opt.use_vae:
                netE2 = util.load_network2(netE2, 'E', opt.which_epoch, opt)
        elif opt.use_vae and opt.pretrain_vae:
            netE2 = util.load_network2(netE2, 'E', opt.which_epoch, opt)

        if opt.edge_cat:
            opt.label_nc -= 1
            opt.semantic_nc -= 1

        netG = networks.define_G(opt)
        netD = networks.define_D(opt) if opt.isTrain else None
        netE = networks.define_E(opt) if opt.use_vae else None

        if not opt.isTrain or opt.continue_train:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)
            if opt.use_vae:
                netE = util.load_network(netE, 'E', opt.which_epoch, opt)
        elif opt.use_vae and opt.pretrain_vae:
            netE = util.load_network(netE, 'E', opt.which_epoch, opt)
            print('Load fixed netE.')

        if opt.edge_cat:
            opt.label_nc += 1
            opt.semantic_nc += 1

        return netG, netD, netE, netG2, netD2, netE2
Esempio n. 7
0
    def initialize_networks(self, opt):
        print(opt.isTrain)
        netG = networks.define_G(opt)
        netD = networks.define_D(
            opt) if opt.isTrain and not opt.no_disc else None
        netE = networks.define_E(opt)

        if not opt.isTrain or opt.continue_train or opt.manipulation:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain or opt.needs_D:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)
                print('network D loaded')

        return netG, netD, netE
Esempio n. 8
0
    def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        # if not opt.isTrain:
        #     print(netG)
        netD = networks.define_D(opt) if opt.isTrain else None
        netE = networks.define_E(opt) if opt.use_vae and opt.isTrain else None

        if not opt.isTrain or opt.continue_train:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)
                if opt.use_vae:
                    netE = util.load_network(netE, 'E', opt.which_epoch, opt)
                # netE_edge = util.load_network(netE_edge, '')

        return netG, netD, netE
Esempio n. 9
0
    def initialize_networks(self, opt):
        # doc: initializes one of the generator classes in generator.py file
        netG = networks.define_G(opt)
        # doc: initializes one of the discriminator classes in the discriminator.py file
        netD = networks.define_D(opt) if opt.isTrain else None
        # doc: initializes the VAE
        netE = networks.define_E(opt) if opt.use_vae else None

        if not opt.isTrain or opt.continue_train:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)
            if opt.use_vae:
                netE = util.load_network(netE, 'E', opt.which_epoch, opt)

        return netG, netD, netE
Esempio n. 10
0
    def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        netD = networks.define_D(opt) if opt.isTrain else None
        netD2 = networks.define_D(
            opt) if opt.isTrain and opt.unpairTrain else None
        netE = networks.define_E(
            opt) if opt.use_vae else None  # this is for original spade network
        netIG = networks.define_IG(
            opt
        ) if opt.use_ig else None  # this is the orient inpainting network
        netSIG = networks.define_SIG(
            opt
        ) if opt.use_stroke else None  # this is the stroke orient inpainting network
        netFE = networks.define_FE(
            opt
        ) if opt.use_instance_feat else None  # this is the feat encoder from pix2pixHD
        netB = networks.define_B(opt) if opt.use_blender else None

        if not opt.isTrain or opt.continue_train:
            # if the pth exist
            save_filename = '%s_net_%s.pth' % (opt.which_epoch, 'G')
            save_dir = os.path.join(opt.checkpoints_dir, opt.name)
            G_path = os.path.join(save_dir, save_filename)
            if os.path.exists(G_path):

                netG = util.load_network(netG, 'G', opt.which_epoch, opt)
                if opt.fix_netG:
                    netG.eval()
                if opt.use_blender:
                    netB = util.load_blend_network(netB, 'B', opt.which_epoch,
                                                   opt)
                if opt.isTrain:
                    netD = util.load_network(netD, 'D', opt.which_epoch, opt)
                    if opt.unpairTrain:
                        netD2 = util.load_network(netD2, 'D', opt.which_epoch,
                                                  opt)
                if opt.use_vae:
                    netE = util.load_network(netE, 'E', opt.which_epoch, opt)
        if opt.use_ig:
            netIG = util.load_inpainting_network(netIG, opt)
            netIG.eval()
        if opt.use_stroke:
            netSIG = util.load_sinpainting_network(netSIG, opt)
            netSIG.eval()

        return netG, netD, netE, netIG, netFE, netB, netD2, netSIG
Esempio n. 11
0
    def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        opt.input_nc = 2
        netD1 = networks.define_D(opt) if opt.isTrain else None
        opt.input_nc = 7
        netD2 = networks.define_D(opt) if opt.isTrain else None
        netE = networks.define_E(opt) if opt.use_vae else None

        if not opt.isTrain or opt.continue_train:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            if opt.isTrain:
                netD1 = util.load_network(netD1, 'D1', opt.which_epoch, opt)
                netD2 = util.load_network(netD2, 'D2', opt.which_epoch, opt)
            if opt.use_vae:
                netE = util.load_network(netE, 'E', opt.which_epoch, opt)

        return netG, netD1, netD2, netE
Esempio n. 12
0
    def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        netD = networks.define_D(opt) if opt.isTrain else None
        netE = networks.define_E(opt) if opt.use_vae else None

        if not opt.isTrain or opt.continue_train:
            netG = util.load_network(netG, 'G', opt.which_epoch, opt)
            print('net_G successfully loaded from {}.'.format(opt.which_epoch))
            if opt.isTrain:
                netD = util.load_network(netD, 'D', opt.which_epoch, opt)
                print('net_D successfully loaded from {}.'.format(
                    opt.which_epoch))
            if opt.use_vae:
                netE = util.load_network(netE, 'E', opt.which_epoch, opt)
                print('net_E successfully loaded from {}.'.format(
                    opt.which_epoch))

        return netG, netD, netE
Esempio n. 13
0
    def __init__(self, opt):
        super(LRimgestimator_Model, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        self.train_opt = train_opt
        self.kernel_size = opt['datasets']['train']['kernel_size']
        self.patch_size = opt['datasets']['train']['patch_size']
        self.batch_size = opt['datasets']['train']['batch_size']

        # define networks and load pretrained models
        self.scale = opt['scale']
        self.model_name = opt['network_E']['which_model_E']
        self.mode = opt['network_E']['mode']

        self.netE = networks.define_E(opt).to(self.device)
        if opt['dist']:
            self.netE = DistributedDataParallel(
                self.netE, device_ids=[torch.cuda.current_device()])
        else:
            self.netE = DataParallel(self.netE)
        self.load()

        # loss
        if train_opt['loss_ftn'] == 'l1':
            self.MyLoss = nn.L1Loss(reduction='mean').to(self.device)
        elif train_opt['loss_ftn'] == 'l2':
            self.MyLoss = nn.MSELoss(reduction='mean').to(self.device)
        else:
            self.MyLoss = None

        if self.is_train:
            self.netE.train()

            # optimizers
            self.optimizers = []
            wd_R = train_opt['weight_decay_R'] if train_opt[
                'weight_decay_R'] else 0
            optim_params = []
            for k, v in self.netE.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('WARNING: params [%s] will not optimize.' % k)
            self.optimizer_E = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_C'],
                                                weight_decay=wd_R)
            print('Weight_decay:%f' % wd_R)
            self.optimizers.append(self.optimizer_E)

            # schedulers
            self.schedulers = []
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                                                                    train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
Esempio n. 14
0
opt.serial_batches = True  # no shuffle

# create dataset
dataset = create_dataset(opt)
model = create_model(opt)
model.setup(opt)
model.eval()
print('Loading model %s' % opt.model)

######
sateOpt = SateOption()
sateE = networks.define_E(sateOpt.output_nc,
                          sateOpt.nz,
                          sateOpt.nef,
                          netE=sateOpt.netE,
                          norm=sateOpt.norm,
                          nl=sateOpt.nl,
                          init_type=sateOpt.init_type,
                          init_gain=sateOpt.init_gain,
                          gpu_ids=sateOpt.gpu_ids,
                          vaeLike=sateOpt.use_vae)
sateCheckpoint = torch.load('sate_encoder/sate_encoder_latest.pth')
sateE.load_state_dict(sateCheckpoint['model_state_dict'])
sateE.eval()
transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
######

# create website
web_dir = os.path.join(opt.results_dir,
Esempio n. 15
0
    def __init__(self, opts, input_dim, output_dim, lambda_ms=None):
        super(BicycleGANAdaIN, self).__init__()
        self.isTrain = (opts.phase == 'train')
        self.gpu_ids = opts.gpu_ids
        self.device = torch.device('cuda:{}'.format(
            self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
        self.nz = opts.nz

        # generator
        self.netG = networks.init_net(Gen(input_dim,
                                          output_dim,
                                          style_dim=opts.nz),
                                      init_type='xavier',
                                      init_gain=0.02,
                                      gpu_ids=self.gpu_ids)

        # discriminator
        if self.isTrain:
            self.netD = networks.define_D(output_dim,
                                          64,
                                          netD='basic_256_multi',
                                          norm='instance',
                                          num_Ds=2,
                                          gpu_ids=self.gpu_ids)
            self.netD2 = networks.define_D(output_dim,
                                           64,
                                           netD='basic_256_multi',
                                           norm='instance',
                                           num_Ds=2,
                                           gpu_ids=self.gpu_ids)

        # encoder
        self.netE = networks.define_E(output_dim,
                                      opts.nz,
                                      64,
                                      netE=opts.bicycleE,
                                      norm='instance',
                                      vaeLike=True,
                                      gpu_ids=self.gpu_ids)

        # loss and optimizer and scheduler
        if self.isTrain:
            self.criterionGAN = networks.GANLoss(gan_mode=opts.gan_mode).to(
                self.device)
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionZ = torch.nn.L1Loss()
            self.lambda_ms = 0. if lambda_ms is None else lambda_ms
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                lr=opts.lr,
                                                betas=(0.5, 0.999))
            self.optimizer_E = torch.optim.Adam(self.netE.parameters(),
                                                lr=opts.lr,
                                                betas=(0.5, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opts.lr,
                                                betas=(0.5, 0.999))
            self.optimizer_D2 = torch.optim.Adam(self.netD2.parameters(),
                                                 lr=opts.lr,
                                                 betas=(0.5, 0.999))
            self.optimizers = [
                self.optimizer_G, self.optimizer_E, self.optimizer_D,
                self.optimizer_D2
            ]
Esempio n. 16
0
        self.lr = 0.0002
        self.beta1 = 0.5
        self.device = torch.device('cuda:{}'.format(
            self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
        self.batch_size = 4
        self.epoch = 5
        return


if __name__ == '__main__':
    opt = option()
    netE = networks.define_E(opt.output_nc,
                             opt.nz,
                             opt.nef,
                             netE=opt.netE,
                             norm=opt.norm,
                             nl=opt.nl,
                             init_type=opt.init_type,
                             init_gain=opt.init_gain,
                             gpu_ids=opt.gpu_ids,
                             vaeLike=opt.use_vae)
    criterionL1 = torch.nn.L1Loss()
    optimizer = torch.optim.Adam(netE.parameters(),
                                 lr=opt.lr,
                                 betas=(opt.beta1, 0.999))

    transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])

    d = np.load('encoded_z_aug_expand.npy', allow_pickle=True).item()
Esempio n. 17
0
    def __init__(self, opt):
        super(SREFGANModel, self).__init__(opt)
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)  # D
            self.netG.train()
            self.netD.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)

            # network E
            if train_opt['aesthetic_criterion'] == "include":
                self.cri_aes = True
                self.netE = networks.define_E(opt).to(self.device)
                self.l_aes_w = train_opt['aesthetic_weight']
            else:
                self.cri_aes = None

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                    self.device)
                self.l_gp_w = train_opt['gp_weigth']

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
        # print network
        self.print_network()
Esempio n. 18
0
    def __init__(self, args):

        self.gpu_ids=[0]
        self.isTrain = True
        
        self.checkpoints_dir = './checkpoints'
        self.which_epoch = 'latest' # which epoch to load? set to latest to use latest cached model
        self.args = args
        # self.name = 'G_GAN_%s_lambdar_%s_lambdas_%s_alpha_%s' % (self.args.lambda_d, self.args.lambda_r, self.args.lambda_s, self.args.alpha)
        self.name = 'Res_convolution_Gram'
        expr_dir = os.path.join(self.checkpoints_dir, self.name)
        if not os.path.exists(expr_dir):
            os.makedirs(expr_dir)

        self.save_dir = os.path.join(self.checkpoints_dir, self.name)
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['G_style', 'G_content']
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        self.model_names = ['C', 'E', 'G']
        self.visual_names = ['A', 'B', 'C', 'R']
        self.input_nc = 3
        self.output_nc = 3
        self.ndf = 64 #number of filters in the first layer of discriminator
        self.ngf = 64

        use_sigmoid = False
        # define networks

        self.netCA = networks.define_channel_attention(self.gpu_ids)

        self.netKA = networks.define_kernel_attention(self.gpu_ids)

        self.netSA = networks.define_spatial_attention(self.gpu_ids)

        self.netC = networks.define_Convolution(self.gpu_ids)

        self.netKC = networks.define_K_Convolution(self.gpu_ids)

        self.netVGG = networks.define_VGG()

        self.netE = networks.define_E(self.input_nc, self.ngf, self.gpu_ids)

        self.netG = networks.define_G(self.input_nc, self.output_nc, self.ngf, self.gpu_ids)


        self.criterionMSE = torch.nn.MSELoss()

        self.criterionL1 = torch.nn.L1Loss()


        # initialize optimizers

        self.schedulers = []
        self.optimizers = []

        self.optimizer_CA = torch.optim.Adam(self.netCA.parameters(),
                                            lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_KA = torch.optim.Adam(self.netKA.parameters(),
                                             lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_SA = torch.optim.Adam(self.netSA.parameters(),
                                             lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_C = torch.optim.Adam(self.netC.parameters(),
                                            lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_KC = torch.optim.Adam(self.netKC.parameters(),
                                            lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_E = torch.optim.Adam(self.netE.parameters(),
                                            lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=0.0002, betas=(0.5, 0.999))

        self.optimizers.append(self.optimizer_CA)
        self.optimizers.append(self.optimizer_KA)
        self.optimizers.append(self.optimizer_SA)
        self.optimizers.append(self.optimizer_C)
        self.optimizers.append(self.optimizer_KC)
        self.optimizers.append(self.optimizer_E)
        self.optimizers.append(self.optimizer_G)

        for optimizer in self.optimizers:
            self.schedulers.append(networks.get_scheduler(optimizer, lr_policy='lambda', epoch_count=1, niter=100, niter_decay=100, lr_decay_iters=50))

        if not self.isTrain or args.continue_train:
            self.load_networks(self.which_epoch)

        self.print_networks()
Esempio n. 19
0
    def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        netD = networks.define_D(opt) if opt.isTrain else None
        netE = networks.define_E(opt) if opt.use_vae else None

        return netG, netD, netE