Example #1
0
    def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        netD = networks.define_D(opt) if opt.isTrain else None
        netD2 = networks.define_D(
            opt) if opt.isTrain and opt.unpairTrain else None
        netE = networks.define_E(
            opt) if opt.use_vae else None  # this is for original spade network
        netIG = networks.define_IG(
            opt
        ) if opt.use_ig else None  # this is the orient inpainting network
        netSIG = networks.define_SIG(
            opt
        ) if opt.use_stroke else None  # this is the stroke orient inpainting network
        netFE = networks.define_FE(
            opt
        ) if opt.use_instance_feat else None  # this is the feat encoder from pix2pixHD
        netB = networks.define_B(opt) if opt.use_blender else None

        if not opt.isTrain or opt.continue_train:
            # if the pth exist
            save_filename = '%s_net_%s.pth' % (opt.which_epoch, 'G')
            save_dir = os.path.join(opt.checkpoints_dir, opt.name)
            G_path = os.path.join(save_dir, save_filename)
            if os.path.exists(G_path):

                netG = util.load_network(netG, 'G', opt.which_epoch, opt)
                if opt.fix_netG:
                    netG.eval()
                if opt.use_blender:
                    netB = util.load_blend_network(netB, 'B', opt.which_epoch,
                                                   opt)
                if opt.isTrain:
                    netD = util.load_network(netD, 'D', opt.which_epoch, opt)
                    if opt.unpairTrain:
                        netD2 = util.load_network(netD2, 'D', opt.which_epoch,
                                                  opt)
                if opt.use_vae:
                    netE = util.load_network(netE, 'E', opt.which_epoch, opt)
        if opt.use_ig:
            netIG = util.load_inpainting_network(netIG, opt)
            netIG.eval()
        if opt.use_stroke:
            netSIG = util.load_sinpainting_network(netSIG, opt)
            netSIG.eval()

        return netG, netD, netE, netIG, netFE, netB, netD2, netSIG
Example #2
0
    def __init__(self, config):
        super().__init__()
        self.cfg = config
        self.loss_g_weights = np.array([1.0, 1.0, 10, 10, 10, 10])
        self.loss_g_weights /= self.loss_g_weights.sum()

        self.batch_size = self.cfg.batch_size
        self.dataset_path = self.cfg.dataroot
        self.num_wrokers = 32

        if self.cfg.resize_or_crop != "none" or self.cfg.isTrain is False:
            # when training at full res this causes OOM
            torch.backends.cudnn.benchmark = True

        mask_channels = (
            self.cfg.label_nc if self.cfg.label_nc != 0 else self.cfg.input_nc
        )

        # vae network
        self.netVAE = networks.define_VAE(input_nc=mask_channels)
        vae_checkpoint = torch.load(self.cfg.vae_path)
        self.netVAE.load_state_dict(vae_checkpoint["vae"])
        self.vae_lambda = 2.5

        # generator network
        self.netG = networks.define_G(
            mask_channels,
            self.cfg.output_nc,  # image channels
            self.cfg.ngf,  # gen filters in first conv layer
            self.cfg.netG,  # global or local
            self.cfg.n_downsample_global,  # num of downsampling layers in netG
            self.cfg.n_blocks_global,  # num of residual blocks
            self.cfg.n_local_enhancers,  # ignored
            self.cfg.n_blocks_local,  # ignored
            self.cfg.norm,  # instance normalization or batch normalization
        )
        # discriminator network
        if self.cfg.isTrain:
            use_sigmoid = self.cfg.lsgan is False
            netD_input_nc = mask_channels + self.cfg.output_nc
            self.netD = networks.define_D(
                netD_input_nc,
                self.cfg.ndf,  # filters in first conv layer
                self.cfg.n_layers_D,
                self.cfg.norm,
                use_sigmoid,
                self.cfg.num_D,
                getIntermFeat=self.cfg.ganFeat_loss,
            )
            netB_input_nc = self.cfg.output_nc * 2
            self.netB = networks.define_B(
                netB_input_nc, self.cfg.output_nc, 32, 3, 3, self.cfg.norm
            )
        # loss functions
        self.use_pool = self.cfg.pool_size > 0
        if self.cfg.pool_size > 0:
            self.fake_pool = ImagePool(self.cfg.pool_size)

        self.criterionGAN = networks.GANLoss(use_lsgan=self.cfg.lsgan,)
        self.criterionFeat = torch.nn.L1Loss()
        self.criterionVGG = networks.VGGLoss(self.cfg.gpu_ids)
Example #3
0
    def __init__(self, opt):
        super(AFLGANModel, self).__init__(opt)
        train_opt = opt['train']
        self.state_opt = self.opt['train']['which_state']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)  # G
        self.netD = networks.define_D(opt).to(self.device)  # D
        self.netB = networks.define_B(opt).to(self.device)  # B
        if self.is_train:
            self.netG.train()
            self.netD.train()
            self.netB.train()
        self.load()  # load G and D if needed

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                print('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                print('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters are for WGAN
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            if train_opt['gan_type'] == 'wgan-gp':
                self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
                # gradient penalty loss
                self.cri_gp = GradientPenaltyLoss(device=self.device).to(
                    self.device)
                self.l_gp_w = train_opt['gp_weigth']

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print(
                        'WARNING: params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
                weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_D)

            self.optimizer_B = torch.optim.Adam(self.netB.parameters(), lr=train_opt['lr_D'], \
                                                weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
            self.optimizers.append(self.optimizer_B)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')