def initialize(self, opt):
        super(SupervisedPoseTransferModel, self).initialize(opt)
        ###################################
        # define transformer
        ###################################
        if opt.which_model_T == 'resnet':
            self.netT = networks.ResnetGenerator(
                input_nc=3 + self.get_pose_dim(opt.pose_type),
                output_nc=3,
                ngf=opt.T_nf,
                norm_layer=networks.get_norm_layer(opt.norm),
                use_dropout=not opt.no_dropout,
                n_blocks=9,
                gpu_ids=opt.gpu_ids)
        elif opt.which_model_T == 'unet':
            self.netT = networks.UnetGenerator_v2(
                input_nc=3 + self.get_pose_dim(opt.pose_type),
                output_nc=3,
                num_downs=8,
                ngf=opt.T_nf,
                norm_layer=networks.get_norm_layer(opt.norm),
                use_dropout=not opt.no_dropout,
                gpu_ids=opt.gpu_ids)
        else:
            raise NotImplementedError()

        if opt.gpu_ids:
            self.netT.cuda()
        networks.init_weights(self.netT, init_type=opt.init_type)
        ###################################
        # define discriminator
        ###################################
        self.use_GAN = self.is_train and opt.loss_weight_gan > 0
        if self.use_GAN > 0:
            self.netD = networks.define_D_from_params(
                input_nc=3 +
                self.get_pose_dim(opt.pose_type) if opt.D_cond else 3,
                ndf=opt.D_nf,
                which_model_netD='n_layers',
                n_layers_D=3,
                norm=opt.norm,
                which_gan=opt.which_gan,
                init_type=opt.init_type,
                gpu_ids=opt.gpu_ids)
        else:
            self.netD = None
        ###################################
        # loss functions
        ###################################
        if self.is_train:
            self.loss_functions = []
            self.schedulers = []
            self.optimizers = []

            self.crit_L1 = nn.L1Loss()
            self.crit_vgg = networks.VGGLoss_v2(self.gpu_ids)
            # self.crit_vgg_old = networks.VGGLoss(self.gpu_ids)
            self.crit_psnr = networks.PSNR()
            self.crit_ssim = networks.SSIM()
            self.loss_functions += [self.crit_L1, self.crit_vgg]
            self.optim = torch.optim.Adam(self.netT.parameters(),
                                          lr=opt.lr,
                                          betas=(opt.beta1, opt.beta2))
            self.optimizers += [self.optim]

            if self.use_GAN:
                self.crit_GAN = networks.GANLoss(
                    use_lsgan=opt.which_gan == 'lsgan', tensor=self.Tensor)
                self.optim_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr_D,
                                                betas=(opt.beta1, opt.beta2))
                self.loss_functions.append(self.use_GAN)
                self.optimizers.append(self.optim_D)
            # todo: add pose loss
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))

            self.fake_pool = ImagePool(opt.pool_size)

        ###################################
        # load trained model
        ###################################
        if not self.is_train:
            self.load_network(self.netT, 'netT', opt.which_model)
    def initialize(self, opt):
        super(MultimodalDesignerGAN_V2, self).initialize(opt)
        ###################################
        # define networks
        ###################################
        self.modules = {}
        # shape branch
        if opt.which_model_netG != 'unet':
            self.shape_encoder = networks.define_image_encoder(opt, 'shape')
            self.modules['shape_encoder'] = self.shape_encoder
        else:
            self.shape_encoder = None
        # edge branch
        if opt.use_edge:
            self.edge_encoder = networks.define_image_encoder(opt, 'edge')
            self.modules['edge_encoder'] = self.edge_encoder
        else:
            self.encoder_edge = None
        # color branch
        if opt.use_color:
            self.color_encoder = networks.define_image_encoder(opt, 'color')
            self.modules['color_encoder'] = self.color_encoder
        else:
            self.color_encoder = None

        # fusion model
        if opt.ftn_model == 'none':
            # shape_feat, edge_feat and color_feat will be simply upmpled to same size (size of shape_feat) and concatenated
            pass
        elif opt.ftn_model == 'concat':
            assert opt.use_edge or opt.use_color
            if opt.use_edge:
                self.edge_trans_net = networks.define_feature_fusion_network(
                    name='FeatureConcatNetwork',
                    feat_nc=opt.edge_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['edge_trans_net'] = self.edge_trans_net
            if opt.use_color:
                self.color_trans_net = networks.define_feature_fusion_network(
                    name='FeatureConcatNetwork',
                    feat_nc=opt.color_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['color_trans_net'] = self.color_trans_net
        elif opt.ftn_model == 'reduce':
            assert opt.use_edge or opt.use_color
            if opt.use_edge:
                self.edge_trans_net = networks.define_feature_fusion_network(
                    name='FeatureReduceNetwork',
                    feat_nc=opt.edge_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    ndowns=opt.ftn_ndowns,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['edge_trans_net'] = self.edge_trans_net
            if opt.use_color:
                self.color_trans_net = networks.define_feature_fusion_network(
                    name='FeatureReduceNetwork',
                    feat_nc=opt.color_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    ndowns=opt.ftn_ndowns,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['color_trans_net'] = self.color_trans_net

        elif opt.ftn_model == 'trans':
            assert opt.use_edge or opt.use_color
            if opt.use_edge:
                self.edge_trans_net = networks.define_feature_fusion_network(
                    name='FeatureTransformNetwork',
                    feat_nc=opt.edge_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    feat_size=opt.feat_size_lr,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['edge_trans_net'] = self.edge_trans_net
            if opt.use_color:
                self.color_trans_net = networks.define_feature_fusion_network(
                    name='FeatureTransformNetwork',
                    feat_nc=opt.color_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    feat_size=opt.feat_size_lr,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['color_trans_net'] = self.color_trans_net

        # netG
        self.netG = networks.define_generator(opt)
        self.modules['netG'] = self.netG

        # netD
        if self.is_train:
            self.netD = networks.define_D(opt)
            self.modules['netD'] = self.netD

        ###################################
        # load weights
        ###################################
        if self.is_train:
            if opt.continue_train:
                for label, net in self.modules.iteritems():
                    self.load_network(net, label, opt.which_epoch)
            else:
                if opt.which_model_init != 'none':
                    # load pretrained entire model
                    for label, net in self.modules.iteritems():
                        self.load_network(net,
                                          label,
                                          'latest',
                                          opt.which_model_init,
                                          forced=False)
                else:
                    # load pretrained encoder
                    if opt.which_model_netG != 'unet' and opt.pretrain_shape:
                        self.load_network(self.shape_encoder, 'shape_encoder',
                                          'latest',
                                          opt.which_model_init_shape_encoder)
                    if opt.use_edge and opt.pretrain_edge:
                        self.load_network(self.edge_encoder, 'edge_encoder',
                                          'latest',
                                          opt.which_model_init_edge_encoder)
                    if opt.use_color and opt.pretrain_color:
                        self.load_network(self.color_encoder, 'color_encoder',
                                          'latest',
                                          opt.which_model_init_color_encoder)
        else:
            for label, net in self.modules.iteritems():
                if label != 'netD':
                    self.load_network(net, label, opt.which_epoch)

        ###################################
        # prepare for training
        ###################################
        if self.is_train:
            self.fake_pool = ImagePool(opt.pool_size)
            ###################################
            # define loss functions
            ###################################
            self.loss_functions = []
            if opt.which_gan in {'dcgan', 'lsgan'}:
                self.crit_GAN = networks.GANLoss(
                    use_lsgan=opt.which_gan == 'lsgan', tensor=self.Tensor)
                self.loss_functions.append(self.crit_GAN)
            else:
                # WGAN loss will be calculated in self.backward_D_wgangp and self.backward_G
                self.crit_GAN = None

            self.crit_L1 = nn.L1Loss()
            self.loss_functions.append(self.crit_L1)

            if self.opt.loss_weight_vgg > 0:
                self.crit_vgg = networks.VGGLoss(self.gpu_ids)
                self.loss_functions.append(self.crit_vgg)

            if self.opt.G_output_seg:
                self.crit_CE = nn.CrossEntropyLoss()
                self.loss_functions.append(self.crit_CE)

            self.crit_psnr = networks.SmoothLoss(networks.PSNR())
            self.loss_functions.append(self.crit_psnr)
            ###################################
            # create optimizers
            ###################################
            self.schedulers = []
            self.optimizers = []

            # G optimizer
            G_module_list = [
                'shape_encoder', 'edge_encoder', 'color_encoder', 'netG'
            ]
            G_param_groups = [{
                'params': self.modules[m].parameters()
            } for m in G_module_list if m in self.modules]
            self.optim_G = torch.optim.Adam(G_param_groups,
                                            lr=opt.lr,
                                            betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optim_G)
            # D optimizer
            self.optim_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr_D,
                                            betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optim_D)
            # feature transfer network optimizer
            FTN_module_list = ['edge_trans_net', 'color_trans_net']
            FTN_param_groups = [{
                'params': self.modules[m].parameters()
            } for m in FTN_module_list if m in self.modules]
            if len(FTN_param_groups) > 0:
                self.optim_FTN = torch.optim.Adam(FTN_param_groups,
                                                  lr=opt.lr_FTN,
                                                  betas=(0.9, 0.999))
                self.optimizers.append(self.optim_FTN)
            else:
                self.optim_FTN = None
            # schedulers
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))
Beispiel #3
0
    def initialize(self, opt):
        super(VUnetPoseTransferModel, self).initialize(opt)
        ###################################
        # define transformer
        ###################################
        self.netT = networks.VariationalUnet(
            input_nc_dec = self.get_pose_dim(opt.pose_type),
            input_nc_enc = self.get_appearance_dim(opt.appearance_type),
            output_nc = self.get_output_dim(opt.output_type),
            nf = opt.vunet_nf,
            max_nf = opt.vunet_max_nf,
            input_size = opt.fine_size,
            n_latent_scales = opt.vunet_n_latent_scales,
            bottleneck_factor = opt.vunet_bottleneck_factor,
            box_factor = opt.vunet_box_factor,
            n_residual_blocks = 2,
            norm_layer = networks.get_norm_layer(opt.norm),
            activation = nn.ReLU(False),
            use_dropout = False,
            gpu_ids = opt.gpu_ids,
            output_tanh = False,
            )
        if opt.gpu_ids:
            self.netT.cuda()
        networks.init_weights(self.netT, init_type=opt.init_type)
        ###################################
        # define discriminator
        ###################################
        self.use_GAN = self.is_train and opt.loss_weight_gan > 0
        if self.use_GAN:
            self.netD = networks.define_D_from_params(
                input_nc=3+self.get_pose_dim(opt.pose_type) if opt.D_cond else 3,
                ndf=opt.D_nf,
                which_model_netD='n_layers',
                n_layers_D=opt.D_n_layer,
                norm=opt.norm,
                which_gan=opt.which_gan,
                init_type=opt.init_type,
                gpu_ids=opt.gpu_ids)
        else:
            self.netD = None
        ###################################
        # loss functions
        ###################################
        self.crit_psnr = networks.PSNR()
        self.crit_ssim = networks.SSIM()

        if self.is_train:
            self.optimizers =[]
            self.crit_vgg = networks.VGGLoss_v2(self.gpu_ids, opt.content_layer_weight, opt.style_layer_weight, opt.shifted_style)
            # self.crit_vgg_old = networks.VGGLoss(self.gpu_ids)
            self.optim = torch.optim.Adam(self.netT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay)
            self.optimizers += [self.optim]

            if self.use_GAN:
                self.crit_GAN = networks.GANLoss(use_lsgan=opt.which_gan=='lsgan', tensor=self.Tensor)
                self.optim_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr_D, betas=(opt.beta1, opt.beta2))
                self.optimizers.append(self.optim_D)
            # todo: add pose loss
            self.fake_pool = ImagePool(opt.pool_size)

        ###################################
        # load trained model
        ###################################
        if not self.is_train:
            self.load_network(self.netT, 'netT', opt.which_epoch)
        elif opt.continue_train:
            self.load_network(self.netT, 'netT', opt.which_epoch)
            self.load_optim(self.optim, 'optim', opt.which_epoch)
            if self.use_GAN:
                self.load_network(self.netD, 'netD', opt.which_epoch)
                self.load_optim(self.optim_D, 'optim_D', opt.which_epoch)
        ###################################
        # schedulers
        ###################################
        if self.is_train:
            self.schedulers = []
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))
    def initialize(self, opt):
        super(MultimodalDesignerGAN, self).initialize(opt)
        ###################################
        # load/define networks
        ###################################

        # basic G
        self.netG = networks.define_G(opt)

        # encoders
        self.encoders = {}
        if opt.use_edge:
            self.edge_encoder = networks.define_image_encoder(opt, 'edge')
            self.encoders['edge_encoder'] = self.edge_encoder
        if opt.use_color:
            self.color_encoder = networks.define_image_encoder(opt, 'color')
            self.encoders['color_encoder'] = self.color_encoder
        if opt.use_attr:
            self.attr_encoder, self.opt_AE = network_loader.load_attribute_encoder_net(
                id=opt.which_model_AE, gpu_ids=opt.gpu_ids)

        # basic D and auxiliary Ds
        if self.is_train:
            # basic D
            self.netD = networks.define_D(opt)
            # auxiliary Ds
            self.auxiliaryDs = {}
            if opt.use_edge_D:
                assert opt.use_edge
                self.netD_edge = networks.define_D_from_params(
                    input_nc=opt.edge_nof + 3,
                    ndf=opt.ndf,
                    which_model_netD=opt.which_model_netD,
                    n_layers_D=opt.n_layers_D,
                    norm=opt.norm,
                    which_gan='dcgan',
                    init_type=opt.init_type,
                    gpu_ids=opt.gpu_ids)
                self.auxiliaryDs['D_edge'] = self.netD_edge
            if opt.use_color_D:
                assert opt.use_color
                self.netD_color = networks.define_D_from_params(
                    input_nc=opt.color_nof + 3,
                    ndf=opt.ndf,
                    which_model_netD=opt.which_model_netD,
                    n_layers_D=opt.n_layers_D,
                    norm=opt.norm,
                    which_gan='dcgan',
                    init_type=opt.init_type,
                    gpu_ids=opt.gpu_ids)
                self.auxiliaryDs['D_color'] = self.netD_color
            if opt.use_attr_D:
                assert opt.use_attr
                attr_nof = opt.n_attr_feat if opt.attr_cond_type in {
                    'feat', 'feat_map'
                } else opt.n_attr
                self.netD_attr = networks.define_D_from_params(
                    input_nc=attr_nof + 3,
                    ndf=opt.ndf,
                    which_model_netD=opt.which_model_netD,
                    n_layers_D=opt.n_layers_D,
                    norm=opt.norm,
                    which_gan='dcgan',
                    init_type=opt.init_type,
                    gpu_ids=opt.gpu_ids)
                self.auxiliaryDs['D_attr'] = self.netD_attr
            # load weights
            if not opt.continue_train:
                if opt.which_model_init != 'none':
                    self.load_network(self.netG, 'G', 'latest',
                                      opt.which_model_init)
                    self.load_network(self.netD, 'D', 'latest',
                                      opt.which_model_init)
                    for l, net in self.encoders.iteritems():
                        self.load_network(net, l, 'latest',
                                          opt.which_model_init)
                    for l, net in self.auxiliaryDs.iteritems():
                        self.load_network(net, l, 'latest',
                                          opt.which_model_init)
            else:
                self.load_network(self.netG, 'G', opt.which_epoch)
                self.load_network(self.netD, 'D', opt.which_epoch)
                for l, net in self.encoders.iteritems():
                    self.load_network(net, l, opt.which_epoch)
                for l, net in self.auxiliaryDs.iteritems():
                    self.load_network(net, l, opt.which_epoch)
        else:
            self.load_network(self.netG, 'G', opt.which_epoch)
            for l, net in self.encoders.iteritems():
                self.load_network(net, l, opt.which_epoch)

        if self.is_train:
            self.fake_pool = ImagePool(opt.pool_size)
            ###################################
            # define loss functions and loss buffers
            ###################################
            self.loss_functions = []
            if opt.which_gan in {'dcgan', 'lsgan'}:
                self.crit_GAN = networks.GANLoss(
                    use_lsgan=opt.which_gan == 'lsgan', tensor=self.Tensor)
            else:
                # WGAN loss will be calculated in self.backward_D_wgangp and self.backward_G
                self.crit_GAN = None

            self.loss_functions.append(self.crit_GAN)

            self.crit_L1 = nn.L1Loss()
            self.loss_functions.append(self.crit_L1)

            if self.opt.loss_weight_vgg > 0:
                self.crit_vgg = networks.VGGLoss(self.gpu_ids)
                self.loss_functions.append(self.crit_vgg)

            self.crit_psnr = networks.SmoothLoss(networks.PSNR())
            self.loss_functions.append(self.crit_psnr)
            ###################################
            # create optimizers
            ###################################
            self.schedulers = []
            self.optimizers = []

            # optim_G will optimize parameters of netG and all image encoders (except attr_encoder)
            G_param_groups = [{'params': self.netG.parameters()}]
            for l, net in self.encoders.iteritems():
                G_param_groups.append({'params': net.parameters()})
            self.optim_G = torch.optim.Adam(G_param_groups,
                                            lr=opt.lr,
                                            betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optim_G)
            # optim_D will optimize parameters of netD
            self.optim_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr_D,
                                            betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optim_D)
            # optim_D_aux will optimize parameters of auxiliaryDs
            if len(self.auxiliaryDs) > 0:
                aux_D_param_groups = [{
                    'params': net.parameters()
                } for net in self.auxiliaryDs.values()]
                self.optim_D_aux = torch.optim.Adam(aux_D_param_groups,
                                                    lr=opt.lr_D,
                                                    betas=(opt.beta1,
                                                           opt.beta2))
                self.optimizers.append(self.optim_D_aux)
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))

        # color transformation from std to imagenet
        # img_imagenet = img_std * a + b
        self.trans_std_to_imagenet = {
            'a':
            Variable(self.Tensor([0.5 / 0.229, 0.5 / 0.224, 0.5 / 0.225]),
                     requires_grad=False).view(3, 1, 1),
            'b':
            Variable(self.Tensor([(0.5 - 0.485) / 0.229, (0.5 - 0.456) / 0.224,
                                  (0.5 - 0.406) / 0.225]),
                     requires_grad=False).view(3, 1, 1)
        }
    def initialize(self, opt):
        super(PoseTransferModel, self).initialize(opt)
        ###################################
        # define generator
        ###################################
        if opt.which_model_G == 'unet':
            self.netG = networks.UnetGenerator(
                input_nc=self.get_tensor_dim('+'.join(
                    [opt.G_appearance_type, opt.G_pose_type])),
                output_nc=3,
                nf=opt.G_nf,
                max_nf=opt.G_max_nf,
                num_scales=opt.G_n_scale,
                n_residual_blocks=2,
                norm=opt.G_norm,
                activation=nn.LeakyReLU(0.1)
                if opt.G_activation == 'leaky_relu' else nn.ReLU(),
                use_dropout=opt.use_dropout,
                gpu_ids=opt.gpu_ids)
        elif opt.which_model_G == 'dual_unet':
            self.netG = networks.DualUnetGenerator(
                pose_nc=self.get_tensor_dim(opt.G_pose_type),
                appearance_nc=self.get_tensor_dim(opt.G_appearance_type),
                output_nc=3,
                aux_output_nc=[],
                nf=opt.G_nf,
                max_nf=opt.G_max_nf,
                num_scales=opt.G_n_scale,
                num_warp_scales=opt.G_n_warp_scale,
                n_residual_blocks=2,
                norm=opt.G_norm,
                vis_mode=opt.G_vis_mode,
                activation=nn.LeakyReLU(0.1)
                if opt.G_activation == 'leaky_relu' else nn.ReLU(),
                use_dropout=opt.use_dropout,
                no_end_norm=opt.G_no_end_norm,
                gpu_ids=opt.gpu_ids,
            )
        if opt.gpu_ids:
            self.netG.cuda()
        networks.init_weights(self.netG, init_type=opt.init_type)
        ###################################
        # define external pixel warper
        ###################################
        if opt.G_pix_warp:
            pix_warp_n_scale = opt.G_n_scale
            self.netPW = networks.UnetGenerator_MultiOutput(
                input_nc=self.get_tensor_dim(opt.G_pix_warp_input_type),
                output_nc=[1],  # only use one output branch (weight mask)
                nf=32,
                max_nf=128,
                num_scales=pix_warp_n_scale,
                n_residual_blocks=2,
                norm=opt.G_norm,
                activation=nn.ReLU(False),
                use_dropout=False,
                gpu_ids=opt.gpu_ids)
            if opt.gpu_ids:
                self.netPW.cuda()
            networks.init_weights(self.netPW, init_type=opt.init_type)
        ###################################
        # define discriminator
        ###################################
        self.use_gan = self.is_train and self.opt.loss_weight_gan > 0
        if self.use_gan:
            self.netD = networks.NLayerDiscriminator(
                input_nc=self.get_tensor_dim(opt.D_input_type_real),
                ndf=opt.D_nf,
                n_layers=opt.D_n_layers,
                use_sigmoid=(opt.gan_type == 'dcgan'),
                output_bias=True,
                gpu_ids=opt.gpu_ids,
            )
            if opt.gpu_ids:
                self.netD.cuda()
            networks.init_weights(self.netD, init_type=opt.init_type)
        ###################################
        # load optical flow model
        ###################################
        if opt.flow_on_the_fly:
            self.netF = load_flow_network(opt.pretrained_flow_id,
                                          opt.pretrained_flow_epoch,
                                          opt.gpu_ids)
            self.netF.eval()
            if opt.gpu_ids:
                self.netF.cuda()
        ###################################
        # loss and optimizers
        ###################################
        self.crit_psnr = networks.PSNR()
        self.crit_ssim = networks.SSIM()

        if self.is_train:
            self.crit_vgg = networks.VGGLoss(
                opt.gpu_ids,
                shifted_style=opt.shifted_style_loss,
                content_weights=opt.vgg_content_weights)
            if opt.G_pix_warp:
                # only optimze netPW
                self.optim = torch.optim.Adam(self.netPW.parameters(),
                                              lr=opt.lr,
                                              betas=(opt.beta1, opt.beta2),
                                              weight_decay=opt.weight_decay)
            else:
                self.optim = torch.optim.Adam(self.netG.parameters(),
                                              lr=opt.lr,
                                              betas=(opt.beta1, opt.beta2),
                                              weight_decay=opt.weight_decay)
            self.optimizers = [self.optim]
            if self.use_gan:
                self.crit_gan = networks.GANLoss(
                    use_lsgan=(opt.gan_type == 'lsgan'))
                if self.gpu_ids:
                    self.crit_gan.cuda()
                self.optim_D = torch.optim.Adam(
                    self.netD.parameters(),
                    lr=opt.lr_D,
                    betas=(opt.beta1, opt.beta2),
                    weight_decay=opt.weight_decay_D)
                self.optimizers += [self.optim_D]

        ###################################
        # load trained model
        ###################################
        if not self.is_train:
            # load trained model for testing
            self.load_network(self.netG, 'netG', opt.which_epoch)
            if opt.G_pix_warp:
                self.load_network(self.netPW, 'netPW', opt.which_epoch)
        elif opt.pretrained_G_id is not None:
            # load pretrained network
            self.load_network(self.netG, 'netG', opt.pretrained_G_epoch,
                              opt.pretrained_G_id)
        elif opt.resume_train:
            # resume training
            self.load_network(self.netG, 'netG', opt.which_epoch)
            self.load_optim(self.optim, 'optim', opt.which_epoch)
            if self.use_gan:
                self.load_network(self.netD, 'netD', opt.which_epoch)
                self.load_optim(self.optim_D, 'optim_D', opt.which_epoch)
            if opt.G_pix_warp:
                self.load_network(self.netPW, 'netPW', opt.which_epoch)
        ###################################
        # schedulers
        ###################################
        if self.is_train:
            self.schedulers = []
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))
    def initialize(self, opt):
        super(TwoStagePoseTransferModel, self).initialize(opt)
        ###################################
        # load pretrained stage-1 (coarse) network
        ###################################
        self._create_stage_1_net(opt)
        ###################################
        # define stage-2 (refine) network
        ###################################
        # local patch encoder
        if opt.which_model_s2e == 'patch_embed':
            self.netT_s2e = networks.LocalPatchEncoder(
                n_patch=len(opt.patch_indices),
                input_nc=3,
                output_nc=opt.s2e_nof,
                nf=opt.s2e_nf,
                max_nf=opt.s2e_max_nf,
                input_size=opt.patch_size,
                bottleneck_factor=opt.s2e_bottleneck_factor,
                n_residual_blocks=2,
                norm_layer=networks.get_norm_layer(opt.norm),
                activation=nn.ReLU(False),
                use_dropout=False,
                gpu_ids=opt.gpu_ids,
            )
            s2e_nof = opt.s2e_nof
        elif opt.which_model_s2e == 'patch':
            self.netT_s2e = networks.LocalPatchRearranger(
                n_patch=len(opt.patch_indices),
                image_size=opt.fine_size,
            )
            s2e_nof = 3
        elif opt.which_model_s2e == 'seg_embed':
            self.netT_s2e = networks.SegmentRegionEncoder(
                seg_nc=self.opt.seg_nc,
                input_nc=3,
                output_nc=opt.s2e_nof,
                nf=opt.s2d_nf,
                input_size=opt.fine_size,
                n_blocks=3,
                norm_layer=networks.get_norm_layer(opt.norm),
                activation=nn.ReLU,
                use_dropout=False,
                grid_level=opt.s2e_grid_level,
                gpu_ids=opt.gpu_ids,
            )
            s2e_nof = opt.s2e_nof + opt.s2e_grid_level
        else:
            raise NotImplementedError()
        if opt.gpu_ids:
            self.netT_s2e.cuda()

        # decoder
        if self.opt.which_model_s2d == 'resnet':
            self.netT_s2d = networks.ResnetGenerator(
                input_nc=3 + s2e_nof,
                output_nc=3,
                ngf=opt.s2d_nf,
                norm_layer=networks.get_norm_layer(opt.norm),
                activation=nn.ReLU,
                use_dropout=False,
                n_blocks=opt.s2d_nblocks,
                gpu_ids=opt.gpu_ids,
                output_tanh=False,
            )
        elif self.opt.which_model_s2d == 'unet':
            self.netT_s2d = networks.UnetGenerator_v2(
                input_nc=3 + s2e_nof,
                output_nc=3,
                num_downs=8,
                ngf=opt.s2d_nf,
                max_nf=opt.s2d_nf * 2**3,
                norm_layer=networks.get_norm_layer(opt.norm),
                use_dropout=False,
                gpu_ids=opt.gpu_ids,
                output_tanh=False,
            )
        elif self.opt.which_model_s2d == 'rpresnet':
            self.netT_s2d = networks.RegionPropagationResnetGenerator(
                input_nc=3 + s2e_nof,
                output_nc=3,
                ngf=opt.s2d_nf,
                norm_layer=networks.get_norm_layer(opt.norm),
                activation=nn.ReLU,
                use_dropout=False,
                nblocks=opt.s2d_nblocks,
                gpu_ids=opt.gpu_ids,
                output_tanh=False)
        else:
            raise NotImplementedError()
        if opt.gpu_ids:
            self.netT_s2d.cuda()
        ###################################
        # define discriminator
        ###################################
        self.use_GAN = self.is_train and opt.loss_weight_gan > 0
        if self.use_GAN:
            self.netD = networks.define_D_from_params(
                input_nc=3 +
                self.get_pose_dim(opt.pose_type) if opt.D_cond else 3,
                ndf=opt.D_nf,
                which_model_netD='n_layers',
                n_layers_D=opt.D_n_layer,
                norm=opt.norm,
                which_gan=opt.which_gan,
                init_type=opt.init_type,
                gpu_ids=opt.gpu_ids)
        else:
            self.netD = None
        ###################################
        # loss functions
        ###################################
        self.crit_psnr = networks.PSNR()
        self.crit_ssim = networks.SSIM()

        if self.is_train:
            self.optimizers = []
            self.crit_vgg = networks.VGGLoss_v2(self.gpu_ids,
                                                opt.content_layer_weight,
                                                opt.style_layer_weight,
                                                opt.shifted_style)

            self.optim = torch.optim.Adam([{
                'params': self.netT_s2e.parameters()
            }, {
                'params': self.netT_s2d.parameters()
            }],
                                          lr=opt.lr,
                                          betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optim)

            if opt.train_s1:
                self.optim_s1 = torch.optim.Adam(self.netT_s1.parameters(),
                                                 lr=opt.lr_s1,
                                                 betas=(opt.beta1, opt.beta2))
                self.optimizers.append(self.optim_s1)

            if self.use_GAN:
                self.crit_GAN = networks.GANLoss(
                    use_lsgan=opt.which_gan == 'lsgan', tensor=self.Tensor)
                self.optim_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr_D,
                                                betas=(opt.beta1, opt.beta2))
                self.optimizers.append(self.optim_D)
                self.fake_pool = ImagePool(opt.pool_size)
        ###################################
        # init/load model
        ###################################
        if self.is_train:
            if not opt.continue_train:
                self.load_network(self.netT_s1, 'netT', 'latest',
                                  self.opt_s1.id)
                networks.init_weights(self.netT_s2e, init_type=opt.init_type)
                networks.init_weights(self.netT_s2d, init_type=opt.init_type)
                if self.use_GAN:
                    networks.init_weights(self.netD, init_type=opt.init_type)
            else:
                self.load_network(self.netT_s1, 'netT_s1', opt.which_epoch)
                self.load_network(self.netT_s2e, 'netT_s2e', opt.which_epoch)
                self.load_network(self.netT_s2d, 'netT_s2d', opt.which_epoch)
                self.load_optim(self.optim, 'optim', opt.which_epoch)
                if self.use_GAN:
                    self.load_network(self.netD, 'netD', opt.which_epoch)
                    self.load_optim(self.optim_D, 'optim_D', opt.which_epoch)
        else:
            self.load_network(self.netT_s1, 'netT_s1', opt.which_epoch)
            self.load_network(self.netT_s2e, 'netT_s2e', opt.which_epoch)
            self.load_network(self.netT_s2d, 'netT_s2d', opt.which_epoch)
        ###################################
        # schedulers
        ###################################
        if self.is_train:
            self.schedulers = []
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))