Пример #1
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = 0.5
        beta2 = 0.999
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = lr_scheduler.StepLR(self.dis_opt, step_size=hyperparameters['step_size'],
                                        gamma=hyperparameters['gamma'], last_epoch=-1)
        self.gen_scheduler = lr_scheduler.StepLR(self.gen_opt, step_size=hyperparameters['step_size'],
                                        gamma=hyperparameters['gamma'], last_epoch=-1)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))
Пример #2
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.criterionGAN = GANLoss(hyperparameters['dis']['gan_type']).cuda()
        self.featureLoss = nn.MSELoss(reduction='mean')
        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))
Пример #3
0
    def __init__(self, hyperparameters):
        super(DSMAP_Trainer, self).__init__()
        # Initiate the networks
        mid_downsample = hyperparameters['gen'].get('mid_downsample', 1)
        self.content_enc = ContentEncoder_share(
            hyperparameters['gen']['n_downsample'],
            mid_downsample,
            hyperparameters['gen']['n_res'],
            hyperparameters['input_dim_a'],
            hyperparameters['gen']['dim'],
            'in',
            hyperparameters['gen']['activ'],
            pad_type=hyperparameters['gen']['pad_type'])

        self.style_dim = hyperparameters['gen']['style_dim']
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'], self.content_enc, 'a',
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'], self.content_enc, 'b',
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'], self.content_enc.output_dim,
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'], self.content_enc.output_dim,
            hyperparameters['dis'])  # discriminator for domain b
Пример #4
0
    def __init__(self, hyperparameters, device):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        self.device = device
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1,
                               1).to(self.device)
        self.s_b = torch.randn(display_size, self.style_dim, 1,
                               1).to(self.device)

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
Пример #5
0
    def __init__(self, hyperparameters):
        super(aclgan_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_AB = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain A
        self.gen_BA = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain B
        self.dis_A = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain A
        self.dis_B = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain B
        self.dis_2 = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator 2
#        self.dis_2B = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator 2 for domain B
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.z_1 = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.z_2 = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.z_3 = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_A.parameters()) + list(self.dis_B.parameters()) + list(self.dis_2.parameters())
        gen_params = list(self.gen_AB.parameters()) + list(self.gen_BA.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.alpha = hyperparameters['alpha']
        self.focus_lam = hyperparameters['focus_loss']

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_A.apply(weights_init('gaussian'))
        self.dis_B.apply(weights_init('gaussian'))
        self.dis_2.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
Пример #6
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()   # super() 函数是用于调用父类(超类)的一个方法。
        lr = hyperparameters['lr']
        # Initiate the networks, 需要好好看看生成器和鉴别器到底是如何构造的
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        # https://blog.csdn.net/liuxiao214/article/details/81037416
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        # s_a , s_b 表示的是两个不同的style
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()  # 16*8*1*1
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        # 两个鉴别器
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        # 两个生成器
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        # 优化器
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        # 优化策略
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        # 解释 apply  apply(lambda x,y : x+y, (1),{'y' : 2})   https://zhuanlan.zhihu.com/p/42756654
        self.apply(weights_init(hyperparameters['init']))  # 初始化当前类
        self.dis_a.apply(weights_init('gaussian'))   # 初始化dis_a,是一个类对象
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
Пример #7
0
 def __init__(self, hyperparameters):
     super(MUNIT, self).__init__()
     lr = hyperparameters['lr']
     # Initiate the networks
     self.gen_a = AdaINGen(
         hyperparameters['input_dim_a'],
         hyperparameters['gen'])  # auto-encoder for domain a
     self.gen_b = AdaINGen(
         hyperparameters['input_dim_b'],
         hyperparameters['gen'])  # auto-encoder for domain b
     self.dis_a = MsImageDis(
         hyperparameters['input_dim_a'],
         hyperparameters['dis'])  # discriminator for domain a
     self.dis_b = MsImageDis(
         hyperparameters['input_dim_b'],
         hyperparameters['dis'])  # discriminator for domain b
     self.instancenorm = nn.InstanceNorm2d(512, affine=False)
     self.style_dim = hyperparameters['gen']['style_dim']
Пример #8
0
    def build_model(self):
        self.G = AdaINGen(self.label_dim, self.gen_var)
        self.D = Discriminator(self.label_dim, self.dis_var)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.lr,
                                            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.lr,
                                            [self.beta1, self.beta2])

        self.G.to(self.device)
        self.D.to(self.device)
Пример #9
0
    def __init__(self, hyperparameters, opts):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        self.opts = opts

        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.seg = segmentor(num_classes=2, channels=hyperparameters['input_dim_b'], hyperpars=hyperparameters['seg'])

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.seg_opt = torch.optim.SGD(self.seg.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters['lr_policy'], hyperparameters=hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters['lr_policy'], hyperparameters=hyperparameters)
        self.seg_scheduler = get_scheduler(self.seg_opt, 'constant', hyperparameters=None)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        self.criterion_seg = DiceLoss(ignore_index=hyperparameters['seg']['ignore_index'])
Пример #10
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'], hyperparameters['new_size'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        self.reg_param = hyperparameters['reg_param']
        self.beta_step = hyperparameters['beta_step']
        self.target_kl = hyperparameters['target_kl']
        self.gan_type = hyperparameters['gan_type']

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_b.parameters())
        gen_params = list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.gen_b.apply(weights_init(hyperparameters['init']))
        self.dis_b.apply(weights_init('gaussian'))

        # SSIM Loss
        self.ssim_loss = pytorch_ssim.SSIM()
Пример #11
0
    def __init__(self, hyperparameters):
        super(AGUIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.noise_dim = hyperparameters['gen']['noise_dim']
        self.attr_dim = len(hyperparameters['gen']['selected_attrs'])
        self.gen = AdaINGen(hyperparameters['input_dim'],
                            hyperparameters['gen'])
        self.dis = MsImageDis(hyperparameters['input_dim'], self.attr_dim,
                              hyperparameters['dis'])
        self.dis_content = ContentDis(
            hyperparameters['gen']['dim'] *
            (2**hyperparameters['gen']['n_downsample']), self.attr_dim)

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis.parameters()) + list(
            self.dis_content.parameters())
        gen_params = list(self.gen.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis.apply(weights_init('gaussian'))
        self.dis_content.apply(weights_init('gaussian'))
Пример #12
0
    def __init__(self, param):
        super(SupIntrinsicTrainer, self).__init__()
        lr = param['lr']
        # Initiate the networks
        self.model = AdaINGen(param['input_dim_a'],
                              param['input_dim_b'] + param['input_dim_b'],
                              param['gen'])  # auto-encoder

        # Setup the optimizers
        beta1 = param['beta1']
        beta2 = param['beta2']
        gen_params = list(self.model.parameters())
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=param['weight_decay'])
        self.gen_scheduler = get_scheduler(self.gen_opt, param)

        # Network weight initialization
        self.apply(weights_init(param['init']))
        self.best_result = float('inf')
Пример #13
0
    def __init__(self, hyperparameters, gpu_ids=[0]):
        super(DGNet_Trainer, self).__init__()
        lr_g = hyperparameters['lr_g']  #生成器学习率
        lr_d = hyperparameters['lr_d']  #判别器学习率
        ID_class = hyperparameters['ID_class']
        if not 'apex' in hyperparameters.keys():
            hyperparameters['apex'] = False
        self.fp16 = hyperparameters['apex']
        # Initiate the networks
        # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False.
        # 构建Es编码+解码过程 gen_a.encode()可以进行编码,gen_b.encode()可以进行解码
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'], fp16 = False)  # auto-encoder for domain a
        self.gen_b = self.gen_a  # auto-encoder for domain b
        
        # ID_stride,外观编码器池化层的stride
        if not 'ID_stride' in hyperparameters.keys():
            hyperparameters['ID_stride'] = 2
        # 构建外观编码器
        if hyperparameters['ID_style']=='PCB':
            self.id_a = PCB(ID_class)
        elif hyperparameters['ID_style']=='AB': #使用的AB编码器
            self.id_a = ft_netAB(ID_class, stride = hyperparameters['ID_stride'], norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) 
        else:
            self.id_a = ft_net(ID_class, norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) # return 2048 now
        # 浅拷贝,两者等同
        self.id_b = self.id_a
        
        # 鉴别器,使用的是一个多尺寸的鉴别器,即对图片进行几次缩放,并且对每次缩放都会预测,计算总的损失
         
        self.dis_a = MsImageDis(3, hyperparameters['dis'], fp16 = False)  # discriminator for domain a
        self.dis_b = self.dis_a # discriminator for domain b

        # load teachers 加载教师模型
        if hyperparameters['teacher'] != "":
            teacher_name = hyperparameters['teacher']
            print(teacher_name)
            # 构建教师模型
            teacher_names = teacher_name.split(',')
            teacher_model = nn.ModuleList()
            teacher_count = 0
      
            for teacher_name in teacher_names:
                config_tmp = load_config(teacher_name)
                if 'stride' in config_tmp:
                    stride = config_tmp['stride'] 
                else:
                    stride = 2
                # 网络搭建
                model_tmp = ft_net(ID_class, stride = stride)
                teacher_model_tmp = load_network(model_tmp, teacher_name)
                teacher_model_tmp.model.fc = nn.Sequential()  # remove the original fc layer in ImageNet
                teacher_model_tmp = teacher_model_tmp.cuda()
                if self.fp16:
                    teacher_model_tmp = amp.initialize(teacher_model_tmp, opt_level="O1")
                teacher_model.append(teacher_model_tmp.cuda().eval())
                teacher_count +=1
            self.teacher_model = teacher_model
            # 选择是否使用bn
            if hyperparameters['train_bn']:
                self.teacher_model = self.teacher_model.apply(train_bn)
        # 实例正则化
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # RGB to one channel
        if hyperparameters['single']=='edge':
            self.single = to_edge
        else:
            self.single = to_gray(False)

        # Random Erasing when training
        if not 'erasing_p' in hyperparameters.keys():
            self.erasing_p = 0
        else:
            self.erasing_p = hyperparameters['erasing_p']  # erasing_p表示随机擦除的概率
        # 随机擦除矩形区域的一些像素,数据增强
        self.single_re = RandomErasing(probability = self.erasing_p, mean=[0.0, 0.0, 0.0])

        if not 'T_w' in hyperparameters.keys():
            hyperparameters['T_w'] = 1
        # Setup the optimizers 设置优化器的参数
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) #+ list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) #+ list(self.gen_b.parameters())
        # 使用Adam优化器
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        # id params
        if hyperparameters['ID_style']=='PCB':
            ignored_params = (list(map(id, self.id_a.classifier0.parameters() ))
                            +list(map(id, self.id_a.classifier1.parameters() ))
                            +list(map(id, self.id_a.classifier2.parameters() ))
                            +list(map(id, self.id_a.classifier3.parameters() ))
                            )
            base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD([
                 {'params': base_params, 'lr': lr2},
                 {'params': self.id_a.classifier0.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier1.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier2.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier3.parameters(), 'lr': lr2*10}
            ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)
        elif hyperparameters['ID_style']=='AB':
            ignored_params = (list(map(id, self.id_a.classifier1.parameters()))
                            + list(map(id, self.id_a.classifier2.parameters())))
            # 获得基本的配置参数,如学习率
            base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD([
                 {'params': base_params, 'lr': lr2},
                 {'params': self.id_a.classifier1.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier2.parameters(), 'lr': lr2*10}
            ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)
        else:
            ignored_params = list(map(id, self.id_a.classifier.parameters() ))
            base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD([
                 {'params': base_params, 'lr': lr2},
                 {'params': self.id_a.classifier.parameters(), 'lr': lr2*10}
            ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)
        
        # 选择各个网络优化的策略
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.id_scheduler = get_scheduler(self.id_opt, hyperparameters)
        self.id_scheduler.gamma = hyperparameters['gamma2']

        #ID Loss
        self.id_criterion = nn.CrossEntropyLoss()
        self.criterion_teacher = nn.KLDivLoss(size_average=False)
        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        # save memory
        if self.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.gen_a = self.gen_a.cuda()
            self.dis_a = self.dis_a.cuda()
            self.id_a = self.id_a.cuda()

            self.gen_b = self.gen_a
            self.dis_b = self.dis_a
            self.id_b = self.id_a

            self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1")
            self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1")
            self.id_a, self.id_opt = amp.initialize(self.id_a, self.id_opt, opt_level="O1")
Пример #14
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters["lr"]
        self.gen_state = hyperparameters["gen_state"]
        self.guided = hyperparameters["guided"]
        self.newsize = hyperparameters["crop_image_height"]
        self.semantic_w = hyperparameters["semantic_w"] > 0

        self.recon_mask = hyperparameters["recon_mask"] == 1
        self.check_alignment = hyperparameters["check_alignment"] == 1

        self.full_adaptation = hyperparameters["full_adaptation"] == 1
        self.dann_scheduler = None
        self.full_adaptation = hyperparameters["full_adaptation"] == 1

        if "domain_adv_w" in hyperparameters.keys():
            self.domain_classif = hyperparameters["domain_adv_w"] > 0
        else:
            self.domain_classif = False
        if self.gen_state == 0:
            # Initiate the networks
            self.gen_a = AdaINGen(
                hyperparameters["input_dim_a"],
                hyperparameters["gen"])  # auto-encoder for domain a
            self.gen_b = AdaINGen(
                hyperparameters["input_dim_b"],
                hyperparameters["gen"])  # auto-encoder for domain b

        elif self.gen_state == 1:
            self.gen = AdaINGen_double(hyperparameters["input_dim_a"],
                                       hyperparameters["gen"])
        else:
            print("self.gen_state unknown value:", self.gen_state)

        self.dis_a = MsImageDis(
            hyperparameters["input_dim_a"],
            hyperparameters["dis"])  # discriminator for domain a

        self.dis_b = MsImageDis(
            hyperparameters["input_dim_b"],
            hyperparameters["dis"])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters["gen"]["style_dim"]

        # fix the noise used in sampling
        display_size = int(hyperparameters["display_size"])
        print(self.style_dim)
        print(display_size)
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        # Setup the optimizers
        beta1 = hyperparameters["beta1"]
        beta2 = hyperparameters["beta2"]

        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())

        if self.gen_state == 0:
            gen_params = list(self.gen_a.parameters()) + list(
                self.gen_b.parameters())
        elif self.gen_state == 1:
            gen_params = list(self.gen.parameters())
        else:
            print("self.gen_state unknown value:", self.gen_state)

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters["weight_decay"],
        )
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters["weight_decay"],
        )
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters["init"]))
        self.dis_a.apply(weights_init("gaussian"))
        self.dis_b.apply(weights_init("gaussian"))

        # Load VGG model if needed
        if "vgg_w" in hyperparameters.keys() and hyperparameters["vgg_w"] > 0:
            self.vgg = load_vgg16(hyperparameters["vgg_model_path"] +
                                  "/models")
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        # Load semantic segmentation model if needed
        if "semantic_w" in hyperparameters.keys(
        ) and hyperparameters["semantic_w"] > 0:
            self.segmentation_model = load_segmentation_model(
                hyperparameters["semantic_ckpt_path"])
            self.segmentation_model.eval()
            for param in self.segmentation_model.parameters():
                param.requires_grad = False

        # Load domain classifier if needed
        if ("domain_adv_w" in hyperparameters.keys()
                and hyperparameters["domain_adv_w"] > 0):
            self.domain_classifier = domainClassifier(256)
            dann_params = list(self.domain_classifier.parameters())
            self.dann_opt = torch.optim.Adam(
                [p for p in dann_params if p.requires_grad],
                lr=lr,
                betas=(beta1, beta2),
                weight_decay=hyperparameters["weight_decay"],
            )
            self.domain_classifier.apply(weights_init("gaussian"))
            self.dann_scheduler = get_scheduler(self.dann_opt, hyperparameters)
Пример #15
0
    def __init__(self, hyperparameters, resume_epoch=-1, snapshot_dir=None):

        super(MUNIT_Trainer, self).__init__()

        lr = hyperparameters['lr']

        # Initiate the networks.
        self.gen = AdaINGen(
            hyperparameters['input_dim'] + hyperparameters['n_datasets'],
            hyperparameters['gen'],
            hyperparameters['n_datasets'])  # Auto-encoder for domain a.
        self.dis = MsImageDis(
            hyperparameters['input_dim'] + hyperparameters['n_datasets'],
            hyperparameters['dis'])  # Discriminator for domain a.

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.beta1 = hyperparameters['beta1']
        self.beta2 = hyperparameters['beta2']
        self.weight_decay = hyperparameters['weight_decay']

        # Initiating and loader pretrained UNet.
        self.sup = UNet(input_channels=hyperparameters['input_dim'],
                        num_classes=2).cuda()

        # Fix the noise used in sampling.
        self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda()

        # Setup the optimizers.
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']

        dis_params = list(self.dis.parameters())
        gen_params = list(self.gen.parameters()) + list(self.sup.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(self.beta1, self.beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(self.beta1, self.beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization.
        self.apply(weights_init(hyperparameters['init']))
        self.dis.apply(weights_init('gaussian'))

        # Presetting one hot encoding vectors.
        self.one_hot_img = torch.zeros(hyperparameters['n_datasets'],
                                       hyperparameters['batch_size'],
                                       hyperparameters['n_datasets'], 256,
                                       256).cuda()
        self.one_hot_c = torch.zeros(hyperparameters['n_datasets'],
                                     hyperparameters['batch_size'],
                                     hyperparameters['n_datasets'], 64,
                                     64).cuda()

        for i in range(hyperparameters['n_datasets']):
            self.one_hot_img[i, :, i, :, :].fill_(1)
            self.one_hot_c[i, :, i, :, :].fill_(1)

        if resume_epoch != -1:

            self.resume(snapshot_dir, hyperparameters)
Пример #16
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b

        self.content_classifier = ContentClassifier(
            hyperparameters['gen']['dim'], hyperparameters)

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.num_con_c = hyperparameters['dis']['num_con_c']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        # dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())

        dis_named_params = list(self.dis_a.named_parameters()) + list(
            self.dis_b.named_parameters())
        # gen_named_params = list(self.gen_a.named_parameters()) + list(self.gen_b.named_parameters())

        ### modifying list params
        dis_params = list()
        # gen_params = list()
        for name, param in dis_named_params:
            if "_Q" in name:
                # print('%s --> gen_params' % name)
                gen_params.append(param)
            else:
                dis_params.append(param)

        content_classifier_params = list(self.content_classifier.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.cla_opt = torch.optim.Adam(
            [p for p in content_classifier_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.cla_scheduler = get_scheduler(self.cla_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        self.content_classifier.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        self.gan_type = hyperparameters['dis']['gan_type']
        self.criterionQ_con = NormalNLLLoss()

        self.criterion_content_classifier = nn.CrossEntropyLoss()

        # self.batch_size = hyperparameters['batch_size']
        self.batch_size_val = hyperparameters['batch_size_val']
Пример #17
0
    def __init__(self, hyperparameters, gpu_ids=[0]):
        super(DGNet_Trainer, self).__init__()
        # 从配置文件获取生成模型的和鉴别模型的学习率
        lr_g = hyperparameters['lr_g']
        lr_d = hyperparameters['lr_d']

        # # ID的类别,这里要注意,不同的数据集都是不一样的,应该是训练数据集的ID数目,非测试集
        ID_class = hyperparameters['ID_class']

        # 看是否设置使用float16,估计float16可以增加精确度
        if not 'apex' in hyperparameters.keys():
            hyperparameters['apex'] = False
        self.fp16 = hyperparameters['apex']

        # Initiate the networks
        # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False.
        ################################################################################################################
        ##这里是定义Es和G
        # 注意这里包含了两个步骤,Es编码+解码过程,既然解码(论文Figure 2的黄色梯形G)包含到这里了,下面Ea应该不会包含解码过程了
        # 因为这里是一个类,如后续gen_a.encode()可以进行编码,gen_b.encode()可以进行解码
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'],
                              hyperparameters['gen'],
                              fp16=False)  # auto-encoder for domain a
        self.gen_b = self.gen_a  # auto-encoder for domain b
        ############################################################################################################################################

        ############################################################################################################################################
        ##这里是定义Ea
        # ID_stride,外观编码器池化层的stride
        if not 'ID_stride' in hyperparameters.keys():
            hyperparameters['ID_stride'] = 2

        # hyperparameters['ID_style']默认为'AB',论文中的Ea编码器
        #这里是设置Ea,有三种模型可以选择
        #PCB模型,ft_netAB为改造后的resnet50,ft_net为resnet50
        if hyperparameters['ID_style'] == 'PCB':
            self.id_a = PCB(ID_class)
        elif hyperparameters['ID_style'] == 'AB':
            # 这是我们执行的模型,注意的是,id_a返回两个x(表示身份),获得f,具体介绍看函数内部
            # 我们使用的是ft_netAB,是代码中Ea编码的过程,也就得到 ap code的过程,除了ap code还会得到两个分类结果
            # 现在怀疑,该分类结果,可能就是行人重识别的结果
            #ID_class表示有ID_class个不同ID的行人
            self.id_a = ft_netAB(ID_class,
                                 stride=hyperparameters['ID_stride'],
                                 norm=hyperparameters['norm_id'],
                                 pool=hyperparameters['pool'])
        else:
            self.id_a = ft_net(ID_class,
                               norm=hyperparameters['norm_id'],
                               pool=hyperparameters['pool'])  # return 2048 now

        # 这里进行的是浅拷贝,所以我认为他们的权重是一起的,可以理解为一个
        self.id_b = self.id_a
        ############################################################################################################################################################

        ############################################################################################################################################################
        ##这里是定义D
        # 鉴别器,行人重识别,这里使用的是一个多尺寸的鉴别器,大概就是说,对图片进行几次缩放,并且对每次缩放都会预测,计算总的损失
        # 经过网络3个元素,分别大小为[batch_size,1,64,32], [batch_size,1,32,16], [batch_size,1,16,8]
        self.dis_a = MsImageDis(3, hyperparameters['dis'],
                                fp16=False)  # discriminator for domain a
        self.dis_b = self.dis_a  # discriminator for domain b
        ############################################################################################################################################################

        ############################################################################################################################################################
        # load teachers
        # 加载老师模型
        # teacher:老师模型名称。对于DukeMTMC,您可以设置“best - duke”
        if hyperparameters['teacher'] != "":
            #teacher_name=best
            teacher_name = hyperparameters['teacher']
            print(teacher_name)
            #有这个操作,我怀疑是可以加载多个教师模型
            teacher_names = teacher_name.split(',')
            #构建老师模型
            teacher_model = nn.ModuleList()
            teacher_count = 0

            # 默认只有一个teacher_name='teacher_name',所以其加载的模型配置文件为项目根目录models/best/opts.yaml模型
            for teacher_name in teacher_names:
                # 加载配置文件models/best/opts.yaml
                config_tmp = load_config(teacher_name)
                if 'stride' in config_tmp:
                    #stride=1
                    stride = config_tmp['stride']
                else:
                    stride = 2

                #  老师模型加载,老师模型为ft_net为resnet50
                model_tmp = ft_net(ID_class, stride=stride)
                teacher_model_tmp = load_network(model_tmp, teacher_name)
                # 移除原本的全连接层
                teacher_model_tmp.model.fc = nn.Sequential(
                )  # remove the original fc layer in ImageNet
                teacher_model_tmp = teacher_model_tmp.cuda()
                # summary(teacher_model_tmp, (3, 224, 224))
                #使用浮点型
                if self.fp16:
                    teacher_model_tmp = amp.initialize(teacher_model_tmp,
                                                       opt_level="O1")
                teacher_model.append(teacher_model_tmp.cuda().eval())
                teacher_count += 1
            self.teacher_model = teacher_model
            # 选择是否使用bn
            if hyperparameters['train_bn']:
                self.teacher_model = self.teacher_model.apply(train_bn)
############################################################################################################################################################

# 实例正则化
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # RGB to one channel
        # 默认设置signal=gray,Es的输入为灰度图
        if hyperparameters['single'] == 'edge':
            self.single = to_edge
        else:
            self.single = to_gray(False)

        # Random Erasing when training
        #earsing_p表示随机擦除的概率
        if not 'erasing_p' in hyperparameters.keys():
            self.erasing_p = 0
        else:
            self.erasing_p = hyperparameters['erasing_p']
        #随机擦除矩形区域的一些像素,应该类似于数据增强
        self.single_re = RandomErasing(probability=self.erasing_p,
                                       mean=[0.0, 0.0, 0.0])
        # 设置T_w为1,T_w为primary feature learning loss的权重系数
        if not 'T_w' in hyperparameters.keys():
            hyperparameters['T_w'] = 1

        ################################################################################################
        # Setup the optimizers
        # 设置优化器参数
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(
            self.dis_a.parameters())  #+ list(self.dis_b.parameters())
        gen_params = list(
            self.gen_a.parameters())  #+ list(self.gen_b.parameters())
        #使用Adams优化器,用Adams训练Es,G,D
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr_d,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr_g,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        # id params
        # 因为ID_style默认为AB,所以这里不执行
        if hyperparameters['ID_style'] == 'PCB':
            ignored_params = (
                list(map(id, self.id_a.classifier0.parameters())) +
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())) +
                list(map(id, self.id_a.classifier3.parameters())))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            #Ea 的优化器
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier0.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier3.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)

        #     这里是我们执行的代码
        elif hyperparameters['ID_style'] == 'AB':
            # 忽略的参数,应该是适用于'PCB'或者其他的,但是不适用于'AB'的
            ignored_params = (
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())))
            # 获得基本的配置参数,如学习率
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']

            #对Ea使用SGD
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        else:
            ignored_params = list(map(id, self.id_a.classifier.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)

        # 选择各个网络的优化
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.id_scheduler = get_scheduler(self.id_opt, hyperparameters)
        self.id_scheduler.gamma = hyperparameters['gamma2']

        #ID Loss
        #交叉熵损失函数
        self.id_criterion = nn.CrossEntropyLoss()
        # KL散度
        self.criterion_teacher = nn.KLDivLoss(size_average=False)

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        # save memory
        if self.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.gen_a = self.gen_a.cuda()
            self.dis_a = self.dis_a.cuda()
            self.id_a = self.id_a.cuda()

            self.gen_b = self.gen_a
            self.dis_b = self.dis_a
            self.id_b = self.id_a

            self.gen_a, self.gen_opt = amp.initialize(self.gen_a,
                                                      self.gen_opt,
                                                      opt_level="O1")
            self.dis_a, self.dis_opt = amp.initialize(self.dis_a,
                                                      self.dis_opt,
                                                      opt_level="O1")
            self.id_a, self.id_opt = amp.initialize(self.id_a,
                                                    self.id_opt,
                                                    opt_level="O1")
Пример #18
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        '''
            input_dim_a和input_dim_b是输入图像的维度,RGB图就是3
            gen和dis是在yaml中定义的与架构相关的配置
        '''

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        '''
            为每幅显示的图像(总共16幅)配置随机的风格(维度为8)
        '''

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        '''
            这种简洁的写法值得学习:先将parameter()的list并起来,然后[p for p in params if p.requires_grad]
            这里分别为判别器参数、生成器参数各自建立一个优化器
            优化器采用Adam,算法参数为0.5和0.999
            优化器中可同时配置权重衰减,这里是1e-4
            学习率调节器默认配置为每100000步减小为0.5
        '''

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))
        '''
            注:这个apply函数递归地对每个子模块应用某种函数
        '''

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
        '''默认配置中,没有使用这个vgg网络'''
Пример #19
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.a_attibute = hyperparameters['label_a']
        self.b_attibute = hyperparameters['label_b']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        if self.a_attibute == 0:
            self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        else:
            self.s_a = torch.randn(display_size,
                                   self.style_dim - self.a_attibute, 1,
                                   1).cuda()
            s_attribute = [i % self.a_attibute for i in range(display_size)]
            s_attribute = torch.tensor(s_attribute, dtype=torch.long).reshape(
                (display_size, 1))
            label_a = torch.zeros(display_size,
                                  self.a_attibute,
                                  dtype=torch.float32).scatter_(
                                      1, s_attribute, 1)
            label_a = label_a.reshape(display_size, self.a_attibute, 1,
                                      1).cuda()
            self.s_a = torch.cat([self.s_a, label_a], 1)
        if self.b_attibute == 0:
            self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        else:
            self.s_b = torch.randn(display_size,
                                   self.style_dim - self.b_attibute, 1,
                                   1).cuda()
            s_attribute = [i % self.b_attibute for i in range(display_size)]
            s_attribute = torch.tensor(s_attribute, dtype=torch.long).reshape(
                (display_size, 1))
            label_b = torch.zeros(display_size,
                                  self.b_attibute,
                                  dtype=torch.float32).scatter_(
                                      1, s_attribute, 1)
            label_b = label_b.reshape(display_size, self.b_attibute, 1,
                                      1).cuda()
            self.s_b = torch.cat([self.s_b, label_b], 1)

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
Пример #20
0
    def __init__(self, hyperparameters, gpu_ids=[0]):
        super(DGNet_Trainer, self).__init__()
        # 从配置文件获取生成模型和鉴别模型的学习率
        lr_g = hyperparameters['lr_g']
        lr_d = hyperparameters['lr_d']

        # ID 类别
        ID_class = hyperparameters['ID_class']

        # 是否设置使用fp16,
        if not 'apex' in hyperparameters.keys():
            hyperparameters['apex'] = False
        self.fp16 = hyperparameters['apex']
        # Initiate the networks
        # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False.
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'],
                              hyperparameters['gen'],
                              fp16=False)  # auto-encoder for domain a
        self.gen_b = self.gen_a  # auto-encoder for domain b
        '''
        ft_netAB :   Ea
        '''
        # ID_stride: 外观编码器池化层的stride
        if not 'ID_stride' in hyperparameters.keys():
            hyperparameters['ID_stride'] = 2
        # id_a : 外观编码器  ->  Ea
        if hyperparameters['ID_style'] == 'PCB':
            self.id_a = PCB(ID_class)
        elif hyperparameters['ID_style'] == 'AB':
            self.id_a = ft_netAB(ID_class,
                                 stride=hyperparameters['ID_stride'],
                                 norm=hyperparameters['norm_id'],
                                 pool=hyperparameters['pool'])
        else:
            self.id_a = ft_net(ID_class,
                               norm=hyperparameters['norm_id'],
                               pool=hyperparameters['pool'])  # return 2048 now

        self.id_b = self.id_a  # 对图片b的操作与图片a的操作一致

        # 判别器,使用的是一个多尺寸的判别器,就是对图片进行几次缩放,并且对每次缩放都会预测,计算总的损失
        # 经过网络3个缩放,,分别为:[batch_size, 1, 64, 32],[batch_size, 1, 32, 16],[batch_size, 1, 16, 8]
        self.dis_a = MsImageDis(3, hyperparameters['dis'],
                                fp16=False)  # discriminator for domain a
        self.dis_b = self.dis_a  # discriminator for domain b

        # load teachers
        if hyperparameters['teacher'] != "":
            teacher_name = hyperparameters['teacher']
            print(teacher_name)

            # 加载多个老师模型
            teacher_names = teacher_name.split(',')
            # 构建老师模型
            teacher_model = nn.ModuleList()  # 初始化为空,接下来开始填充
            teacher_count = 0
            for teacher_name in teacher_names:
                config_tmp = load_config(teacher_name)

                # 池化层的stride
                if 'stride' in config_tmp:
                    stride = config_tmp['stride']
                else:
                    stride = 2

                # 开始搭建网络
                model_tmp = ft_net(ID_class, stride=stride)
                teacher_model_tmp = load_network(model_tmp, teacher_name)
                teacher_model_tmp.model.fc = nn.Sequential(
                )  # remove the original fc layer in ImageNet
                teacher_model_tmp = teacher_model_tmp.cuda()
                # teacher_model_tmp,[3, 224, 224]

                # 使用fp16
                if self.fp16:
                    teacher_model_tmp = amp.initialize(teacher_model_tmp,
                                                       opt_level="O1")
                teacher_model.append(teacher_model_tmp.cuda().eval(
                ))  # 第一个填充为 teacher_model_tmp.cuda().eval()
                teacher_count += 1
            self.teacher_model = teacher_model

            # 是否使用batchnorm
            if hyperparameters['train_bn']:
                self.teacher_model = self.teacher_model.apply(train_bn)
        # 实例正则化
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # RGB to one channel
        # 因为Es 需要使用灰度图, 所以single 用来将图片转化为灰度图
        if hyperparameters['single'] == 'edge':
            self.single = to_edge
        else:
            self.single = to_gray(False)

        # Random Erasing when training
        # arasing_p 随机擦除的概率
        if not 'erasing_p' in hyperparameters.keys():
            self.erasing_p = 0
        else:
            self.erasing_p = hyperparameters['erasing_p']
        # 对图片中的某一随机区域进行擦除,具体:将该区域的像素值设置为均值
        self.single_re = RandomErasing(probability=self.erasing_p,
                                       mean=[0.0, 0.0, 0.0])

        if not 'T_w' in hyperparameters.keys():
            hyperparameters['T_w'] = 1
        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(
            self.dis_a.parameters())  #+ list(self.dis_b.parameters())
        gen_params = list(
            self.gen_a.parameters())  #+ list(self.gen_b.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr_d,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr_g,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        # id params
        # 修改 id_a模型中分类器的学习率
        if hyperparameters['ID_style'] == 'PCB':
            ignored_params = (
                list(map(id, self.id_a.classifier0.parameters())) +
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())) +
                list(map(id, self.id_a.classifier3.parameters())))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier0.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier3.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        elif hyperparameters['ID_style'] == 'AB':
            ignored_params = (
                list(map(id, self.id_a.classifier1.parameters())) +
                list(map(id, self.id_a.classifier2.parameters())))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier1.parameters(),
                    'lr': lr2 * 10
                }, {
                    'params': self.id_a.classifier2.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        else:
            ignored_params = list(map(id, self.id_a.classifier.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD(
                [{
                    'params': base_params,
                    'lr': lr2
                }, {
                    'params': self.id_a.classifier.parameters(),
                    'lr': lr2 * 10
                }],
                weight_decay=hyperparameters['weight_decay'],
                momentum=0.9,
                nesterov=True)
        # 生成器和判别器中的优化策略(学习率的更新策略)
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.id_scheduler = get_scheduler(self.id_opt, hyperparameters)
        self.id_scheduler.gamma = hyperparameters['gamma2']

        #ID Loss
        self.id_criterion = nn.CrossEntropyLoss()
        self.criterion_teacher = nn.KLDivLoss(
            size_average=False)  # 生成主要特征: Lprim
        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        # save memory
        # 保存当前的模型,是为了提高计算效率
        if self.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.gen_a = self.gen_a.cuda()
            self.dis_a = self.dis_a.cuda()
            self.id_a = self.id_a.cuda()

            self.gen_b = self.gen_a
            self.dis_b = self.dis_a
            self.id_b = self.id_a

            self.gen_a, self.gen_opt = amp.initialize(self.gen_a,
                                                      self.gen_opt,
                                                      opt_level="O1")
            self.dis_a, self.dis_opt = amp.initialize(self.dis_a,
                                                      self.dis_opt,
                                                      opt_level="O1")
            self.id_a, self.id_opt = amp.initialize(self.id_a,
                                                    self.id_opt,
                                                    opt_level="O1")
Пример #21
0
    def __init__(self, param):
        super(UnsupIntrinsicTrainer, self).__init__()
        lr = param['lr']
        # Initiate the networks
        self.gen_i = AdaINGen(param['input_dim_a'], param['input_dim_a'],
                              param['gen'])  # auto-encoder for domain I
        self.gen_r = AdaINGen(param['input_dim_b'], param['input_dim_b'],
                              param['gen'])  # auto-encoder for domain R
        self.gen_s = AdaINGen(param['input_dim_c'], param['input_dim_c'],
                              param['gen'])  # auto-encoder for domain S
        self.dis_r = MsImageDis(param['input_dim_b'],
                                param['dis'])  # discriminator for domain R
        self.dis_s = MsImageDis(param['input_dim_c'],
                                param['dis'])  # discriminator for domain S
        gp = param['gen']
        self.with_mapping = True
        self.use_phy_loss = True
        self.use_content_loss = True
        if 'ablation_study' in param:
            if 'with_mapping' in param['ablation_study']:
                wm = param['ablation_study']['with_mapping']
                self.with_mapping = True if wm != 0 else False
            if 'wo_phy_loss' in param['ablation_study']:
                wpl = param['ablation_study']['wo_phy_loss']
                self.use_phy_loss = True if wpl == 0 else False
            if 'wo_content_loss' in param['ablation_study']:
                wcl = param['ablation_study']['wo_content_loss']
                self.use_content_loss = True if wcl == 0 else False

        if self.with_mapping:
            self.fea_s = IntrinsicSplitor(gp['style_dim'], gp['mlp_dim'],
                                          gp['n_layer'],
                                          gp['activ'])  # split style for I
            self.fea_m = IntrinsicMerger(gp['style_dim'], gp['mlp_dim'],
                                         gp['n_layer'],
                                         gp['activ'])  # merge style for R, S
        self.bias_shift = param['bias_shift']
        self.instance_norm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = param['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(param['display_size'])
        self.s_r = torch.randn(display_size, self.style_dim, 1, 1).cuda() + 1.
        self.s_s = torch.randn(display_size, self.style_dim, 1, 1).cuda() - 1.

        # Setup the optimizers
        beta1 = param['beta1']
        beta2 = param['beta2']
        dis_params = list(self.dis_r.parameters()) + list(
            self.dis_s.parameters())
        if self.with_mapping:
            gen_params = list(self.gen_i.parameters()) + list(self.gen_r.parameters()) + \
                         list(self.gen_s.parameters()) + \
                         list(self.fea_s.parameters()) + list(self.fea_m.parameters())
        else:
            gen_params = list(self.gen_i.parameters()) + list(
                self.gen_r.parameters()) + list(self.gen_s.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=param['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=param['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, param)
        self.gen_scheduler = get_scheduler(self.gen_opt, param)

        # Network weight initialization
        self.apply(weights_init(param['init']))
        self.dis_r.apply(weights_init('gaussian'))
        self.dis_s.apply(weights_init('gaussian'))
        self.best_result = float('inf')
        self.reflectance_loss = LocalAlbedoSmoothnessLoss(param)
Пример #22
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.dis_sa = MsImageDis(
            hyperparameters['input_dim_a'] * 2,
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_sb = MsImageDis(
            hyperparameters['input_dim_b'] * 2,
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        dis_style_params = list(self.dis_sa.parameters()) + list(
            self.dis_sb.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_style_opt = torch.optim.Adam(
            [p for p in dis_style_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.dis_style_scheduler = get_scheduler(self.dis_style_opt,
                                                 hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_sa.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))
        self.dis_sb.apply(weights_init('gaussian'))
        if hyperparameters['gen']['CE_method'] == 'vgg':
            self.gen_a.content_init()
            self.gen_b.content_init()
        self.criterion = nn.L1Loss().cuda()
        self.triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2).cuda()
        self.kld = nn.KLDivLoss()
        self.contextual_loss = ContextualLoss()

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
Пример #23
0
    def __init__(self, hyperparameters):
        super(ERGAN_Trainer, self).__init__()
        lr_G = hyperparameters['lr_G']
        lr_D = hyperparameters['lr_D']
        print(lr_D, lr_G)
        self.fp16 = hyperparameters['fp16']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.gen_b.enc_content = self.gen_a.enc_content  # content share weight
        #self.gen_b.enc_style = self.gen_a.enc_style
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.a = hyperparameters['gen']['new_size'] / 224
        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr_D,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr_G,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        if self.fp16:
            self.gen_a = self.gen_a.cuda()
            self.dis_a = self.dis_a.cuda()
            self.gen_b = self.gen_b.cuda()
            self.dis_b = self.dis_b.cuda()
            self.gen_a, self.gen_opt = amp.initialize(self.gen_a,
                                                      self.gen_opt,
                                                      opt_level="O1")
            self.dis_a, self.dis_opt = amp.initialize(self.dis_a,
                                                      self.dis_opt,
                                                      opt_level="O1")

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False