Example #1
0
    def init_models(self):
        # Networks
        self.G_AB = networks.define_G(
            input_nc=self.config.input_nc,
            output_nc=self.config.output_nc,
            ngf=self.config.g_conv_dim,
            which_model_netG=self.config.which_model_netG,
            norm='batch',
            init_type='normal',
            gpu_ids=self.gpu_ids)
        self.D_B = networks.define_D(
            input_nc=self.config.input_nc,
            ndf=self.config.d_conv_dim,
            which_model_netD=self.config.which_model_netD,
            n_layers_D=3,
            norm='instance',
            use_sigmoid=True,
            init_type='normal',
            gpu_ids=self.gpu_ids,
            image_size=self.config.image_size)

        # Optimisers
        self.G_optim = optim.Adam(self.G_AB.parameters(),
                                  lr=self.config.lr,
                                  betas=(self.config.beta1, self.config.beta2))
        self.D_optim = optim.Adam(self.D_B.parameters(),
                                  lr=self.config.lr,
                                  betas=(self.config.beta1, self.config.beta2))
        self.optimizers = [self.G_optim, self.D_optim]

        # Schedulers
        self.schedulers = []
        for optimizer in self.optimizers:
            self.schedulers.append(
                networks.get_scheduler(optimizer, self.config))
Example #2
0
    def init_models(self):
        """ Models: G_UM, G_MU, D_M, D_U """
        # Networks
        self.G_UM = networks.define_G(input_nc=1, output_nc=1, ngf=self.config.g_conv_dim,
                                      which_model_netG=self.config.which_model_netG, norm='batch', init_type='normal',
                                      gpu_ids=self.gpu_ids)
        self.G_MU = networks.define_G(input_nc=1, output_nc=1, ngf=self.config.g_conv_dim,
                                      which_model_netG=self.config.which_model_netG, norm='batch', init_type='normal',
                                      gpu_ids=self.gpu_ids)
        self.D_M = networks.define_D(input_nc=1, ndf=self.config.d_conv_dim,
                                     which_model_netD=self.config.which_model_netD,
                                     n_layers_D=3, norm='instance', use_sigmoid=True, init_type='normal',
                                     gpu_ids=self.gpu_ids)
        self.D_U = networks.define_D(input_nc=1, ndf=self.config.d_conv_dim,
                                     which_model_netD=self.config.which_model_netD,
                                     n_layers_D=3, norm='instance', use_sigmoid=True, init_type='normal',
                                     gpu_ids=self.gpu_ids)

        # Optimisers
        # single optimiser for both generators
        self.G_optim = optim.Adam(itertools.chain(self.G_UM.parameters(), self.G_MU.parameters()),
                                  self.config.lr, betas=(self.config.beta1, self.config.beta2))
        self.D_M_optim = optim.Adam(self.D_M.parameters(),
                                    lr=self.config.lr, betas=(self.config.beta1, self.config.beta2))
        self.D_U_optim = optim.Adam(self.D_U.parameters(),
                                    lr=self.config.lr, betas=(self.config.beta1, self.config.beta2))
        self.optimizers = [self.G_optim, self.D_M_optim, self.D_U_optim]

        # Schedulers
        self.schedulers = []
        for optimizer in self.optimizers:
            self.schedulers.append(networks.get_scheduler(optimizer, self.config))
Example #3
0
 def set_scheduler(self, opts, last_ep=0):
   self.disA_sch = networks.get_scheduler(self.disA_opt, opts, last_ep)
   self.disB_sch = networks.get_scheduler(self.disB_opt, opts, last_ep)
   self.disA2_sch = networks.get_scheduler(self.disA2_opt, opts, last_ep)
   self.disB2_sch = networks.get_scheduler(self.disB2_opt, opts, last_ep)
   self.disContent_sch = networks.get_scheduler(self.disContent_opt, opts, last_ep)
   self.enc_c_sch = networks.get_scheduler(self.enc_c_opt, opts, last_ep)
   self.enc_a_sch = networks.get_scheduler(self.enc_a_opt, opts, last_ep)
   self.gen_sch = networks.get_scheduler(self.gen_opt, opts, last_ep)
Example #4
0
    def setup(self, opt, parser=None):
        if self.isTrain:
            self.schedulers = [
                networks.get_scheduler(optimizer, opt)
                for optimizer in self.optimizers
            ]

        if not self.isTrain or opt.continue_train:
            self.load_networks(opt.which_epoch)
        self.print_networks(opt.verbose)
Example #5
0
File: model.py Project: cvvsu/RMEP
    def setup(self):
        r"""Load and print networks; create schedulers

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        self.scheduler = networks.get_scheduler(self.optimizer, self.opt)
        if not self.isTrain or self.opt.continue_train:
            self.load_networks(self.opt.epoch)
        self.print_networks(self.opt.verbose)
    def initialize(self, opt):
        super(PoseParsingModel, self).initialize(opt)
        ###################################
        # create model
        ###################################
        if opt.which_model_PP == 'resnet':
            self.netPP = networks.ResnetGenerator(
                input_nc = self.get_data_dim(opt.pp_input_type),
                output_nc = self.get_data_dim(opt.pp_pose_type),
                ngf = opt.pp_nf,
                norm_layer = networks.get_norm_layer(opt.norm),
                activation = nn.ReLU,
                use_dropout = False,
                n_blocks = opt.pp_nblocks,
                gpu_ids = opt.gpu_ids,
                output_tanh = False,
                )
        elif opt.which_model_PP == 'unet':
            self.netPP = networks.UnetGenerator_v2(
                input_nc = self.get_data_dim(opt.pp_input_type),
                output_nc = self.get_data_dim(opt.pp_pose_type),
                num_downs = 8,
                ngf = opt.pp_nf,
                max_nf = opt.pp_nf*(2**3),
                norm_layer = networks.get_norm_layer(opt.norm),
                use_dropout = False,
                gpu_ids = opt.gpu_ids,
                output_tanh = False,
                )
        else:
            raise NotImplementedError()

        if opt.gpu_ids:
            self.netPP.cuda()
        ###################################
        # init/load model
        ###################################
        if self.is_train and (not opt.continue_train):
            networks.init_weights(self.netPP, init_type=opt.init_type)
        else:
            self.load_network(self.netPP, 'netPP', opt.which_epoch)
        ###################################
        # optimizers and schedulers
        ###################################
        if self.is_train:
            self.optim = torch.optim.Adam(self.netPP.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
            self.optimizers = [self.optim]

            self.schedulers = []
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))
 def setup(self, opt):
     """Load and print networks; create schedulers
     Parameters:
         opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
     """
     if self.isTrain:
         self.schedulers = [
             networks.get_scheduler(optimizer, opt)
             for optimizer in self.optimizers
         ]
     if not self.isTrain or opt.continue_train:
         load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
         self.load_networks(load_suffix)
     self.print_networks(opt.verbose)
    def initialize(self, opt):
        super(FeatureSpatialTransformer, self).initialize(opt)
        ###################################
        # load/define networks
        ###################################
        self.net = networks.define_feat_spatial_transformer(opt)
        self.netAE = None

        if opt.continue_train or not self.is_train:
            self.load_network(self.net, 'FeatST', epoch_label = opt.which_epoch)

        if self.is_train:
            ###################################
            # load attribute encoder
            ###################################
            self.netAE, self.opt_AE = load_attribute_encoder_net(id=opt.which_model_AE, gpu_ids=opt.gpu_ids)

            ###################################
            # define loss functions and loss buffers
            ###################################
            self.crit_L1 = networks.SmoothLoss(nn.L1Loss())
            self.crit_attr = networks.SmoothLoss(nn.BCELoss())

            self.loss_functions = []
            self.loss_functions.append(self.crit_L1)
            self.loss_functions.append(self.crit_attr)

            ###################################
            # create optimizers
            ###################################
            self.schedulers = []
            self.optimizers = []

            self.optim = torch.optim.Adam(self.net.parameters(), 
                lr = opt.lr, betas = (opt.beta1, opt.beta2), weight_decay = opt.weight_decay)

            self.optimizers.append(self.optim)

            for optim in self.optimizers:
                    self.schedulers.append(networks.get_scheduler(optim, opt))

        # color transformation from std to imagenet
        # img_imagenet = img_std * a + b
        self.trans_std_to_imagenet = {
            'a': Variable(self.Tensor([0.5/0.229, 0.5/0.224, 0.5/0.225]), requires_grad = False).view(3,1,1),
            'b': Variable(self.Tensor([(0.5-0.485)/0.229, (0.5-0.456)/0.224, (0.5-0.406)/0.225]), requires_grad = False).view(3,1,1)
        }
Example #9
0
    def setup(self, task_index):
        """Load and print networks; create schedulers

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        self.task_index = task_index
        logging.info("BaseModel set up")
        if self.isTrain:
            self.schedulers = [
                get_scheduler(optimizer, self.opt)
                for optimizer in self.optimizers
            ]

        if not self.isTrain or self.opt.continue_train:
            self.load_networks(self.opt.load_taskindex, self.opt.load_step,
                               self.opt.load_epoch)

        if self.opt.amp:
            self.scaler = GradScaler()
        self.print_networks()
Example #10
0
    def initialize(self, opt):
        super(EncoderDecoderFramework_V2, self).initialize(opt)
        ###################################
        # define encoder
        ###################################
        self.encoder = networks.define_encoder_v2(opt)
        ###################################
        # define decoder
        ###################################
        self.decoder = networks.define_decoder_v2(opt)
        ###################################
        # guide encoder
        ###################################
        if opt.use_guide_encoder:
            self.guide_encoder = networks.load_encoder_v2(
                opt, opt.which_model_guide)
            self.guide_encoder.eval()
            for p in self.guide_encoder.parameters():
                p.requires_grad = False
        ###################################
        # loss functions
        ###################################
        self.loss_functions = []
        self.schedulers = []
        self.crit_image = networks.SmoothLoss(nn.L1Loss())
        self.crit_seg = networks.SmoothLoss(nn.CrossEntropyLoss())
        self.crit_edge = networks.SmoothLoss(nn.BCELoss())
        self.loss_functions += [self.crit_image, self.crit_seg, self.crit_edge]

        self.optim = torch.optim.Adam([{
            'params': self.encoder.parameters()
        }, {
            'params': self.decoder.parameters()
        }],
                                      lr=opt.lr,
                                      betas=(opt.beta1, opt.beta2))
        self.optimizers = [self.optim]
        for optim in self.optimizers:
            self.schedulers.append(networks.get_scheduler(optim, opt))
Example #11
0
                 0.02,
                 gpu_id=device)
net_d = define_D(opt.input_nc + opt.output_nc, opt.ndf, 'basic', gpu_id=device)

criterionGAN = GANLoss().to(device)
criterionL1 = nn.L1Loss().to(device)
criterionMSE = nn.MSELoss().to(device)

# setup optimizer
optimizer_g = optim.Adam(net_g.parameters(),
                         lr=opt.lr,
                         betas=(opt.beta1, 0.999))
optimizer_d = optim.Adam(net_d.parameters(),
                         lr=opt.lr,
                         betas=(opt.beta1, 0.999))
net_g_scheduler = get_scheduler(optimizer_g, opt)
net_d_scheduler = get_scheduler(optimizer_d, opt)
face_descriptor = FaceDescriptor().to(device)
face_landmarks = FaceLandmarks().to(device)

for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    # train
    for iteration, (X, _) in enumerate(training_data_loader, 1):
        # forward
        X = X.to(device)
        FL = face_landmarks(X)
        FL_ = FL.view(FL.size(0), 68, 64 * 64)
        FL_max, _ = FL_.max(dim=2)
        FL_max = FL_max.view(FL_max.size(0), FL_max.size(1), 1)
        FL = (FL_ >= FL_max).float().sum(dim=1).view(FL.size(0), 1, 64, 64)
        FL[FL > 0] = 1
Example #12
0
    def initialize(self, opt):
        super(EncoderDecoderFramework, self).initialize(opt)
        ###################################
        # load/define networks
        ###################################
        if opt.use_shape:
            self.encoder_type = 'shape'
            self.encoder_name = 'shape_encoder'
            self.decoder_name = 'decoder'
        elif opt.use_edge:
            self.encoder_type = 'edge'
            self.encoder_name = 'edge_encoder'
            self.decoder_name = 'decoder'
        elif opt.use_color:
            self.encoder_type = 'color'
            self.encoder_name = 'color_encoder'
            self.decoder_name = 'decoder'
        else:
            raise ValueError(
                'either use_shape, use_edge, use_color should be set')

        # encoder
        self.encoder = networks.define_image_encoder(opt, self.encoder_type)

        # decoder
        if self.encoder_type == 'shape':
            ndowns = opt.shape_ndowns
            nf = opt.shape_nf
            nof = opt.shape_nof
            output_nc = 7
            output_activation = None
            assert opt.decode_guided == False
        elif self.encoder_type == 'edge':
            ndowns = opt.edge_ndowns
            nf = opt.edge_nf
            nof = opt.edge_nof
            output_nc = 1
            output_activation = None
        elif self.encoder_type == 'color':
            ndowns = opt.color_ndowns
            nf = opt.color_nf
            nof = opt.color_nof
            output_nc = 3
            output_activation = nn.Tanh

        if opt.encoder_type in {'normal', 'st'}:
            self.feat_size = 256 // 2**(opt.edge_ndowns)
            self.mid_feat_size = self.feat_size
        else:
            self.feat_size = 1
            self.mid_feat_size = 8

        self.use_concat_net = False
        if opt.decode_guided:
            if self.feat_size > 1:
                self.decoder = networks.define_image_decoder_from_params(
                    input_nc=nof + opt.shape_nc,
                    output_nc=output_nc,
                    nf=nf,
                    num_ups=ndowns,
                    norm=opt.norm,
                    output_activation=output_activation,
                    gpu_ids=opt.gpu_ids,
                    init_type=opt.init_type)
            else:
                self.decoder = networks.define_image_decoder_from_params(
                    input_nc=nof,
                    output_nc=output_nc,
                    nf=nf,
                    num_ups=5,
                    norm=opt.norm,
                    output_activation=output_activation,
                    gpu_ids=opt.gpu_ids,
                    init_type=opt.init_type)
                self.concat_net = networks.FeatureConcatNetwork(
                    feat_nc=nof,
                    guide_nc=opt.shape_nc,
                    nblocks=3,
                    norm=opt.norm,
                    gpu_ids=opt.gpu_ids)
                if len(self.gpu_ids) > 0:
                    self.concat_net.cuda()
                networks.init_weights(self.concat_net, opt.init_type)
                self.use_concat_net = True
                print('encoder_decoder contains a feature_concat_network!')
        else:
            if self.feat_size > 1:
                self.decoder = networks.define_image_decoder_from_params(
                    input_nc=nof,
                    output_nc=output_nc,
                    nf=nf,
                    num_ups=ndowns,
                    norm=opt.norm,
                    output_activation=output_activation,
                    gpu_ids=opt.gpu_ids,
                    init_type=opt.init_type)
            else:
                self.decoder = networks.define_image_decoder_from_params(
                    input_nc=nof,
                    output_nc=output_nc,
                    nf=nf,
                    num_ups=8,
                    norm=opt.norm,
                    output_activation=output_activation,
                    gpu_ids=opt.gpu_ids,
                    init_type=opt.init_type)

        if not self.is_train or (self.is_train and self.opt.continue_train):
            self.load_network(self.encoder, self.encoder_name, opt.which_opoch)
            self.load_network(self.decoder, self.decoder_name, opt.which_opoch)
            if self.use_concat_net:
                self.load_network(self.concat_net, 'concat_net',
                                  opt.which_opoch)

        # loss functions
        self.loss_functions = []
        self.schedulers = []
        self.crit_L1 = networks.SmoothLoss(nn.L1Loss())
        self.crit_CE = networks.SmoothLoss(nn.CrossEntropyLoss())
        self.loss_functions += [self.crit_L1, self.crit_CE]

        self.optim = torch.optim.Adam([{
            'params': self.encoder.parameters()
        }, {
            'params': self.decoder.parameters()
        }],
                                      lr=opt.lr,
                                      betas=(opt.beta1, opt.beta2))
        self.optimizers = [self.optim]
        for optim in self.optimizers:
            self.schedulers.append(networks.get_scheduler(optim, opt))
Example #13
0
    def initialize(self, opt):
        super(DesignerGAN, self).initialize(opt)
        ###################################
        # define data tensors
        ###################################
        # self.input['img'] = self.Tensor()
        # self.input['img_attr'] = self.Tensor()
        # self.input['lm_map'] = self.Tensor()
        # self.input['seg_mask'] = self.Tensor()
        # self.input['attr_label'] = self.Tensor()
        # self.input['id'] = []

        ###################################
        # load/define networks
        ###################################

        # Todo modify networks.define_G
        # 1. add specified generator networks

        self.netG = networks.define_G(opt)
        self.netAE, self.opt_AE = network_loader.load_attribute_encoder_net(
            id=opt.which_model_AE, gpu_ids=opt.gpu_ids)
        if opt.which_model_FeatST != 'none':
            self.netFeatST, self.opt_FeatST = network_loader.load_feature_spatial_transformer_net(
                id=opt.which_model_FeatST, gpu_ids=opt.gpu_ids)
            self.use_FeatST = True
            # assert self.opt_FeatST.shape_encode == self.opt.shape_encode, 'GAN model and FeatST model has different shape encode mode'
            # assert self.opt_FeatST.input_mask_mode == self.opt.input_mask_mode, 'GAN model and FeatST model has different segmentation input mode'
        else:
            self.use_FeatST = False

        if self.is_train:
            self.netD = networks.define_D(opt)
            if opt.which_model_init_netG != 'none' and not opt.continue_train:
                self.load_network(self.netG, 'G', 'latest',
                                  opt.which_model_init_netG)

        if not self.is_train or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.is_train:
                self.load_network(self.netD, 'D', opt.which_epoch)

        if self.is_train:
            self.fake_pool = ImagePool(opt.pool_size)

            ###################################
            # define loss functions and loss buffers
            ###################################
            if opt.which_gan in {'dcgan', 'lsgan'}:
                self.crit_GAN = networks.GANLoss(
                    use_lsgan=opt.which_gan == 'lsgan', tensor=self.Tensor)
            else:
                # WGAN loss will be calculated in self.backward_D_wgangp and self.backward_G
                self.crit_GAN = None
            self.crit_L1 = nn.L1Loss()
            self.crit_attr = nn.BCELoss()

            self.loss_functions = []
            self.loss_functions.append(self.crit_GAN)
            self.loss_functions.append(self.crit_L1)
            self.loss_functions.append(self.crit_attr)

            if self.opt.loss_weight_vgg > 0:
                self.crit_vgg = networks.VGGLoss(self.gpu_ids)
                self.loss_functions.append(self.crit_vgg)

            ###################################
            # create optimizers
            ###################################
            self.schedulers = []
            self.optimizers = []

            self.optim_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, opt.beta2))
            self.optim_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optim_G)
            self.optimizers.append(self.optim_D)

            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))

        # color transformation from std to imagenet
        # img_imagenet = img_std * a + b
        self.trans_std_to_imagenet = {
            'a':
            Variable(self.Tensor([0.5 / 0.229, 0.5 / 0.224, 0.5 / 0.225]),
                     requires_grad=False).view(3, 1, 1),
            'b':
            Variable(self.Tensor([(0.5 - 0.485) / 0.229, (0.5 - 0.456) / 0.224,
                                  (0.5 - 0.406) / 0.225]),
                     requires_grad=False).view(3, 1, 1)
        }
Example #14
0
    def initialize(self, opt):
        super(VUnetPoseTransferModel, self).initialize(opt)
        ###################################
        # define transformer
        ###################################
        self.netT = networks.VariationalUnet(
            input_nc_dec = self.get_pose_dim(opt.pose_type),
            input_nc_enc = self.get_appearance_dim(opt.appearance_type),
            output_nc = self.get_output_dim(opt.output_type),
            nf = opt.vunet_nf,
            max_nf = opt.vunet_max_nf,
            input_size = opt.fine_size,
            n_latent_scales = opt.vunet_n_latent_scales,
            bottleneck_factor = opt.vunet_bottleneck_factor,
            box_factor = opt.vunet_box_factor,
            n_residual_blocks = 2,
            norm_layer = networks.get_norm_layer(opt.norm),
            activation = nn.ReLU(False),
            use_dropout = False,
            gpu_ids = opt.gpu_ids,
            output_tanh = False,
            )
        if opt.gpu_ids:
            self.netT.cuda()
        networks.init_weights(self.netT, init_type=opt.init_type)
        ###################################
        # define discriminator
        ###################################
        self.use_GAN = self.is_train and opt.loss_weight_gan > 0
        if self.use_GAN:
            self.netD = networks.define_D_from_params(
                input_nc=3+self.get_pose_dim(opt.pose_type) if opt.D_cond else 3,
                ndf=opt.D_nf,
                which_model_netD='n_layers',
                n_layers_D=opt.D_n_layer,
                norm=opt.norm,
                which_gan=opt.which_gan,
                init_type=opt.init_type,
                gpu_ids=opt.gpu_ids)
        else:
            self.netD = None
        ###################################
        # loss functions
        ###################################
        self.crit_psnr = networks.PSNR()
        self.crit_ssim = networks.SSIM()

        if self.is_train:
            self.optimizers =[]
            self.crit_vgg = networks.VGGLoss_v2(self.gpu_ids, opt.content_layer_weight, opt.style_layer_weight, opt.shifted_style)
            # self.crit_vgg_old = networks.VGGLoss(self.gpu_ids)
            self.optim = torch.optim.Adam(self.netT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay)
            self.optimizers += [self.optim]

            if self.use_GAN:
                self.crit_GAN = networks.GANLoss(use_lsgan=opt.which_gan=='lsgan', tensor=self.Tensor)
                self.optim_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr_D, betas=(opt.beta1, opt.beta2))
                self.optimizers.append(self.optim_D)
            # todo: add pose loss
            self.fake_pool = ImagePool(opt.pool_size)

        ###################################
        # load trained model
        ###################################
        if not self.is_train:
            self.load_network(self.netT, 'netT', opt.which_epoch)
        elif opt.continue_train:
            self.load_network(self.netT, 'netT', opt.which_epoch)
            self.load_optim(self.optim, 'optim', opt.which_epoch)
            if self.use_GAN:
                self.load_network(self.netD, 'netD', opt.which_epoch)
                self.load_optim(self.optim_D, 'optim_D', opt.which_epoch)
        ###################################
        # schedulers
        ###################################
        if self.is_train:
            self.schedulers = []
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))
Example #15
0
    def initialize(self, cfg):
        self.cfg = cfg

        ## set devices
        if cfg['GPU_IDS']:
            assert (torch.cuda.is_available())
            self.device = torch.device('cuda:{}'.format(cfg['GPU_IDS'][0]))
            torch.backends.cudnn.benchmark = True
            print('Using %d GPUs' % len(cfg['GPU_IDS']))
        else:
            self.device = torch.device('cpu')

        # define network
        if cfg['ARCHI'] == 'alexnet':
            self.netB = networks.netB_alexnet()
            self.netH = networks.netH_alexnet()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = networks.netD_alexnet(self.cfg['DA_LAYER'])
        elif cfg['ARCHI'] == 'vgg16':
            raise NotImplementedError
            self.netB = networks.netB_vgg16()
            self.netH = networks.netH_vgg16()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = netD_vgg16(self.cfg['DA_LAYER'])
        elif 'resnet' in cfg['ARCHI']:
            self.netB = networks.netB_resnet34()
            self.netH = networks.netH_resnet34()
            if self.cfg['USE_DA'] and self.cfg['TRAIN']:
                self.netD = networks.netD_resnet(self.cfg['DA_LAYER'])
        else:
            raise ValueError('Un-supported network')

        ## initialize network param.
        self.netB = networks.init_net(self.netB, cfg['GPU_IDS'], 'xavier')
        self.netH = networks.init_net(self.netH, cfg['GPU_IDS'], 'xavier')

        if self.cfg['USE_DA'] and self.cfg['TRAIN']:
            self.netD = networks.init_net(self.netD, cfg['GPU_IDS'], 'xavier')

        # loss, optimizer, and scherduler
        if cfg['TRAIN']:
            self.total_steps = 0
            ## Output path
            self.save_dir = os.path.join(
                cfg['OUTPUT_PATH'], cfg['ARCHI'],
                datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
            if not os.path.isdir(self.save_dir):
                os.makedirs(self.save_dir)
            # self.logger = Logger(self.save_dir)

            ## model names
            self.model_names = ['netB', 'netH']
            ## loss
            self.criterionGAN = networks.GANLoss().to(self.device)
            self.criterionDepth1 = torch.nn.MSELoss().to(self.device)
            self.criterionNorm = torch.nn.CosineEmbeddingLoss().to(self.device)
            self.criterionEdge = torch.nn.BCELoss().to(self.device)

            ## optimizers
            self.lr = cfg['LR']
            self.optimizers = []
            self.optimizer_B = torch.optim.Adam(self.netB.parameters(),
                                                lr=cfg['LR'],
                                                betas=(cfg['BETA1'],
                                                       cfg['BETA2']))
            self.optimizer_H = torch.optim.Adam(self.netH.parameters(),
                                                lr=cfg['LR'],
                                                betas=(cfg['BETA1'],
                                                       cfg['BETA2']))
            self.optimizers.append(self.optimizer_B)
            self.optimizers.append(self.optimizer_H)
            if cfg['USE_DA']:
                self.real_pool = ImagePool(cfg['POOL_SIZE'])
                self.syn_pool = ImagePool(cfg['POOL_SIZE'])
                self.model_names.append('netD')
                ## use SGD for discriminator
                self.optimizer_D = torch.optim.SGD(
                    self.netD.parameters(),
                    lr=cfg['LR'],
                    momentum=cfg['MOMENTUM'],
                    weight_decay=cfg['WEIGHT_DECAY'])
                self.optimizers.append(self.optimizer_D)
            ## LR scheduler
            self.schedulers = [
                networks.get_scheduler(optimizer, cfg)
                for optimizer in self.optimizers
            ]
        else:
            ## testing
            self.model_names = ['netB', 'netH']
            self.criterionDepth1 = torch.nn.MSELoss().to(self.device)
            self.criterionNorm = torch.nn.CosineEmbeddingLoss(
                reduction='none').to(self.device)

        self.load_dir = os.path.join(cfg['CKPT_PATH'])
        self.criterionNorm_eval = torch.nn.CosineEmbeddingLoss(
            reduction='none').to(self.device)

        if cfg['TEST'] or cfg['RESUME']:
            self.load_networks(cfg['EPOCH_LOAD'])
Example #16
0
    def __init__(self, opt):
        """Initialize the CycleGAN class.

                Parameters:
                    opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
                """

        self.opt = opt
        self.isTrain = opt.isTrain
        self.device = opt.device
        self.model_save_dir = opt.model_dir
        self.loss_names = []
        self.model_names = []
        self.optimizers = []
        self.image_paths = []

        self.epoch = 0
        self.num_epochs = opt.nr_epochs

        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'D_B', 'G_B', 'cycle_B']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>

        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.Generator(opt.input_nc, opt.output_nc, opt.ngf,
                                         opt.n_layers_G).to(self.device)
        self.netG_B = networks.Generator(opt.output_nc, opt.input_nc, opt.ngf,
                                         opt.n_layers_G).to(self.device)

        if self.isTrain:  # define discriminators
            self.netD_A = networks.NLayerDiscriminator(
                opt.output_nc, opt.ndf, opt.n_layers_D).to(self.device)
            self.netD_B = networks.NLayerDiscriminator(
                opt.input_nc, opt.ndf, opt.n_layers_D).to(self.device)
            # self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
            #                                 opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
            # self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
            #                                 opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:
            self.fake_A_pool = networks.MotionPool(
                opt.pool_size
            )  # create image buffer to store previously generated images
            self.fake_B_pool = networks.MotionPool(
                opt.pool_size
            )  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(
                self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            # self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.netD_A.parameters(), self.netD_B.parameters()),
                                                lr=opt.lr,
                                                betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

            self.schedulers = [
                networks.get_scheduler(optimizer, opt)
                for optimizer in self.optimizers
            ]
Example #17
0
def train(opt):
    #### device
    device = torch.device('cuda:{}'.format(opt.gpu_id)
                          if opt.gpu_id >= 0 else torch.device('cpu'))

    #### dataset
    data_loader = UnAlignedDataLoader()
    data_loader.initialize(opt)
    data_set = data_loader.load_data()
    print("The number of training images = %d." % len(data_set))

    #### initialize models
    ## declaration
    E_a2Zb = Encoder(input_nc=opt.input_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type,
                     use_dropout=not opt.no_dropout,
                     n_blocks=9)
    G_Zb2b = Decoder(output_nc=opt.output_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type)
    T_Zb2Za = LatentTranslator(n_channels=256,
                               norm_type=opt.norm_type,
                               use_dropout=not opt.no_dropout)
    D_b = Discriminator(input_nc=opt.input_nc,
                        ndf=opt.ndf,
                        n_layers=opt.n_layers,
                        norm_type=opt.norm_type)

    E_b2Za = Encoder(input_nc=opt.input_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type,
                     use_dropout=not opt.no_dropout,
                     n_blocks=9)
    G_Za2a = Decoder(output_nc=opt.output_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type)
    T_Za2Zb = LatentTranslator(n_channels=256,
                               norm_type=opt.norm_type,
                               use_dropout=not opt.no_dropout)
    D_a = Discriminator(input_nc=opt.input_nc,
                        ndf=opt.ndf,
                        n_layers=opt.n_layers,
                        norm_type=opt.norm_type)

    ## initialization
    E_a2Zb = init_net(E_a2Zb, init_type=opt.init_type).to(device)
    G_Zb2b = init_net(G_Zb2b, init_type=opt.init_type).to(device)
    T_Zb2Za = init_net(T_Zb2Za, init_type=opt.init_type).to(device)
    D_b = init_net(D_b, init_type=opt.init_type).to(device)

    E_b2Za = init_net(E_b2Za, init_type=opt.init_type).to(device)
    G_Za2a = init_net(G_Za2a, init_type=opt.init_type).to(device)
    T_Za2Zb = init_net(T_Za2Zb, init_type=opt.init_type).to(device)
    D_a = init_net(D_a, init_type=opt.init_type).to(device)
    print(
        "+------------------------------------------------------+\nFinish initializing networks."
    )

    #### optimizer and criterion
    ## criterion
    criterionGAN = GANLoss(opt.gan_mode).to(device)
    criterionZId = nn.L1Loss()
    criterionIdt = nn.L1Loss()
    criterionCTC = nn.L1Loss()
    criterionZCyc = nn.L1Loss()

    ## optimizer
    optimizer_G = torch.optim.Adam(itertools.chain(E_a2Zb.parameters(),
                                                   G_Zb2b.parameters(),
                                                   T_Zb2Za.parameters(),
                                                   E_b2Za.parameters(),
                                                   G_Za2a.parameters(),
                                                   T_Za2Zb.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2))
    optimizer_D = torch.optim.Adam(itertools.chain(D_a.parameters(),
                                                   D_b.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2))

    ## scheduler
    scheduler = [
        get_scheduler(optimizer_G, opt),
        get_scheduler(optimizer_D, opt)
    ]

    print(
        "+------------------------------------------------------+\nFinish initializing the optimizers and criterions."
    )

    #### global variables
    checkpoints_pth = os.path.join(opt.checkpoints, opt.name)
    if os.path.exists(checkpoints_pth) is not True:
        os.mkdir(checkpoints_pth)
        os.mkdir(os.path.join(checkpoints_pth, 'images'))
    record_fh = open(os.path.join(checkpoints_pth, 'records.txt'),
                     'w',
                     encoding='utf-8')
    loss_names = [
        'GAN_A', 'Adv_A', 'Idt_A', 'CTC_A', 'ZId_A', 'ZCyc_A', 'GAN_B',
        'Adv_B', 'Idt_B', 'CTC_B', 'ZId_B', 'ZCyc_B'
    ]

    fake_A_pool = ImagePool(
        opt.pool_size
    )  # create image buffer to store previously generated images
    fake_B_pool = ImagePool(
        opt.pool_size
    )  # create image buffer to store previously generated images

    print(
        "+------------------------------------------------------+\nFinish preparing the other works."
    )
    print(
        "+------------------------------------------------------+\nNow training is beginning .."
    )
    #### training
    cur_iter = 0
    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()  # timer for entire epoch

        for i, data in enumerate(data_set):
            ## setup inputs
            real_A = data['A'].to(device)
            real_B = data['B'].to(device)

            ## forward
            # image cycle / GAN
            latent_B = E_a2Zb(real_A)  #-> a -> Zb     : E_a2b(a)
            fake_B = G_Zb2b(latent_B)  #-> Zb -> b'    : G_b(E_a2b(a))
            latent_A = E_b2Za(real_B)  #-> b -> Za     : E_b2a(b)
            fake_A = G_Za2a(latent_A)  #-> Za -> a'    : G_a(E_b2a(b))

            # Idt
            '''
            rec_A = G_Za2a(E_b2Za(fake_B))          #-> b' -> Za' -> rec_a  : G_a(E_b2a(fake_b))
            rec_B = G_Zb2b(E_a2Zb(fake_A))          #-> a' -> Zb' -> rec_b  : G_b(E_a2b(fake_a))
            '''
            idt_latent_A = E_b2Za(real_A)  #-> a -> Za        : E_b2a(a)
            idt_A = G_Za2a(idt_latent_A)  #-> Za -> idt_a    : G_a(E_b2a(a))
            idt_latent_B = E_a2Zb(real_B)  #-> b -> Zb        : E_a2b(b)
            idt_B = G_Zb2b(idt_latent_B)  #-> Zb -> idt_b    : G_b(E_a2b(b))

            # ZIdt
            T_latent_A = T_Zb2Za(latent_B)  #-> Zb -> Za''  : T_b2a(E_a2b(a))
            T_rec_A = G_Za2a(
                T_latent_A)  #-> Za'' -> a'' : G_a(T_b2a(E_a2b(a)))
            T_latent_B = T_Za2Zb(latent_A)  #-> Za -> Zb''  : T_a2b(E_b2a(b))
            T_rec_B = G_Zb2b(
                T_latent_B)  #-> Zb'' -> b'' : G_b(T_a2b(E_b2a(b)))

            # CTC
            T_idt_latent_B = T_Za2Zb(idt_latent_A)  #-> a -> T_a2b(E_b2a(a))
            T_idt_latent_A = T_Zb2Za(idt_latent_B)  #-> b -> T_b2a(E_a2b(b))

            # ZCyc
            TT_latent_B = T_Za2Zb(T_latent_A)  #-> T_a2b(T_b2a(E_a2b(a)))
            TT_latent_A = T_Zb2Za(T_latent_B)  #-> T_b2a(T_a2b(E_b2a(b)))

            ### optimize parameters
            ## Generator updating
            set_requires_grad(
                [D_b, D_a],
                False)  #-> set Discriminator to require no gradient
            optimizer_G.zero_grad()
            # GAN loss
            loss_G_A = criterionGAN(D_b(fake_B), True)
            loss_G_B = criterionGAN(D_a(fake_A), True)
            loss_GAN = loss_G_A + loss_G_B
            # Idt loss
            loss_idt_A = criterionIdt(idt_A, real_A)
            loss_idt_B = criterionIdt(idt_B, real_B)
            loss_Idt = loss_idt_A + loss_idt_B
            # Latent cross-identity loss
            loss_Zid_A = criterionZId(T_rec_A, real_A)
            loss_Zid_B = criterionZId(T_rec_B, real_B)
            loss_Zid = loss_Zid_A + loss_Zid_B
            # Latent cross-translation consistency
            loss_CTC_A = criterionCTC(T_idt_latent_A, latent_A)
            loss_CTC_B = criterionCTC(T_idt_latent_B, latent_B)
            loss_CTC = loss_CTC_B + loss_CTC_A
            # Latent cycle consistency
            loss_ZCyc_A = criterionZCyc(TT_latent_A, latent_A)
            loss_ZCyc_B = criterionZCyc(TT_latent_B, latent_B)
            loss_ZCyc = loss_ZCyc_B + loss_ZCyc_A

            loss_G = opt.lambda_gan * loss_GAN + opt.lambda_idt * loss_Idt + opt.lambda_zid * loss_Zid + opt.lambda_ctc * loss_CTC + opt.lambda_zcyc * loss_ZCyc

            # backward and gradient updating
            loss_G.backward()
            optimizer_G.step()

            ## Discriminator updating
            set_requires_grad([D_b, D_a],
                              True)  # -> set Discriminator to require gradient
            optimizer_D.zero_grad()

            # backward D_b
            fake_B_ = fake_B_pool.query(fake_B)
            #-> real_B, fake_B
            pred_real_B = D_b(real_B)
            loss_D_real_B = criterionGAN(pred_real_B, True)

            pred_fake_B = D_b(fake_B_)
            loss_D_fake_B = criterionGAN(pred_fake_B, False)

            loss_D_B = (loss_D_real_B + loss_D_fake_B) * 0.5
            loss_D_B.backward()

            # backward D_a
            fake_A_ = fake_A_pool.query(fake_A)
            #-> real_A, fake_A
            pred_real_A = D_a(real_A)
            loss_D_real_A = criterionGAN(pred_real_A, True)

            pred_fake_A = D_a(fake_A_)
            loss_D_fake_A = criterionGAN(pred_fake_A, False)

            loss_D_A = (loss_D_real_A + loss_D_fake_A) * 0.5
            loss_D_A.backward()

            # update the gradients
            optimizer_D.step()

            ### validate here, both qualitively and quantitatively
            ## record the losses
            if cur_iter % opt.log_freq == 0:
                # loss_names = ['GAN_A', 'Adv_A', 'Idt_A', 'CTC_A', 'ZId_A', 'ZCyc_A', 'GAN_B', 'Adv_B', 'Idt_B', 'CTC_B', 'ZId_B', 'ZCyc_B']
                losses = [
                    loss_G_A.item(),
                    loss_D_A.item(),
                    loss_idt_A.item(),
                    loss_CTC_A.item(),
                    loss_Zid_A.item(),
                    loss_ZCyc_A.item(),
                    loss_G_B.item(),
                    loss_D_B.item(),
                    loss_idt_B.item(),
                    loss_CTC_B.item(),
                    loss_Zid_B.item(),
                    loss_ZCyc_B.item()
                ]
                # record
                line = ''
                for loss in losses:
                    line += '{} '.format(loss)
                record_fh.write(line[:-1] + '\n')
                # print out
                print('Epoch: %3d/%3dIter: %9d--------------------------+' %
                      (epoch, opt.epoch, i))
                field_names = loss_names[:len(loss_names) // 2]
                table = PrettyTable(field_names=field_names)
                for l_n in field_names:
                    table.align[l_n] = 'm'
                table.add_row(losses[:len(field_names)])
                print(table.get_string(reversesort=True))

                field_names = loss_names[len(loss_names) // 2:]
                table = PrettyTable(field_names=field_names)
                for l_n in field_names:
                    table.align[l_n] = 'm'
                table.add_row(losses[-len(field_names):])
                print(table.get_string(reversesort=True))

            ## visualize
            if cur_iter % opt.vis_freq == 0:
                if opt.gpu_id >= 0:
                    real_A = real_A.cpu().data
                    real_B = real_B.cpu().data
                    fake_A = fake_A.cpu().data
                    fake_B = fake_B.cpu().data
                    idt_A = idt_A.cpu().data
                    idt_B = idt_B.cpu().data
                    T_rec_A = T_rec_A.cpu().data
                    T_rec_B = T_rec_B.cpu().data

                plt.subplot(241), plt.title('real_A'), plt.imshow(
                    tensor2image_RGB(real_A[0, ...]))
                plt.subplot(242), plt.title('fake_B'), plt.imshow(
                    tensor2image_RGB(fake_B[0, ...]))
                plt.subplot(243), plt.title('idt_A'), plt.imshow(
                    tensor2image_RGB(idt_A[0, ...]))
                plt.subplot(244), plt.title('L_idt_A'), plt.imshow(
                    tensor2image_RGB(T_rec_A[0, ...]))

                plt.subplot(245), plt.title('real_B'), plt.imshow(
                    tensor2image_RGB(real_B[0, ...]))
                plt.subplot(246), plt.title('fake_A'), plt.imshow(
                    tensor2image_RGB(fake_A[0, ...]))
                plt.subplot(247), plt.title('idt_B'), plt.imshow(
                    tensor2image_RGB(idt_B[0, ...]))
                plt.subplot(248), plt.title('L_idt_B'), plt.imshow(
                    tensor2image_RGB(T_rec_B[0, ...]))

                plt.savefig(
                    os.path.join(checkpoints_pth, 'images',
                                 '%03d_%09d.jpg' % (epoch, i)))

            cur_iter += 1
            #break #-> debug

        ## till now, we finish one epoch, try to update the learning rate
        update_learning_rate(schedulers=scheduler,
                             opt=opt,
                             optimizer=optimizer_D)
        ## save the model
        if epoch % opt.ckp_freq == 0:
            #-> save models
            # torch.save(model.state_dict(), PATH)
            #-> load in models
            # model.load_state_dict(torch.load(PATH))
            # model.eval()
            if opt.gpu_id >= 0:
                E_a2Zb = E_a2Zb.cpu()
                G_Zb2b = G_Zb2b.cpu()
                T_Zb2Za = T_Zb2Za.cpu()
                D_b = D_b.cpu()

                E_b2Za = E_b2Za.cpu()
                G_Za2a = G_Za2a.cpu()
                T_Za2Zb = T_Za2Zb.cpu()
                D_a = D_a.cpu()
                '''
                torch.save( E_a2Zb.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_a2b.pth' % epoch))
                torch.save( G_Zb2b.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-G_b.pth' % epoch))
                torch.save(T_Zb2Za.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_b2a.pth' % epoch))
                torch.save(    D_b.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-D_b.pth' % epoch))

                torch.save( E_b2Za.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_b2a.pth' % epoch))
                torch.save( G_Za2a.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-G_a.pth' % epoch))
                torch.save(T_Za2Zb.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_a2b.pth' % epoch))
                torch.save(    D_a.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-D_a.pth' % epoch))
                '''
            torch.save(
                E_a2Zb.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-E_a2b.pth' % epoch))
            torch.save(
                G_Zb2b.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-G_b.pth' % epoch))
            torch.save(
                T_Zb2Za.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-T_b2a.pth' % epoch))
            torch.save(
                D_b.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-D_b.pth' % epoch))

            torch.save(
                E_b2Za.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-E_b2a.pth' % epoch))
            torch.save(
                G_Za2a.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-G_a.pth' % epoch))
            torch.save(
                T_Za2Zb.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-T_a2b.pth' % epoch))
            torch.save(
                D_a.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-D_a.pth' % epoch))
            if opt.gpu_id >= 0:
                E_a2Zb = E_a2Zb.to(device)
                G_Zb2b = G_Zb2b.to(device)
                T_Zb2Za = T_Zb2Za.to(device)
                D_b = D_b.to(device)

                E_b2Za = E_b2Za.to(device)
                G_Za2a = G_Za2a.to(device)
                T_Za2Zb = T_Za2Zb.to(device)
                D_a = D_a.to(device)
            print("+Successfully saving models in epoch: %3d.-------------+" %
                  epoch)
        #break #-> debug
    record_fh.close()
    print("≧◔◡◔≦ Congratulation! Finishing the training!")
Example #18
0
    def set_scheduler(self, opts, last_ep=0):
        self.subspace_sch = networks.get_scheduler(self.subspace_opt, opts,
                                                   last_ep)
        #self.MI_sch = networks.get_scheduler(self.MI_opt, opts, last_ep)
        self.disA_sch = networks.get_scheduler(self.disA_opt, opts, last_ep)
        self.disB_sch = networks.get_scheduler(self.disB_opt, opts, last_ep)
        self.disA_attr_sch = networks.get_scheduler(self.disA_attr_opt, opts,
                                                    last_ep)
        self.disB_attr_sch = networks.get_scheduler(self.disB_attr_opt, opts,
                                                    last_ep)
        self.enc_c_sch = networks.get_scheduler(self.enc_c_opt, opts, last_ep)
        self.enc_a_sch = networks.get_scheduler(self.enc_a_opt, opts, last_ep)
        self.gen_sch = networks.get_scheduler(self.gen_opt, opts, last_ep)
        self.gen_attr_sch = networks.get_scheduler(self.gen_attr_opt, opts,
                                                   last_ep)

        self.subspace_pre_sch = networks.get_scheduler(self.subspace_pre_opt,
                                                       opts, last_ep)
        #self.MI_pre_sch = networks.get_scheduler(self.MI_pre_opt, opts, last_ep)
        self.enc_c_pre_sch = networks.get_scheduler(self.enc_c_pre_opt, opts,
                                                    last_ep)
        self.enc_a_pre_sch = networks.get_scheduler(self.enc_a_pre_opt, opts,
                                                    last_ep)
        self.gen_pre_sch = networks.get_scheduler(self.gen_pre_opt, opts,
                                                  last_ep)
    def initialize(self, opt):
        super(SupervisedPoseTransferModel, self).initialize(opt)
        ###################################
        # define transformer
        ###################################
        if opt.which_model_T == 'resnet':
            self.netT = networks.ResnetGenerator(
                input_nc=3 + self.get_pose_dim(opt.pose_type),
                output_nc=3,
                ngf=opt.T_nf,
                norm_layer=networks.get_norm_layer(opt.norm),
                use_dropout=not opt.no_dropout,
                n_blocks=9,
                gpu_ids=opt.gpu_ids)
        elif opt.which_model_T == 'unet':
            self.netT = networks.UnetGenerator_v2(
                input_nc=3 + self.get_pose_dim(opt.pose_type),
                output_nc=3,
                num_downs=8,
                ngf=opt.T_nf,
                norm_layer=networks.get_norm_layer(opt.norm),
                use_dropout=not opt.no_dropout,
                gpu_ids=opt.gpu_ids)
        else:
            raise NotImplementedError()

        if opt.gpu_ids:
            self.netT.cuda()
        networks.init_weights(self.netT, init_type=opt.init_type)
        ###################################
        # define discriminator
        ###################################
        self.use_GAN = self.is_train and opt.loss_weight_gan > 0
        if self.use_GAN > 0:
            self.netD = networks.define_D_from_params(
                input_nc=3 +
                self.get_pose_dim(opt.pose_type) if opt.D_cond else 3,
                ndf=opt.D_nf,
                which_model_netD='n_layers',
                n_layers_D=3,
                norm=opt.norm,
                which_gan=opt.which_gan,
                init_type=opt.init_type,
                gpu_ids=opt.gpu_ids)
        else:
            self.netD = None
        ###################################
        # loss functions
        ###################################
        if self.is_train:
            self.loss_functions = []
            self.schedulers = []
            self.optimizers = []

            self.crit_L1 = nn.L1Loss()
            self.crit_vgg = networks.VGGLoss_v2(self.gpu_ids)
            # self.crit_vgg_old = networks.VGGLoss(self.gpu_ids)
            self.crit_psnr = networks.PSNR()
            self.crit_ssim = networks.SSIM()
            self.loss_functions += [self.crit_L1, self.crit_vgg]
            self.optim = torch.optim.Adam(self.netT.parameters(),
                                          lr=opt.lr,
                                          betas=(opt.beta1, opt.beta2))
            self.optimizers += [self.optim]

            if self.use_GAN:
                self.crit_GAN = networks.GANLoss(
                    use_lsgan=opt.which_gan == 'lsgan', tensor=self.Tensor)
                self.optim_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr_D,
                                                betas=(opt.beta1, opt.beta2))
                self.loss_functions.append(self.use_GAN)
                self.optimizers.append(self.optim_D)
            # todo: add pose loss
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))

            self.fake_pool = ImagePool(opt.pool_size)

        ###################################
        # load trained model
        ###################################
        if not self.is_train:
            self.load_network(self.netT, 'netT', opt.which_model)
    def initialize(self, opt):
        super(FlowRegressionModel, self).initialize(opt)
        ###################################
        # define flow networks
        ###################################
        if opt.which_model == 'unet':
            self.netF = networks.FlowUnet(
                input_nc=self.get_input_dim(opt.input_type1) +
                self.get_input_dim(opt.input_type2),
                nf=opt.nf,
                start_scale=opt.start_scale,
                num_scale=opt.num_scale,
                norm=opt.norm,
                gpu_ids=opt.gpu_ids,
            )
        elif opt.which_model == 'unet_v2':
            self.netF = networks.FlowUnet_v2(
                input_nc=self.get_input_dim(opt.input_type1) +
                self.get_input_dim(opt.input_type2),
                nf=opt.nf,
                max_nf=opt.max_nf,
                start_scale=opt.start_scale,
                num_scales=opt.num_scale,
                norm=opt.norm,
                gpu_ids=opt.gpu_ids,
            )
        if opt.gpu_ids:
            self.netF.cuda()
        networks.init_weights(self.netF, init_type=opt.init_type)
        ###################################
        # loss and optimizers
        ###################################
        self.crit_flow = networks.MultiScaleFlowLoss(
            start_scale=opt.start_scale,
            num_scale=opt.num_scale,
            loss_type=opt.flow_loss_type)
        self.crit_vis = nn.CrossEntropyLoss(
        )  #(0-visible, 1-invisible, 2-background)
        if opt.use_ss_flow_loss:
            self.crit_flow_ss = networks.SS_FlowLoss(loss_type='l1')
        if self.is_train:
            self.optimizers = []
            self.optim = torch.optim.Adam(self.netF.parameters(),
                                          lr=opt.lr,
                                          betas=(opt.beta1, opt.beta2),
                                          weight_decay=opt.weight_decay)
            self.optimizers.append(self.optim)

        ###################################
        # load trained model
        ###################################
        if not self.is_train:
            # load trained model for test
            print('load pretrained model')
            self.load_network(self.netF, 'netF', opt.which_epoch)
        elif opt.resume_train:
            # resume training
            print('resume training')
            self.load_network(self.netF, 'netF', opt.last_epoch)
            self.load_optim(self.optim, 'optim', opt.last_epoch)
        ###################################
        # schedulers
        ###################################
        if self.is_train:
            self.schedulers = []
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))
Example #21
0
# define loss
criterionL1 = nn.L1Loss().to(device)
criterionL2 = nn.MSELoss().to(device)
criterionMSE = nn.MSELoss().to(device)
criterionSSIM = SSIM(data_range=255, size_average=True, channel=3)
criterionMSSSIM1 = MS_SSIM(data_range=255, size_average=True, channel=1)
criterionMSSSIM3 = MS_SSIM(data_range=255, size_average=True, channel=3)

# setup optimizer
optimizer_i = optim.Adam(net_i.parameters(),
                         lr=opt.lr,
                         betas=(opt.beta1, 0.999))
optimizer_r = optim.Adam(net_r.parameters(),
                         lr=opt.lr,
                         betas=(opt.beta1, 0.999))
net_i_scheduler = get_scheduler(optimizer_i, opt)
net_r_scheduler = get_scheduler(optimizer_r, opt)

loss_i_list = []
loss_r_list = []
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    loss_i_per_epoch_list = []
    loss_r_per_epoch_list = []
    # train
    for iteration, batch in enumerate(training_data_loader, 1):
        # forward
        I_low, I_high, R_low, R_high, target = batch[0].to(
            device), batch[1].to(device), batch[2].to(device), batch[3].to(
                device), batch[4].to(device)
        I_high_rec = net_i(I_low)
        R_high_rec = net_r(R_low)
Example #22
0
    def __init__(self, opt=None):
        '''

        :param opt:
        :param nic:
        '''
        super(cycleGAN, self).__init__()
        if opt == None:
            #             parser = argparse.ArgumentParser()
            #             parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
            #             parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
            #             parser.add_argument("--b2", type=float, default=0.999,
            #                                 help="adam: decay of first order momentum of gradient")
            #             parser.add_argument("--channels", type=int, default=3, help="number of image channels")
            #             parser.add_argument("--out_channels", type=int, default=3, help="number of generator output channels")
            #             parser.add_argument("--n_residual_blocks", type=int, default=9,
            #                                 help="number of residual blocks in generator")
            #             parser.add_argument("--save_dir", type=str, default='/saved_models/', help="save directory")
            #             option = parser.parse_args()
            opt = self.default_option()
        self.opt = opt
        nic = opt.channels
        noc = opt.out_channels
        # model & loss names
        self.model_names = ['GenA', 'GenB', 'DisA', 'DisB']
        self.loss_names = [
            'D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'
        ]
        # define Generator
        """
        ResnetGenerator(nic, out_channels, num_residual_blocks, ngf)
        """
        self.GenA = networks.ResnetGenerator(nic, noc, opt.n_residual_blocks)
        self.GenB = networks.ResnetGenerator(nic, noc, opt.n_residual_blocks)

        # define Discriminator
        """
        Discriminator(nic)
        """
        self.DisA = networks.Discriminator(nic, opt.image_size)
        self.DisB = networks.Discriminator(nic, opt.image_size)

        # criterion define loss function
        """
        GAN Loss
        Cycle-Consistency Loss
        Identity Loss
        """
        self.criterion_GAN = nn.MSELoss()
        self.criterion_Cycle = nn.L1Loss()
        self.criterion_idt = nn.L1Loss()
        self.optimizers = []

        # define optimizer
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.GenA.parameters(), self.GenB.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.b1, 0.999))
        self.optimizer_D = torch.optim.Adam(itertools.chain(
            self.DisA.parameters(), self.DisB.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.b1, 0.999))
        self.optimizers.append(self.optimizer_G)
        self.optimizers.append(self.optimizer_D)

        #
        self.save_dir = opt.save_dir
        self.device = opt.device

        step_size = 100
        self.schedulers = [
            networks.get_scheduler(optimizer, step_size)
            for optimizer in self.optimizers
        ]
    def initialize(self, opt):
        super(MultimodalDesignerGAN_V2, self).initialize(opt)
        ###################################
        # define networks
        ###################################
        self.modules = {}
        # shape branch
        if opt.which_model_netG != 'unet':
            self.shape_encoder = networks.define_image_encoder(opt, 'shape')
            self.modules['shape_encoder'] = self.shape_encoder
        else:
            self.shape_encoder = None
        # edge branch
        if opt.use_edge:
            self.edge_encoder = networks.define_image_encoder(opt, 'edge')
            self.modules['edge_encoder'] = self.edge_encoder
        else:
            self.encoder_edge = None
        # color branch
        if opt.use_color:
            self.color_encoder = networks.define_image_encoder(opt, 'color')
            self.modules['color_encoder'] = self.color_encoder
        else:
            self.color_encoder = None

        # fusion model
        if opt.ftn_model == 'none':
            # shape_feat, edge_feat and color_feat will be simply upmpled to same size (size of shape_feat) and concatenated
            pass
        elif opt.ftn_model == 'concat':
            assert opt.use_edge or opt.use_color
            if opt.use_edge:
                self.edge_trans_net = networks.define_feature_fusion_network(
                    name='FeatureConcatNetwork',
                    feat_nc=opt.edge_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['edge_trans_net'] = self.edge_trans_net
            if opt.use_color:
                self.color_trans_net = networks.define_feature_fusion_network(
                    name='FeatureConcatNetwork',
                    feat_nc=opt.color_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['color_trans_net'] = self.color_trans_net
        elif opt.ftn_model == 'reduce':
            assert opt.use_edge or opt.use_color
            if opt.use_edge:
                self.edge_trans_net = networks.define_feature_fusion_network(
                    name='FeatureReduceNetwork',
                    feat_nc=opt.edge_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    ndowns=opt.ftn_ndowns,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['edge_trans_net'] = self.edge_trans_net
            if opt.use_color:
                self.color_trans_net = networks.define_feature_fusion_network(
                    name='FeatureReduceNetwork',
                    feat_nc=opt.color_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    ndowns=opt.ftn_ndowns,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['color_trans_net'] = self.color_trans_net

        elif opt.ftn_model == 'trans':
            assert opt.use_edge or opt.use_color
            if opt.use_edge:
                self.edge_trans_net = networks.define_feature_fusion_network(
                    name='FeatureTransformNetwork',
                    feat_nc=opt.edge_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    feat_size=opt.feat_size_lr,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['edge_trans_net'] = self.edge_trans_net
            if opt.use_color:
                self.color_trans_net = networks.define_feature_fusion_network(
                    name='FeatureTransformNetwork',
                    feat_nc=opt.color_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    feat_size=opt.feat_size_lr,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['color_trans_net'] = self.color_trans_net

        # netG
        self.netG = networks.define_generator(opt)
        self.modules['netG'] = self.netG

        # netD
        if self.is_train:
            self.netD = networks.define_D(opt)
            self.modules['netD'] = self.netD

        ###################################
        # load weights
        ###################################
        if self.is_train:
            if opt.continue_train:
                for label, net in self.modules.iteritems():
                    self.load_network(net, label, opt.which_epoch)
            else:
                if opt.which_model_init != 'none':
                    # load pretrained entire model
                    for label, net in self.modules.iteritems():
                        self.load_network(net,
                                          label,
                                          'latest',
                                          opt.which_model_init,
                                          forced=False)
                else:
                    # load pretrained encoder
                    if opt.which_model_netG != 'unet' and opt.pretrain_shape:
                        self.load_network(self.shape_encoder, 'shape_encoder',
                                          'latest',
                                          opt.which_model_init_shape_encoder)
                    if opt.use_edge and opt.pretrain_edge:
                        self.load_network(self.edge_encoder, 'edge_encoder',
                                          'latest',
                                          opt.which_model_init_edge_encoder)
                    if opt.use_color and opt.pretrain_color:
                        self.load_network(self.color_encoder, 'color_encoder',
                                          'latest',
                                          opt.which_model_init_color_encoder)
        else:
            for label, net in self.modules.iteritems():
                if label != 'netD':
                    self.load_network(net, label, opt.which_epoch)

        ###################################
        # prepare for training
        ###################################
        if self.is_train:
            self.fake_pool = ImagePool(opt.pool_size)
            ###################################
            # define loss functions
            ###################################
            self.loss_functions = []
            if opt.which_gan in {'dcgan', 'lsgan'}:
                self.crit_GAN = networks.GANLoss(
                    use_lsgan=opt.which_gan == 'lsgan', tensor=self.Tensor)
                self.loss_functions.append(self.crit_GAN)
            else:
                # WGAN loss will be calculated in self.backward_D_wgangp and self.backward_G
                self.crit_GAN = None

            self.crit_L1 = nn.L1Loss()
            self.loss_functions.append(self.crit_L1)

            if self.opt.loss_weight_vgg > 0:
                self.crit_vgg = networks.VGGLoss(self.gpu_ids)
                self.loss_functions.append(self.crit_vgg)

            if self.opt.G_output_seg:
                self.crit_CE = nn.CrossEntropyLoss()
                self.loss_functions.append(self.crit_CE)

            self.crit_psnr = networks.SmoothLoss(networks.PSNR())
            self.loss_functions.append(self.crit_psnr)
            ###################################
            # create optimizers
            ###################################
            self.schedulers = []
            self.optimizers = []

            # G optimizer
            G_module_list = [
                'shape_encoder', 'edge_encoder', 'color_encoder', 'netG'
            ]
            G_param_groups = [{
                'params': self.modules[m].parameters()
            } for m in G_module_list if m in self.modules]
            self.optim_G = torch.optim.Adam(G_param_groups,
                                            lr=opt.lr,
                                            betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optim_G)
            # D optimizer
            self.optim_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr_D,
                                            betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optim_D)
            # feature transfer network optimizer
            FTN_module_list = ['edge_trans_net', 'color_trans_net']
            FTN_param_groups = [{
                'params': self.modules[m].parameters()
            } for m in FTN_module_list if m in self.modules]
            if len(FTN_param_groups) > 0:
                self.optim_FTN = torch.optim.Adam(FTN_param_groups,
                                                  lr=opt.lr_FTN,
                                                  betas=(0.9, 0.999))
                self.optimizers.append(self.optim_FTN)
            else:
                self.optim_FTN = None
            # schedulers
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))
Example #24
0
    def initialize(self, opt):
        super(AttributeEncoder, self).initialize(opt)

        # define tensors
        self.input['img'] = self.Tensor(opt.batch_size, opt.input_nc,
                                        opt.fine_size, opt.fine_size)
        self.input['label'] = self.Tensor(opt.batch_size, opt.n_attr)

        # load/define networks
        self.net = networks.define_attr_encoder_net(opt)

        if not self.is_train or opt.continue_train:
            self.load_network(self.net,
                              network_label='AE',
                              epoch_label=opt.which_epoch)

        self.schedulers = []
        self.optimizers = []
        self.loss_functions = []

        # define loss functions
        # attribute
        if opt.loss_type == 'bce':
            self.crit_attr = networks.SmoothLoss(
                torch.nn.BCELoss(size_average=not opt.no_size_avg))
        elif opt.loss_type == 'wbce':
            attr_entry = io.load_json(os.path.join(opt.data_root,
                                                   opt.fn_entry))
            pos_rate = self.Tensor([att['pos_rate'] for att in attr_entry])
            pos_rate.clamp_(min=0.01, max=0.99)
            self.crit_attr = networks.SmoothLoss(
                networks.WeightedBCELoss(pos_rate=pos_rate,
                                         class_norm=opt.wbce_class_norm,
                                         size_average=not opt.no_size_avg))
        else:
            raise ValueError('attribute loss type "%s" is not defined' %
                             opt.loss_type)
        self.loss_functions.append(self.crit_attr)

        # joint task
        if opt.joint_cat:
            self.crit_cat = networks.SmoothLoss(torch.nn.CrossEntropyLoss())
            self.loss_functions.append(self.crit_cat)

        # initialize optimizers
        if opt.is_train:
            if opt.optim == 'adam':
                self.optim_attr = torch.optim.Adam(
                    self.net.parameters(),
                    lr=opt.lr,
                    betas=(opt.beta1, 0.999),
                    weight_decay=opt.weight_decay)
            elif opt.optim == 'sgd':
                self.optim_attr = torch.optim.SGD(
                    self.net.parameters(),
                    lr=opt.lr,
                    momentum=0.9,
                    weight_decay=opt.weight_decay)
            self.optimizers.append(self.optim_attr)

            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))
    def initialize(self, opt):
        super(EncoderDecoderFramework_DFN, self).initialize(opt)
        ###################################
        # define encoder
        ###################################
        self.encoder = networks.define_encoder_v2(opt)
        ###################################
        # define decoder
        ###################################
        self.decoder = networks.define_decoder_v2(opt)
        ###################################
        # guide encoder
        ###################################        
        self.guide_encoder, self.opt_guide = networks.load_encoder_v2(opt, opt.which_model_guide)
        self.guide_encoder.eval()
        for p in self.guide_encoder.parameters():
            p.requires_grad = False
        ###################################
        # DFN Modules
        ###################################
        if self.opt.use_dfn:
            self.dfn = networks.define_DFN_from_params(nf=opt.nof, ng=self.opt_guide.nof, nmid=opt.dfn_nmid, feat_size=opt.feat_size, local_size=opt.dfn_local_size, nblocks=opt.dfn_nblocks, norm=opt.norm, gpu_ids=opt.gpu_ids, init_type=opt.init_type)
        else:
            self.dfn = None
        ###################################
        # Discriminator
        ###################################
        self.use_GAN = self.is_train and opt.loss_weight_gan > 0
        if self.is_train:
            if self.use_GAN:
                # if not self.opt.D_cond:
                #     input_nc = self.decoder.output_nc
                # else:
                #     input_nc = self.decoder.output_nc + self.encoder.input_nc
                if self.opt.gan_level == 'image':
                    input_nc = self.decoder.output_nc
                elif self.opt.gan_level == 'feature':
                    input_nc = self.opt.nof
                if self.opt.D_cond:
                    if self.opt.D_cond_type == 'cond':
                        input_nc += self.encoder.input_nc
                    elif self.opt.D_cond_type == 'pair':
                        input_nc += input_nc
                self.netD = networks.define_D_from_params(input_nc=input_nc, ndf=64, which_model_netD='n_layers', n_layers_D=3, norm=opt.norm, which_gan=opt.which_gan, init_type=opt.init_type, gpu_ids=opt.gpu_ids)
            else:
                self.netD = None
        ###################################
        # loss functions
        ###################################
        if self.is_train:
            self.loss_functions = []
            self.schedulers = []
            self.crit_image = nn.L1Loss()
            self.crit_seg = nn.CrossEntropyLoss()
            self.crit_edge = nn.BCELoss()
            self.loss_functions += [self.crit_image, self.crit_seg, self.crit_edge]
            if self.opt.use_dfn:
                self.optim = torch.optim.Adam([{'params': self.encoder.parameters()}, {'params': self.decoder.parameters()}, {'params': self.dfn.parameters()}], lr=opt.lr, betas=(opt.beta1, opt.beta2))
            else:
                self.optim = torch.optim.Adam([{'params': self.encoder.parameters()}, {'params': self.decoder.parameters()}], lr=opt.lr, betas=(opt.beta1, opt.beta2))
            self.optimizers = [self.optim]
            # GAN loss and optimizers
            if self.use_GAN > 0:
                self.crit_GAN = networks.GANLoss(use_lsgan=opt.which_gan=='lsgan', tensor=self.Tensor)
                self.optim_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr_D, betas=(0.5, 0.999))
                self.loss_functions += [self.crit_GAN]
                self.optimizers += [self.optim_D]

            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))

        ###################################
        # load trained model
        ###################################
        if not self.is_train:
            self.load_network(self.encoder, 'encoder', opt.which_epoch)
            self.load_network(self.decoder, 'decoder', opt.which_epoch)
            if opt.use_dfn:
                self.load_network(self.dfn, 'dfn', opt.which_epoch)
    def initialize(self, opt):
        super(PoseTransferModel, self).initialize(opt)
        ###################################
        # define generator
        ###################################
        if opt.which_model_G == 'unet':
            self.netG = networks.UnetGenerator(
                input_nc=self.get_tensor_dim('+'.join(
                    [opt.G_appearance_type, opt.G_pose_type])),
                output_nc=3,
                nf=opt.G_nf,
                max_nf=opt.G_max_nf,
                num_scales=opt.G_n_scale,
                n_residual_blocks=2,
                norm=opt.G_norm,
                activation=nn.LeakyReLU(0.1)
                if opt.G_activation == 'leaky_relu' else nn.ReLU(),
                use_dropout=opt.use_dropout,
                gpu_ids=opt.gpu_ids)
        elif opt.which_model_G == 'dual_unet':
            self.netG = networks.DualUnetGenerator(
                pose_nc=self.get_tensor_dim(opt.G_pose_type),
                appearance_nc=self.get_tensor_dim(opt.G_appearance_type),
                output_nc=3,
                aux_output_nc=[],
                nf=opt.G_nf,
                max_nf=opt.G_max_nf,
                num_scales=opt.G_n_scale,
                num_warp_scales=opt.G_n_warp_scale,
                n_residual_blocks=2,
                norm=opt.G_norm,
                vis_mode=opt.G_vis_mode,
                activation=nn.LeakyReLU(0.1)
                if opt.G_activation == 'leaky_relu' else nn.ReLU(),
                use_dropout=opt.use_dropout,
                no_end_norm=opt.G_no_end_norm,
                gpu_ids=opt.gpu_ids,
            )
        if opt.gpu_ids:
            self.netG.cuda()
        networks.init_weights(self.netG, init_type=opt.init_type)
        ###################################
        # define external pixel warper
        ###################################
        if opt.G_pix_warp:
            pix_warp_n_scale = opt.G_n_scale
            self.netPW = networks.UnetGenerator_MultiOutput(
                input_nc=self.get_tensor_dim(opt.G_pix_warp_input_type),
                output_nc=[1],  # only use one output branch (weight mask)
                nf=32,
                max_nf=128,
                num_scales=pix_warp_n_scale,
                n_residual_blocks=2,
                norm=opt.G_norm,
                activation=nn.ReLU(False),
                use_dropout=False,
                gpu_ids=opt.gpu_ids)
            if opt.gpu_ids:
                self.netPW.cuda()
            networks.init_weights(self.netPW, init_type=opt.init_type)
        ###################################
        # define discriminator
        ###################################
        self.use_gan = self.is_train and self.opt.loss_weight_gan > 0
        if self.use_gan:
            self.netD = networks.NLayerDiscriminator(
                input_nc=self.get_tensor_dim(opt.D_input_type_real),
                ndf=opt.D_nf,
                n_layers=opt.D_n_layers,
                use_sigmoid=(opt.gan_type == 'dcgan'),
                output_bias=True,
                gpu_ids=opt.gpu_ids,
            )
            if opt.gpu_ids:
                self.netD.cuda()
            networks.init_weights(self.netD, init_type=opt.init_type)
        ###################################
        # load optical flow model
        ###################################
        if opt.flow_on_the_fly:
            self.netF = load_flow_network(opt.pretrained_flow_id,
                                          opt.pretrained_flow_epoch,
                                          opt.gpu_ids)
            self.netF.eval()
            if opt.gpu_ids:
                self.netF.cuda()
        ###################################
        # loss and optimizers
        ###################################
        self.crit_psnr = networks.PSNR()
        self.crit_ssim = networks.SSIM()

        if self.is_train:
            self.crit_vgg = networks.VGGLoss(
                opt.gpu_ids,
                shifted_style=opt.shifted_style_loss,
                content_weights=opt.vgg_content_weights)
            if opt.G_pix_warp:
                # only optimze netPW
                self.optim = torch.optim.Adam(self.netPW.parameters(),
                                              lr=opt.lr,
                                              betas=(opt.beta1, opt.beta2),
                                              weight_decay=opt.weight_decay)
            else:
                self.optim = torch.optim.Adam(self.netG.parameters(),
                                              lr=opt.lr,
                                              betas=(opt.beta1, opt.beta2),
                                              weight_decay=opt.weight_decay)
            self.optimizers = [self.optim]
            if self.use_gan:
                self.crit_gan = networks.GANLoss(
                    use_lsgan=(opt.gan_type == 'lsgan'))
                if self.gpu_ids:
                    self.crit_gan.cuda()
                self.optim_D = torch.optim.Adam(
                    self.netD.parameters(),
                    lr=opt.lr_D,
                    betas=(opt.beta1, opt.beta2),
                    weight_decay=opt.weight_decay_D)
                self.optimizers += [self.optim_D]

        ###################################
        # load trained model
        ###################################
        if not self.is_train:
            # load trained model for testing
            self.load_network(self.netG, 'netG', opt.which_epoch)
            if opt.G_pix_warp:
                self.load_network(self.netPW, 'netPW', opt.which_epoch)
        elif opt.pretrained_G_id is not None:
            # load pretrained network
            self.load_network(self.netG, 'netG', opt.pretrained_G_epoch,
                              opt.pretrained_G_id)
        elif opt.resume_train:
            # resume training
            self.load_network(self.netG, 'netG', opt.which_epoch)
            self.load_optim(self.optim, 'optim', opt.which_epoch)
            if self.use_gan:
                self.load_network(self.netD, 'netD', opt.which_epoch)
                self.load_optim(self.optim_D, 'optim_D', opt.which_epoch)
            if opt.G_pix_warp:
                self.load_network(self.netPW, 'netPW', opt.which_epoch)
        ###################################
        # schedulers
        ###################################
        if self.is_train:
            self.schedulers = []
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))
Example #27
0
    def __init__(self, p):

        super(CycleGAN, self).__init__(p)
        nb = p.batchSize
        size = p.cropSize

        # load/define models
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_A = networks.define_G(p.input_nc, p.output_nc, p.ngf,
                                        p.which_model_netG, p.norm,
                                        not p.no_dropout, p.init_type,
                                        self.gpu_ids)
        self.netG_B = networks.define_G(p.output_nc, p.input_nc, p.ngf,
                                        p.which_model_netG, p.norm,
                                        not p.no_dropout, p.init_type,
                                        self.gpu_ids)

        if self.isTrain:
            use_sigmoid = p.no_lsgan
            self.netD_A = networks.define_D(p.output_nc, p.ndf,
                                            p.which_model_netD, p.n_layers_D,
                                            p.norm, use_sigmoid, p.init_type,
                                            self.gpu_ids)
            self.netD_B = networks.define_D(p.input_nc, p.ndf,
                                            p.which_model_netD, p.n_layers_D,
                                            p.norm, use_sigmoid, p.init_type,
                                            self.gpu_ids)

        if not self.isTrain or p.continue_train:
            which_epoch = p.which_epoch
            self.load_model(self.netG_A, 'G_A', which_epoch)
            self.load_model(self.netG_B, 'G_B', which_epoch)
            if self.isTrain:
                self.load_model(self.netD_A, 'D_A', which_epoch)
                self.load_model(self.netD_B, 'D_B', which_epoch)

        if self.isTrain:
            self.old_lr = p.lr
            self.fake_A_pool = ImagePool(p.pool_size)
            self.fake_B_pool = ImagePool(p.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not p.no_lsgan,
                                                 tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=p.lr,
                                                betas=(p.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
                                                  lr=p.lr,
                                                  betas=(p.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
                                                  lr=p.lr,
                                                  betas=(p.beta1, 0.999))

            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, p))
def main():

	print(f"epoch: {opt.niter+opt.niter_decay}")
	print(f"cuda: {opt.cuda}")
	print(f"dataset: {opt.dataset}")
	print(f"output: {opt.output_path}")

	if opt.cuda and not torch.cuda.is_available():
		raise Exception("No GPU found, please run without --cuda")

	cudnn.benchmark = True

	torch.manual_seed(opt.seed)
	if opt.cuda:
		torch.cuda.manual_seed(opt.seed)

	print('Loading datasets')
	train_set = get_training_set(root_path + opt.dataset, opt.direction)
	test_set = get_test_set(root_path + opt.dataset, opt.direction)

	training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True)
	testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.test_batch_size, shuffle=False)

	device = torch.device("cuda:0" if opt.cuda else "cpu")

	print('Building models')
	net_g = define_G(opt.input_nc, opt.output_nc, opt.g_ch, len(class_name_array), 'batch', False, 'normal', 0.02, gpu_id=device)
	net_d = define_D(opt.input_nc + opt.output_nc, opt.d_ch, len(class_name_array), 'basic', gpu_id=device)

	criterionGAN = GANLoss().to(device)
	criterionL1 = nn.L1Loss().to(device)
	criterionMSE = nn.MSELoss().to(device)

	# setup optimizer
	optimizer_g = optim.Adam(net_g.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
	optimizer_d = optim.Adam(net_d.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
	net_g_scheduler = get_scheduler(optimizer_g, opt)
	net_d_scheduler = get_scheduler(optimizer_d, opt)

	start_time = time.time()

	#save loss
	G_loss_array = []
	D_loss_array = []
	epoch_array = []

	for epoch in tqdm(range(opt.epoch_count, opt.niter + opt.niter_decay + 1), desc="Epoch"):
		# train
		loss_g_sum = 0
		loss_d_sum = 0
		for iteration, batch in enumerate(tqdm(training_data_loader, desc="Batch"), 1):
			# forward
			real_a, real_b, class_label, _ = batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3][0]
			fake_b = net_g(real_a, class_label)

			######################
			# (1) Update D network
			######################

			optimizer_d.zero_grad()
			
			# train with fake
			if opt.padding:
				real_a_for_d = padding(real_a)
				real_b_for_d = padding(real_b)
				fake_b_for_d = padding(fake_b)
			else:
				real_a_for_d = real_a
				real_b_for_d = real_b
				fake_b_for_d = fake_b
			
			fake_ab = torch.cat((real_a_for_d, fake_b_for_d), 1)
			pred_fake = net_d.forward(fake_ab.detach(), class_label)
			loss_d_fake = criterionGAN(pred_fake, False)

			# train with real
			real_ab = torch.cat((real_a_for_d, real_b_for_d), 1)
			pred_real = net_d.forward(real_ab, class_label)
			loss_d_real = criterionGAN(pred_real, True)
			
			# Combined D loss
			loss_d = (loss_d_fake + loss_d_real) * 0.5

			loss_d.backward()
		   
			optimizer_d.step()

			######################
			# (2) Update G network
			######################

			optimizer_g.zero_grad()

			# First, G(A) should fake the discriminator
			fake_ab = torch.cat((real_a_for_d, fake_b_for_d), 1)
			pred_fake = net_d.forward(fake_ab, class_label)
			loss_g_gan = criterionGAN(pred_fake, True)

			# Second, G(A) = B
			loss_g_l1 = criterionL1(fake_b, real_b) * opt.lamb
			
			loss_g = loss_g_gan + loss_g_l1
			
			loss_g.backward()

			optimizer_g.step()
			loss_d_sum += loss_d.item()
			loss_g_sum += loss_g.item()

		update_learning_rate(net_g_scheduler, optimizer_g)
		update_learning_rate(net_d_scheduler, optimizer_d)
		
		# test
		avg_psnr = 0
		dst = Image.new('RGB', (512*4, 256*4))
		n = 0
		for batch in tqdm(testing_data_loader, desc="Batch"):
			input, target, class_label, _ = batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3][0]

			prediction = net_g(input, class_label)
			mse = criterionMSE(prediction, target)
			psnr = 10 * log10(1 / mse.item())
			avg_psnr += psnr
			
			n += 1
			if n <= 16:
				#make test preview
				out_img = prediction.detach().squeeze(0).cpu()
				image_numpy = out_img.float().numpy()
				image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
				image_numpy = image_numpy.clip(0, 255)
				image_numpy = image_numpy.astype(np.uint8)
				image_pil = Image.fromarray(image_numpy)
				dst.paste(image_pil, ((n-1)%4*512, (n-1)//4*256))
				
		if not os.path.exists("results"):
			os.mkdir("results")
		if not os.path.exists(os.path.join("results", opt.output_path)):
			os.mkdir(os.path.join("results", opt.output_path))
		dst.save(f"results/{opt.output_path}/epoch{epoch}_test_preview.jpg")
		
		epoch_array += [epoch]
		G_loss_array += [loss_g_sum/len(training_data_loader)]
		D_loss_array += [loss_d_sum/len(training_data_loader)]
		
		if opt.graph_save_while_training and len(epoch_array) > 1:
			output_graph(epoch_array, G_loss_array, D_loss_array, False)
		
		#checkpoint
		if epoch % opt.save_interval == 0:
			if not os.path.exists("checkpoint"):
				os.mkdir("checkpoint")
			if not os.path.exists(os.path.join("checkpoint", opt.output_path)):
				os.mkdir(os.path.join("checkpoint", opt.output_path))
			net_g_model_out_path = "checkpoint/{}/netG_model_epoch_{}.pth".format(opt.output_path, epoch)
			net_d_model_out_path = "checkpoint/{}/netD_model_epoch_{}.pth".format(opt.output_path, epoch)
			torch.save(net_g, net_g_model_out_path)
			torch.save(net_d, net_d_model_out_path)

	#save the latest net
	if not os.path.exists("checkpoint"):
		os.mkdir("checkpoint")
	if not os.path.exists(os.path.join("checkpoint", opt.output_path)):
		os.mkdir(os.path.join("checkpoint", opt.output_path))
	net_g_model_out_path = "checkpoint/{}/netG_model_epoch_{}.pth".format(opt.output_path, opt.niter + opt.niter_decay)
	net_d_model_out_path = "checkpoint/{}/netD_model_epoch_{}.pth".format(opt.output_path, opt.niter + opt.niter_decay)
	torch.save(net_g, net_g_model_out_path)
	torch.save(net_d, net_d_model_out_path)
	print("\nCheckpoint saved to {}".format("checkpoint/" + opt.output_path))

	# output loss graph
	output_graph(epoch_array, G_loss_array, D_loss_array)

	# finish training
	now_time = time.time()
	t = now_time - start_time
	print(f"Training time: {t/60:.1f}m")
    def initialize(self, opt):
        super(MultimodalDesignerGAN, self).initialize(opt)
        ###################################
        # load/define networks
        ###################################

        # basic G
        self.netG = networks.define_G(opt)

        # encoders
        self.encoders = {}
        if opt.use_edge:
            self.edge_encoder = networks.define_image_encoder(opt, 'edge')
            self.encoders['edge_encoder'] = self.edge_encoder
        if opt.use_color:
            self.color_encoder = networks.define_image_encoder(opt, 'color')
            self.encoders['color_encoder'] = self.color_encoder
        if opt.use_attr:
            self.attr_encoder, self.opt_AE = network_loader.load_attribute_encoder_net(
                id=opt.which_model_AE, gpu_ids=opt.gpu_ids)

        # basic D and auxiliary Ds
        if self.is_train:
            # basic D
            self.netD = networks.define_D(opt)
            # auxiliary Ds
            self.auxiliaryDs = {}
            if opt.use_edge_D:
                assert opt.use_edge
                self.netD_edge = networks.define_D_from_params(
                    input_nc=opt.edge_nof + 3,
                    ndf=opt.ndf,
                    which_model_netD=opt.which_model_netD,
                    n_layers_D=opt.n_layers_D,
                    norm=opt.norm,
                    which_gan='dcgan',
                    init_type=opt.init_type,
                    gpu_ids=opt.gpu_ids)
                self.auxiliaryDs['D_edge'] = self.netD_edge
            if opt.use_color_D:
                assert opt.use_color
                self.netD_color = networks.define_D_from_params(
                    input_nc=opt.color_nof + 3,
                    ndf=opt.ndf,
                    which_model_netD=opt.which_model_netD,
                    n_layers_D=opt.n_layers_D,
                    norm=opt.norm,
                    which_gan='dcgan',
                    init_type=opt.init_type,
                    gpu_ids=opt.gpu_ids)
                self.auxiliaryDs['D_color'] = self.netD_color
            if opt.use_attr_D:
                assert opt.use_attr
                attr_nof = opt.n_attr_feat if opt.attr_cond_type in {
                    'feat', 'feat_map'
                } else opt.n_attr
                self.netD_attr = networks.define_D_from_params(
                    input_nc=attr_nof + 3,
                    ndf=opt.ndf,
                    which_model_netD=opt.which_model_netD,
                    n_layers_D=opt.n_layers_D,
                    norm=opt.norm,
                    which_gan='dcgan',
                    init_type=opt.init_type,
                    gpu_ids=opt.gpu_ids)
                self.auxiliaryDs['D_attr'] = self.netD_attr
            # load weights
            if not opt.continue_train:
                if opt.which_model_init != 'none':
                    self.load_network(self.netG, 'G', 'latest',
                                      opt.which_model_init)
                    self.load_network(self.netD, 'D', 'latest',
                                      opt.which_model_init)
                    for l, net in self.encoders.iteritems():
                        self.load_network(net, l, 'latest',
                                          opt.which_model_init)
                    for l, net in self.auxiliaryDs.iteritems():
                        self.load_network(net, l, 'latest',
                                          opt.which_model_init)
            else:
                self.load_network(self.netG, 'G', opt.which_epoch)
                self.load_network(self.netD, 'D', opt.which_epoch)
                for l, net in self.encoders.iteritems():
                    self.load_network(net, l, opt.which_epoch)
                for l, net in self.auxiliaryDs.iteritems():
                    self.load_network(net, l, opt.which_epoch)
        else:
            self.load_network(self.netG, 'G', opt.which_epoch)
            for l, net in self.encoders.iteritems():
                self.load_network(net, l, opt.which_epoch)

        if self.is_train:
            self.fake_pool = ImagePool(opt.pool_size)
            ###################################
            # define loss functions and loss buffers
            ###################################
            self.loss_functions = []
            if opt.which_gan in {'dcgan', 'lsgan'}:
                self.crit_GAN = networks.GANLoss(
                    use_lsgan=opt.which_gan == 'lsgan', tensor=self.Tensor)
            else:
                # WGAN loss will be calculated in self.backward_D_wgangp and self.backward_G
                self.crit_GAN = None

            self.loss_functions.append(self.crit_GAN)

            self.crit_L1 = nn.L1Loss()
            self.loss_functions.append(self.crit_L1)

            if self.opt.loss_weight_vgg > 0:
                self.crit_vgg = networks.VGGLoss(self.gpu_ids)
                self.loss_functions.append(self.crit_vgg)

            self.crit_psnr = networks.SmoothLoss(networks.PSNR())
            self.loss_functions.append(self.crit_psnr)
            ###################################
            # create optimizers
            ###################################
            self.schedulers = []
            self.optimizers = []

            # optim_G will optimize parameters of netG and all image encoders (except attr_encoder)
            G_param_groups = [{'params': self.netG.parameters()}]
            for l, net in self.encoders.iteritems():
                G_param_groups.append({'params': net.parameters()})
            self.optim_G = torch.optim.Adam(G_param_groups,
                                            lr=opt.lr,
                                            betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optim_G)
            # optim_D will optimize parameters of netD
            self.optim_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr_D,
                                            betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optim_D)
            # optim_D_aux will optimize parameters of auxiliaryDs
            if len(self.auxiliaryDs) > 0:
                aux_D_param_groups = [{
                    'params': net.parameters()
                } for net in self.auxiliaryDs.values()]
                self.optim_D_aux = torch.optim.Adam(aux_D_param_groups,
                                                    lr=opt.lr_D,
                                                    betas=(opt.beta1,
                                                           opt.beta2))
                self.optimizers.append(self.optim_D_aux)
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))

        # color transformation from std to imagenet
        # img_imagenet = img_std * a + b
        self.trans_std_to_imagenet = {
            'a':
            Variable(self.Tensor([0.5 / 0.229, 0.5 / 0.224, 0.5 / 0.225]),
                     requires_grad=False).view(3, 1, 1),
            'b':
            Variable(self.Tensor([(0.5 - 0.485) / 0.229, (0.5 - 0.456) / 0.224,
                                  (0.5 - 0.406) / 0.225]),
                     requires_grad=False).view(3, 1, 1)
        }
    def initialize(self, opt):
        super(TwoStagePoseTransferModel, self).initialize(opt)
        ###################################
        # load pretrained stage-1 (coarse) network
        ###################################
        self._create_stage_1_net(opt)
        ###################################
        # define stage-2 (refine) network
        ###################################
        # local patch encoder
        if opt.which_model_s2e == 'patch_embed':
            self.netT_s2e = networks.LocalPatchEncoder(
                n_patch=len(opt.patch_indices),
                input_nc=3,
                output_nc=opt.s2e_nof,
                nf=opt.s2e_nf,
                max_nf=opt.s2e_max_nf,
                input_size=opt.patch_size,
                bottleneck_factor=opt.s2e_bottleneck_factor,
                n_residual_blocks=2,
                norm_layer=networks.get_norm_layer(opt.norm),
                activation=nn.ReLU(False),
                use_dropout=False,
                gpu_ids=opt.gpu_ids,
            )
            s2e_nof = opt.s2e_nof
        elif opt.which_model_s2e == 'patch':
            self.netT_s2e = networks.LocalPatchRearranger(
                n_patch=len(opt.patch_indices),
                image_size=opt.fine_size,
            )
            s2e_nof = 3
        elif opt.which_model_s2e == 'seg_embed':
            self.netT_s2e = networks.SegmentRegionEncoder(
                seg_nc=self.opt.seg_nc,
                input_nc=3,
                output_nc=opt.s2e_nof,
                nf=opt.s2d_nf,
                input_size=opt.fine_size,
                n_blocks=3,
                norm_layer=networks.get_norm_layer(opt.norm),
                activation=nn.ReLU,
                use_dropout=False,
                grid_level=opt.s2e_grid_level,
                gpu_ids=opt.gpu_ids,
            )
            s2e_nof = opt.s2e_nof + opt.s2e_grid_level
        else:
            raise NotImplementedError()
        if opt.gpu_ids:
            self.netT_s2e.cuda()

        # decoder
        if self.opt.which_model_s2d == 'resnet':
            self.netT_s2d = networks.ResnetGenerator(
                input_nc=3 + s2e_nof,
                output_nc=3,
                ngf=opt.s2d_nf,
                norm_layer=networks.get_norm_layer(opt.norm),
                activation=nn.ReLU,
                use_dropout=False,
                n_blocks=opt.s2d_nblocks,
                gpu_ids=opt.gpu_ids,
                output_tanh=False,
            )
        elif self.opt.which_model_s2d == 'unet':
            self.netT_s2d = networks.UnetGenerator_v2(
                input_nc=3 + s2e_nof,
                output_nc=3,
                num_downs=8,
                ngf=opt.s2d_nf,
                max_nf=opt.s2d_nf * 2**3,
                norm_layer=networks.get_norm_layer(opt.norm),
                use_dropout=False,
                gpu_ids=opt.gpu_ids,
                output_tanh=False,
            )
        elif self.opt.which_model_s2d == 'rpresnet':
            self.netT_s2d = networks.RegionPropagationResnetGenerator(
                input_nc=3 + s2e_nof,
                output_nc=3,
                ngf=opt.s2d_nf,
                norm_layer=networks.get_norm_layer(opt.norm),
                activation=nn.ReLU,
                use_dropout=False,
                nblocks=opt.s2d_nblocks,
                gpu_ids=opt.gpu_ids,
                output_tanh=False)
        else:
            raise NotImplementedError()
        if opt.gpu_ids:
            self.netT_s2d.cuda()
        ###################################
        # define discriminator
        ###################################
        self.use_GAN = self.is_train and opt.loss_weight_gan > 0
        if self.use_GAN:
            self.netD = networks.define_D_from_params(
                input_nc=3 +
                self.get_pose_dim(opt.pose_type) if opt.D_cond else 3,
                ndf=opt.D_nf,
                which_model_netD='n_layers',
                n_layers_D=opt.D_n_layer,
                norm=opt.norm,
                which_gan=opt.which_gan,
                init_type=opt.init_type,
                gpu_ids=opt.gpu_ids)
        else:
            self.netD = None
        ###################################
        # loss functions
        ###################################
        self.crit_psnr = networks.PSNR()
        self.crit_ssim = networks.SSIM()

        if self.is_train:
            self.optimizers = []
            self.crit_vgg = networks.VGGLoss_v2(self.gpu_ids,
                                                opt.content_layer_weight,
                                                opt.style_layer_weight,
                                                opt.shifted_style)

            self.optim = torch.optim.Adam([{
                'params': self.netT_s2e.parameters()
            }, {
                'params': self.netT_s2d.parameters()
            }],
                                          lr=opt.lr,
                                          betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optim)

            if opt.train_s1:
                self.optim_s1 = torch.optim.Adam(self.netT_s1.parameters(),
                                                 lr=opt.lr_s1,
                                                 betas=(opt.beta1, opt.beta2))
                self.optimizers.append(self.optim_s1)

            if self.use_GAN:
                self.crit_GAN = networks.GANLoss(
                    use_lsgan=opt.which_gan == 'lsgan', tensor=self.Tensor)
                self.optim_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=opt.lr_D,
                                                betas=(opt.beta1, opt.beta2))
                self.optimizers.append(self.optim_D)
                self.fake_pool = ImagePool(opt.pool_size)
        ###################################
        # init/load model
        ###################################
        if self.is_train:
            if not opt.continue_train:
                self.load_network(self.netT_s1, 'netT', 'latest',
                                  self.opt_s1.id)
                networks.init_weights(self.netT_s2e, init_type=opt.init_type)
                networks.init_weights(self.netT_s2d, init_type=opt.init_type)
                if self.use_GAN:
                    networks.init_weights(self.netD, init_type=opt.init_type)
            else:
                self.load_network(self.netT_s1, 'netT_s1', opt.which_epoch)
                self.load_network(self.netT_s2e, 'netT_s2e', opt.which_epoch)
                self.load_network(self.netT_s2d, 'netT_s2d', opt.which_epoch)
                self.load_optim(self.optim, 'optim', opt.which_epoch)
                if self.use_GAN:
                    self.load_network(self.netD, 'netD', opt.which_epoch)
                    self.load_optim(self.optim_D, 'optim_D', opt.which_epoch)
        else:
            self.load_network(self.netT_s1, 'netT_s1', opt.which_epoch)
            self.load_network(self.netT_s2e, 'netT_s2e', opt.which_epoch)
            self.load_network(self.netT_s2d, 'netT_s2d', opt.which_epoch)
        ###################################
        # schedulers
        ###################################
        if self.is_train:
            self.schedulers = []
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))