Ejemplo n.º 1
0
    def __init__(self, opt, testing=False):

        ##### model options
        self.old_lr = opt.lr
        opt.use_sigmoid = opt.no_lsgan

        self.opt = opt

        ##### define all networks we need here
        self.netG_A_B = networks.define_stochastic_G(nlatent=opt.nlatent, input_nc=opt.input_nc,
                                                     output_nc=opt.output_nc, ngf=opt.ngf,
                                                     which_model_netG=opt.which_model_netG,
                                                     norm=opt.norm, use_dropout=opt.use_dropout,
                                                     gpu_ids=opt.gpu_ids)

        self.netG_B_A = networks.define_G(input_nc=opt.output_nc,
                                          output_nc=opt.input_nc, ngf=opt.ngf,
                                          which_model_netG=opt.which_model_netG,
                                          norm=opt.norm, use_dropout=opt.use_dropout,
                                          gpu_ids=opt.gpu_ids)

        enc_input_nc = opt.output_nc
        if opt.enc_A_B:
            enc_input_nc += opt.input_nc
        self.netE_B = networks.define_E(nlatent=opt.nlatent, input_nc=enc_input_nc,
                                        nef=opt.nef, norm='batch', gpu_ids=opt.gpu_ids)

        self.netD_A = networks.define_D_A(input_nc=opt.input_nc,
                                          ndf=32, which_model_netD=opt.which_model_netD,
                                          norm=opt.norm, use_sigmoid=opt.use_sigmoid, gpu_ids=opt.gpu_ids)

        self.netD_B = networks.define_D_B(input_nc=opt.output_nc,
                                          ndf=opt.ndf, which_model_netD=opt.which_model_netD,
                                          norm=opt.norm, use_sigmoid=opt.use_sigmoid, gpu_ids=opt.gpu_ids)

        self.netD_z_B = networks.define_LAT_D(nlatent=opt.nlatent, ndf=opt.ndf,
                                              use_sigmoid=opt.use_sigmoid,
                                              gpu_ids=opt.gpu_ids)

        ##### define all optimizers here
        self.optimizer_G_A = torch.optim.Adam(self.netG_B_A.parameters(),
                                              lr=opt.lr, betas=(opt.beta1, 0.999))
        self.optimizer_G_B = torch.optim.Adam(itertools.chain(self.netG_A_B.parameters(),
                                                              self.netE_B.parameters()),
                                              lr=opt.lr, betas=(opt.beta1, 0.999))
        self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                              lr=opt.lr/5., betas=(opt.beta1, 0.999))
        self.optimizer_D_B = torch.optim.Adam(itertools.chain(self.netD_B.parameters(),
                                                              self.netD_z_B.parameters(),
                                                              ),
                                              lr=opt.lr/5., betas=(opt.beta1, 0.999))
        self.criterionGAN = functools.partial(criterion_GAN, use_sigmoid=opt.use_sigmoid)
        self.criterionCycle = F.l1_loss

        if not testing:
            with open("%s/nets.txt" % opt.expr_dir, 'w') as nets_f:
                networks.print_network(self.netG_A_B, nets_f)
                networks.print_network(self.netG_B_A, nets_f)
                networks.print_network(self.netD_A, nets_f)
                networks.print_network(self.netD_B, nets_f)
                networks.print_network(self.netD_z_B, nets_f)
                networks.print_network(self.netE_B, nets_f)
Ejemplo n.º 2
0
    def __init__(self, hyperparameters):
        super(myMUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.style_dim = hyperparameters['gen']['style_dim']
        self.enc_a = networks.define_E(input_nc=3,
                                       output_nc=self.style_dim,
                                       ndf=64)  # encoder for domain a
        self.enc_b = networks.define_E(input_nc=3,
                                       output_nc=self.style_dim,
                                       ndf=64)  # encoder for domain b
        self.gen_a = networks.define_G(input_nc=3,
                                       output_nc=3,
                                       nz=self.style_dim,
                                       ngf=64)  # generator for domain a
        self.gen_b = networks.define_G(input_nc=3,
                                       output_nc=3,
                                       nz=self.style_dim,
                                       ngf=64)  # generator for domain b
        self.dis_a = networks.define_D(input_nc=3,
                                       ndf=64,
                                       norm='instance',
                                       num_Ds=2)  # discriminator for domain a
        self.dis_b = networks.define_D(input_nc=3,
                                       ndf=64,
                                       norm='instance',
                                       num_Ds=2)  # discriminator for domain b
        self.netVGGF = networks.define_VGGF()
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # Initiate the criterions or loss functions
        self.criterionGAN = networks.GANLoss(
            mse_loss=True,
            tensor=torch.cuda.FloatTensor)  # criterion GAN adversarial loss
        self.wGANloss = networks.wGANLoss(
            tensor=torch.cuda.FloatTensor)  # wGAN adversarial loss
        self.criterionL1 = torch.nn.L1Loss()  # L1 loss
        self.criterionL2 = networks.L2Loss()  # L2 loss
        self.criterionZ = torch.nn.L1Loss()  # L1 loss between code
        self.criterionC = networks.ContentLoss(
            vgg_features=self.netVGGF)  # content loss
        self.criterionS = networks.StyleLoss(
            vgg_features=self.netVGGF)  # style loss
        self.criterionC_l = networks.ContentLoss(
            vgg_features=self.netVGGF)  # local content loss
        self.criterionS_l = networks.StyleLoss(
            vgg_features=self.netVGGF)  # local style loss
        self.criterionHisogram = networks.HistogramLoss(
            vgg_features=self.netVGGF)  # histogram loss
        self.Feature_map_im = networks.Feature_map_im(
            vgg_features=self.netVGGF)  # show feature map

        # fix the noise used in sampling
        self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False