Ejemplo n.º 1
0
    def initialize(self, opt):
        super(FeatureSpatialTransformer, self).initialize(opt)
        ###################################
        # load/define networks
        ###################################
        self.net = networks.define_feat_spatial_transformer(opt)
        self.netAE = None

        if opt.continue_train or not self.is_train:
            self.load_network(self.net, 'FeatST', epoch_label = opt.which_epoch)

        if self.is_train:
            ###################################
            # load attribute encoder
            ###################################
            self.netAE, self.opt_AE = load_attribute_encoder_net(id=opt.which_model_AE, gpu_ids=opt.gpu_ids)

            ###################################
            # define loss functions and loss buffers
            ###################################
            self.crit_L1 = networks.SmoothLoss(nn.L1Loss())
            self.crit_attr = networks.SmoothLoss(nn.BCELoss())

            self.loss_functions = []
            self.loss_functions.append(self.crit_L1)
            self.loss_functions.append(self.crit_attr)

            ###################################
            # create optimizers
            ###################################
            self.schedulers = []
            self.optimizers = []

            self.optim = torch.optim.Adam(self.net.parameters(), 
                lr = opt.lr, betas = (opt.beta1, opt.beta2), weight_decay = opt.weight_decay)

            self.optimizers.append(self.optim)

            for optim in self.optimizers:
                    self.schedulers.append(networks.get_scheduler(optim, opt))

        # color transformation from std to imagenet
        # img_imagenet = img_std * a + b
        self.trans_std_to_imagenet = {
            'a': Variable(self.Tensor([0.5/0.229, 0.5/0.224, 0.5/0.225]), requires_grad = False).view(3,1,1),
            'b': Variable(self.Tensor([(0.5-0.485)/0.229, (0.5-0.456)/0.224, (0.5-0.406)/0.225]), requires_grad = False).view(3,1,1)
        }
Ejemplo n.º 2
0
    def initialize(self, opt):
        super(EncoderDecoderFramework_V2, self).initialize(opt)
        ###################################
        # define encoder
        ###################################
        self.encoder = networks.define_encoder_v2(opt)
        ###################################
        # define decoder
        ###################################
        self.decoder = networks.define_decoder_v2(opt)
        ###################################
        # guide encoder
        ###################################
        if opt.use_guide_encoder:
            self.guide_encoder = networks.load_encoder_v2(
                opt, opt.which_model_guide)
            self.guide_encoder.eval()
            for p in self.guide_encoder.parameters():
                p.requires_grad = False
        ###################################
        # loss functions
        ###################################
        self.loss_functions = []
        self.schedulers = []
        self.crit_image = networks.SmoothLoss(nn.L1Loss())
        self.crit_seg = networks.SmoothLoss(nn.CrossEntropyLoss())
        self.crit_edge = networks.SmoothLoss(nn.BCELoss())
        self.loss_functions += [self.crit_image, self.crit_seg, self.crit_edge]

        self.optim = torch.optim.Adam([{
            'params': self.encoder.parameters()
        }, {
            'params': self.decoder.parameters()
        }],
                                      lr=opt.lr,
                                      betas=(opt.beta1, opt.beta2))
        self.optimizers = [self.optim]
        for optim in self.optimizers:
            self.schedulers.append(networks.get_scheduler(optim, opt))
    def initialize(self, opt):
        super(MultimodalDesignerGAN_V2, self).initialize(opt)
        ###################################
        # define networks
        ###################################
        self.modules = {}
        # shape branch
        if opt.which_model_netG != 'unet':
            self.shape_encoder = networks.define_image_encoder(opt, 'shape')
            self.modules['shape_encoder'] = self.shape_encoder
        else:
            self.shape_encoder = None
        # edge branch
        if opt.use_edge:
            self.edge_encoder = networks.define_image_encoder(opt, 'edge')
            self.modules['edge_encoder'] = self.edge_encoder
        else:
            self.encoder_edge = None
        # color branch
        if opt.use_color:
            self.color_encoder = networks.define_image_encoder(opt, 'color')
            self.modules['color_encoder'] = self.color_encoder
        else:
            self.color_encoder = None

        # fusion model
        if opt.ftn_model == 'none':
            # shape_feat, edge_feat and color_feat will be simply upmpled to same size (size of shape_feat) and concatenated
            pass
        elif opt.ftn_model == 'concat':
            assert opt.use_edge or opt.use_color
            if opt.use_edge:
                self.edge_trans_net = networks.define_feature_fusion_network(
                    name='FeatureConcatNetwork',
                    feat_nc=opt.edge_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['edge_trans_net'] = self.edge_trans_net
            if opt.use_color:
                self.color_trans_net = networks.define_feature_fusion_network(
                    name='FeatureConcatNetwork',
                    feat_nc=opt.color_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['color_trans_net'] = self.color_trans_net
        elif opt.ftn_model == 'reduce':
            assert opt.use_edge or opt.use_color
            if opt.use_edge:
                self.edge_trans_net = networks.define_feature_fusion_network(
                    name='FeatureReduceNetwork',
                    feat_nc=opt.edge_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    ndowns=opt.ftn_ndowns,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['edge_trans_net'] = self.edge_trans_net
            if opt.use_color:
                self.color_trans_net = networks.define_feature_fusion_network(
                    name='FeatureReduceNetwork',
                    feat_nc=opt.color_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    ndowns=opt.ftn_ndowns,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['color_trans_net'] = self.color_trans_net

        elif opt.ftn_model == 'trans':
            assert opt.use_edge or opt.use_color
            if opt.use_edge:
                self.edge_trans_net = networks.define_feature_fusion_network(
                    name='FeatureTransformNetwork',
                    feat_nc=opt.edge_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    feat_size=opt.feat_size_lr,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['edge_trans_net'] = self.edge_trans_net
            if opt.use_color:
                self.color_trans_net = networks.define_feature_fusion_network(
                    name='FeatureTransformNetwork',
                    feat_nc=opt.color_nof,
                    guide_nc=opt.shape_nof,
                    nblocks=opt.ftn_nblocks,
                    feat_size=opt.feat_size_lr,
                    norm=opt.norm,
                    gpu_ids=self.gpu_ids,
                    init_type=opt.init_type)
                self.modules['color_trans_net'] = self.color_trans_net

        # netG
        self.netG = networks.define_generator(opt)
        self.modules['netG'] = self.netG

        # netD
        if self.is_train:
            self.netD = networks.define_D(opt)
            self.modules['netD'] = self.netD

        ###################################
        # load weights
        ###################################
        if self.is_train:
            if opt.continue_train:
                for label, net in self.modules.iteritems():
                    self.load_network(net, label, opt.which_epoch)
            else:
                if opt.which_model_init != 'none':
                    # load pretrained entire model
                    for label, net in self.modules.iteritems():
                        self.load_network(net,
                                          label,
                                          'latest',
                                          opt.which_model_init,
                                          forced=False)
                else:
                    # load pretrained encoder
                    if opt.which_model_netG != 'unet' and opt.pretrain_shape:
                        self.load_network(self.shape_encoder, 'shape_encoder',
                                          'latest',
                                          opt.which_model_init_shape_encoder)
                    if opt.use_edge and opt.pretrain_edge:
                        self.load_network(self.edge_encoder, 'edge_encoder',
                                          'latest',
                                          opt.which_model_init_edge_encoder)
                    if opt.use_color and opt.pretrain_color:
                        self.load_network(self.color_encoder, 'color_encoder',
                                          'latest',
                                          opt.which_model_init_color_encoder)
        else:
            for label, net in self.modules.iteritems():
                if label != 'netD':
                    self.load_network(net, label, opt.which_epoch)

        ###################################
        # prepare for training
        ###################################
        if self.is_train:
            self.fake_pool = ImagePool(opt.pool_size)
            ###################################
            # define loss functions
            ###################################
            self.loss_functions = []
            if opt.which_gan in {'dcgan', 'lsgan'}:
                self.crit_GAN = networks.GANLoss(
                    use_lsgan=opt.which_gan == 'lsgan', tensor=self.Tensor)
                self.loss_functions.append(self.crit_GAN)
            else:
                # WGAN loss will be calculated in self.backward_D_wgangp and self.backward_G
                self.crit_GAN = None

            self.crit_L1 = nn.L1Loss()
            self.loss_functions.append(self.crit_L1)

            if self.opt.loss_weight_vgg > 0:
                self.crit_vgg = networks.VGGLoss(self.gpu_ids)
                self.loss_functions.append(self.crit_vgg)

            if self.opt.G_output_seg:
                self.crit_CE = nn.CrossEntropyLoss()
                self.loss_functions.append(self.crit_CE)

            self.crit_psnr = networks.SmoothLoss(networks.PSNR())
            self.loss_functions.append(self.crit_psnr)
            ###################################
            # create optimizers
            ###################################
            self.schedulers = []
            self.optimizers = []

            # G optimizer
            G_module_list = [
                'shape_encoder', 'edge_encoder', 'color_encoder', 'netG'
            ]
            G_param_groups = [{
                'params': self.modules[m].parameters()
            } for m in G_module_list if m in self.modules]
            self.optim_G = torch.optim.Adam(G_param_groups,
                                            lr=opt.lr,
                                            betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optim_G)
            # D optimizer
            self.optim_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr_D,
                                            betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optim_D)
            # feature transfer network optimizer
            FTN_module_list = ['edge_trans_net', 'color_trans_net']
            FTN_param_groups = [{
                'params': self.modules[m].parameters()
            } for m in FTN_module_list if m in self.modules]
            if len(FTN_param_groups) > 0:
                self.optim_FTN = torch.optim.Adam(FTN_param_groups,
                                                  lr=opt.lr_FTN,
                                                  betas=(0.9, 0.999))
                self.optimizers.append(self.optim_FTN)
            else:
                self.optim_FTN = None
            # schedulers
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))
Ejemplo n.º 4
0
    def initialize(self, opt):
        super(EncoderDecoderFramework, self).initialize(opt)
        ###################################
        # load/define networks
        ###################################
        if opt.use_shape:
            self.encoder_type = 'shape'
            self.encoder_name = 'shape_encoder'
            self.decoder_name = 'decoder'
        elif opt.use_edge:
            self.encoder_type = 'edge'
            self.encoder_name = 'edge_encoder'
            self.decoder_name = 'decoder'
        elif opt.use_color:
            self.encoder_type = 'color'
            self.encoder_name = 'color_encoder'
            self.decoder_name = 'decoder'
        else:
            raise ValueError(
                'either use_shape, use_edge, use_color should be set')

        # encoder
        self.encoder = networks.define_image_encoder(opt, self.encoder_type)

        # decoder
        if self.encoder_type == 'shape':
            ndowns = opt.shape_ndowns
            nf = opt.shape_nf
            nof = opt.shape_nof
            output_nc = 7
            output_activation = None
            assert opt.decode_guided == False
        elif self.encoder_type == 'edge':
            ndowns = opt.edge_ndowns
            nf = opt.edge_nf
            nof = opt.edge_nof
            output_nc = 1
            output_activation = None
        elif self.encoder_type == 'color':
            ndowns = opt.color_ndowns
            nf = opt.color_nf
            nof = opt.color_nof
            output_nc = 3
            output_activation = nn.Tanh

        if opt.encoder_type in {'normal', 'st'}:
            self.feat_size = 256 // 2**(opt.edge_ndowns)
            self.mid_feat_size = self.feat_size
        else:
            self.feat_size = 1
            self.mid_feat_size = 8

        self.use_concat_net = False
        if opt.decode_guided:
            if self.feat_size > 1:
                self.decoder = networks.define_image_decoder_from_params(
                    input_nc=nof + opt.shape_nc,
                    output_nc=output_nc,
                    nf=nf,
                    num_ups=ndowns,
                    norm=opt.norm,
                    output_activation=output_activation,
                    gpu_ids=opt.gpu_ids,
                    init_type=opt.init_type)
            else:
                self.decoder = networks.define_image_decoder_from_params(
                    input_nc=nof,
                    output_nc=output_nc,
                    nf=nf,
                    num_ups=5,
                    norm=opt.norm,
                    output_activation=output_activation,
                    gpu_ids=opt.gpu_ids,
                    init_type=opt.init_type)
                self.concat_net = networks.FeatureConcatNetwork(
                    feat_nc=nof,
                    guide_nc=opt.shape_nc,
                    nblocks=3,
                    norm=opt.norm,
                    gpu_ids=opt.gpu_ids)
                if len(self.gpu_ids) > 0:
                    self.concat_net.cuda()
                networks.init_weights(self.concat_net, opt.init_type)
                self.use_concat_net = True
                print('encoder_decoder contains a feature_concat_network!')
        else:
            if self.feat_size > 1:
                self.decoder = networks.define_image_decoder_from_params(
                    input_nc=nof,
                    output_nc=output_nc,
                    nf=nf,
                    num_ups=ndowns,
                    norm=opt.norm,
                    output_activation=output_activation,
                    gpu_ids=opt.gpu_ids,
                    init_type=opt.init_type)
            else:
                self.decoder = networks.define_image_decoder_from_params(
                    input_nc=nof,
                    output_nc=output_nc,
                    nf=nf,
                    num_ups=8,
                    norm=opt.norm,
                    output_activation=output_activation,
                    gpu_ids=opt.gpu_ids,
                    init_type=opt.init_type)

        if not self.is_train or (self.is_train and self.opt.continue_train):
            self.load_network(self.encoder, self.encoder_name, opt.which_opoch)
            self.load_network(self.decoder, self.decoder_name, opt.which_opoch)
            if self.use_concat_net:
                self.load_network(self.concat_net, 'concat_net',
                                  opt.which_opoch)

        # loss functions
        self.loss_functions = []
        self.schedulers = []
        self.crit_L1 = networks.SmoothLoss(nn.L1Loss())
        self.crit_CE = networks.SmoothLoss(nn.CrossEntropyLoss())
        self.loss_functions += [self.crit_L1, self.crit_CE]

        self.optim = torch.optim.Adam([{
            'params': self.encoder.parameters()
        }, {
            'params': self.decoder.parameters()
        }],
                                      lr=opt.lr,
                                      betas=(opt.beta1, opt.beta2))
        self.optimizers = [self.optim]
        for optim in self.optimizers:
            self.schedulers.append(networks.get_scheduler(optim, opt))
Ejemplo n.º 5
0
    def initialize(self, opt):
        super(AttributeEncoder, self).initialize(opt)

        # define tensors
        self.input['img'] = self.Tensor(opt.batch_size, opt.input_nc,
                                        opt.fine_size, opt.fine_size)
        self.input['label'] = self.Tensor(opt.batch_size, opt.n_attr)

        # load/define networks
        self.net = networks.define_attr_encoder_net(opt)

        if not self.is_train or opt.continue_train:
            self.load_network(self.net,
                              network_label='AE',
                              epoch_label=opt.which_epoch)

        self.schedulers = []
        self.optimizers = []
        self.loss_functions = []

        # define loss functions
        # attribute
        if opt.loss_type == 'bce':
            self.crit_attr = networks.SmoothLoss(
                torch.nn.BCELoss(size_average=not opt.no_size_avg))
        elif opt.loss_type == 'wbce':
            attr_entry = io.load_json(os.path.join(opt.data_root,
                                                   opt.fn_entry))
            pos_rate = self.Tensor([att['pos_rate'] for att in attr_entry])
            pos_rate.clamp_(min=0.01, max=0.99)
            self.crit_attr = networks.SmoothLoss(
                networks.WeightedBCELoss(pos_rate=pos_rate,
                                         class_norm=opt.wbce_class_norm,
                                         size_average=not opt.no_size_avg))
        else:
            raise ValueError('attribute loss type "%s" is not defined' %
                             opt.loss_type)
        self.loss_functions.append(self.crit_attr)

        # joint task
        if opt.joint_cat:
            self.crit_cat = networks.SmoothLoss(torch.nn.CrossEntropyLoss())
            self.loss_functions.append(self.crit_cat)

        # initialize optimizers
        if opt.is_train:
            if opt.optim == 'adam':
                self.optim_attr = torch.optim.Adam(
                    self.net.parameters(),
                    lr=opt.lr,
                    betas=(opt.beta1, 0.999),
                    weight_decay=opt.weight_decay)
            elif opt.optim == 'sgd':
                self.optim_attr = torch.optim.SGD(
                    self.net.parameters(),
                    lr=opt.lr,
                    momentum=0.9,
                    weight_decay=opt.weight_decay)
            self.optimizers.append(self.optim_attr)

            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))
    def initialize(self, opt):
        super(MultimodalDesignerGAN, self).initialize(opt)
        ###################################
        # load/define networks
        ###################################

        # basic G
        self.netG = networks.define_G(opt)

        # encoders
        self.encoders = {}
        if opt.use_edge:
            self.edge_encoder = networks.define_image_encoder(opt, 'edge')
            self.encoders['edge_encoder'] = self.edge_encoder
        if opt.use_color:
            self.color_encoder = networks.define_image_encoder(opt, 'color')
            self.encoders['color_encoder'] = self.color_encoder
        if opt.use_attr:
            self.attr_encoder, self.opt_AE = network_loader.load_attribute_encoder_net(
                id=opt.which_model_AE, gpu_ids=opt.gpu_ids)

        # basic D and auxiliary Ds
        if self.is_train:
            # basic D
            self.netD = networks.define_D(opt)
            # auxiliary Ds
            self.auxiliaryDs = {}
            if opt.use_edge_D:
                assert opt.use_edge
                self.netD_edge = networks.define_D_from_params(
                    input_nc=opt.edge_nof + 3,
                    ndf=opt.ndf,
                    which_model_netD=opt.which_model_netD,
                    n_layers_D=opt.n_layers_D,
                    norm=opt.norm,
                    which_gan='dcgan',
                    init_type=opt.init_type,
                    gpu_ids=opt.gpu_ids)
                self.auxiliaryDs['D_edge'] = self.netD_edge
            if opt.use_color_D:
                assert opt.use_color
                self.netD_color = networks.define_D_from_params(
                    input_nc=opt.color_nof + 3,
                    ndf=opt.ndf,
                    which_model_netD=opt.which_model_netD,
                    n_layers_D=opt.n_layers_D,
                    norm=opt.norm,
                    which_gan='dcgan',
                    init_type=opt.init_type,
                    gpu_ids=opt.gpu_ids)
                self.auxiliaryDs['D_color'] = self.netD_color
            if opt.use_attr_D:
                assert opt.use_attr
                attr_nof = opt.n_attr_feat if opt.attr_cond_type in {
                    'feat', 'feat_map'
                } else opt.n_attr
                self.netD_attr = networks.define_D_from_params(
                    input_nc=attr_nof + 3,
                    ndf=opt.ndf,
                    which_model_netD=opt.which_model_netD,
                    n_layers_D=opt.n_layers_D,
                    norm=opt.norm,
                    which_gan='dcgan',
                    init_type=opt.init_type,
                    gpu_ids=opt.gpu_ids)
                self.auxiliaryDs['D_attr'] = self.netD_attr
            # load weights
            if not opt.continue_train:
                if opt.which_model_init != 'none':
                    self.load_network(self.netG, 'G', 'latest',
                                      opt.which_model_init)
                    self.load_network(self.netD, 'D', 'latest',
                                      opt.which_model_init)
                    for l, net in self.encoders.iteritems():
                        self.load_network(net, l, 'latest',
                                          opt.which_model_init)
                    for l, net in self.auxiliaryDs.iteritems():
                        self.load_network(net, l, 'latest',
                                          opt.which_model_init)
            else:
                self.load_network(self.netG, 'G', opt.which_epoch)
                self.load_network(self.netD, 'D', opt.which_epoch)
                for l, net in self.encoders.iteritems():
                    self.load_network(net, l, opt.which_epoch)
                for l, net in self.auxiliaryDs.iteritems():
                    self.load_network(net, l, opt.which_epoch)
        else:
            self.load_network(self.netG, 'G', opt.which_epoch)
            for l, net in self.encoders.iteritems():
                self.load_network(net, l, opt.which_epoch)

        if self.is_train:
            self.fake_pool = ImagePool(opt.pool_size)
            ###################################
            # define loss functions and loss buffers
            ###################################
            self.loss_functions = []
            if opt.which_gan in {'dcgan', 'lsgan'}:
                self.crit_GAN = networks.GANLoss(
                    use_lsgan=opt.which_gan == 'lsgan', tensor=self.Tensor)
            else:
                # WGAN loss will be calculated in self.backward_D_wgangp and self.backward_G
                self.crit_GAN = None

            self.loss_functions.append(self.crit_GAN)

            self.crit_L1 = nn.L1Loss()
            self.loss_functions.append(self.crit_L1)

            if self.opt.loss_weight_vgg > 0:
                self.crit_vgg = networks.VGGLoss(self.gpu_ids)
                self.loss_functions.append(self.crit_vgg)

            self.crit_psnr = networks.SmoothLoss(networks.PSNR())
            self.loss_functions.append(self.crit_psnr)
            ###################################
            # create optimizers
            ###################################
            self.schedulers = []
            self.optimizers = []

            # optim_G will optimize parameters of netG and all image encoders (except attr_encoder)
            G_param_groups = [{'params': self.netG.parameters()}]
            for l, net in self.encoders.iteritems():
                G_param_groups.append({'params': net.parameters()})
            self.optim_G = torch.optim.Adam(G_param_groups,
                                            lr=opt.lr,
                                            betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optim_G)
            # optim_D will optimize parameters of netD
            self.optim_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opt.lr_D,
                                            betas=(opt.beta1, opt.beta2))
            self.optimizers.append(self.optim_D)
            # optim_D_aux will optimize parameters of auxiliaryDs
            if len(self.auxiliaryDs) > 0:
                aux_D_param_groups = [{
                    'params': net.parameters()
                } for net in self.auxiliaryDs.values()]
                self.optim_D_aux = torch.optim.Adam(aux_D_param_groups,
                                                    lr=opt.lr_D,
                                                    betas=(opt.beta1,
                                                           opt.beta2))
                self.optimizers.append(self.optim_D_aux)
            for optim in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optim, opt))

        # color transformation from std to imagenet
        # img_imagenet = img_std * a + b
        self.trans_std_to_imagenet = {
            'a':
            Variable(self.Tensor([0.5 / 0.229, 0.5 / 0.224, 0.5 / 0.225]),
                     requires_grad=False).view(3, 1, 1),
            'b':
            Variable(self.Tensor([(0.5 - 0.485) / 0.229, (0.5 - 0.456) / 0.224,
                                  (0.5 - 0.406) / 0.225]),
                     requires_grad=False).view(3, 1, 1)
        }