Beispiel #1
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = 0.5
        beta2 = 0.999
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = lr_scheduler.StepLR(self.dis_opt, step_size=hyperparameters['step_size'],
                                        gamma=hyperparameters['gamma'], last_epoch=-1)
        self.gen_scheduler = lr_scheduler.StepLR(self.gen_opt, step_size=hyperparameters['step_size'],
                                        gamma=hyperparameters['gamma'], last_epoch=-1)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))
Beispiel #2
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.criterionGAN = GANLoss(hyperparameters['dis']['gan_type']).cuda()
        self.featureLoss = nn.MSELoss(reduction='mean')
        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))
Beispiel #3
0
    def __init__(self, hyperparameters):
        super(DSMAP_Trainer, self).__init__()
        # Initiate the networks
        mid_downsample = hyperparameters['gen'].get('mid_downsample', 1)
        self.content_enc = ContentEncoder_share(
            hyperparameters['gen']['n_downsample'],
            mid_downsample,
            hyperparameters['gen']['n_res'],
            hyperparameters['input_dim_a'],
            hyperparameters['gen']['dim'],
            'in',
            hyperparameters['gen']['activ'],
            pad_type=hyperparameters['gen']['pad_type'])

        self.style_dim = hyperparameters['gen']['style_dim']
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'], self.content_enc, 'a',
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'], self.content_enc, 'b',
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'], self.content_enc.output_dim,
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'], self.content_enc.output_dim,
            hyperparameters['dis'])  # discriminator for domain b
Beispiel #4
0
    def __init__(self, hyperparameters, device):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        self.device = device
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1,
                               1).to(self.device)
        self.s_b = torch.randn(display_size, self.style_dim, 1,
                               1).to(self.device)

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
Beispiel #5
0
    def build_model(self):
        self.G = AdaINGen(self.label_dim, self.gen_var)
        self.D = Discriminator(self.label_dim, self.dis_var)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.lr,
                                            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.lr,
                                            [self.beta1, self.beta2])

        self.G.to(self.device)
        self.D.to(self.device)
Beispiel #6
0
    def __init__(self, hyperparameters):
        super(aclgan_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_AB = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain A
        self.gen_BA = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain B
        self.dis_A = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain A
        self.dis_B = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain B
        self.dis_2 = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator 2
#        self.dis_2B = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator 2 for domain B
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.z_1 = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.z_2 = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.z_3 = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_A.parameters()) + list(self.dis_B.parameters()) + list(self.dis_2.parameters())
        gen_params = list(self.gen_AB.parameters()) + list(self.gen_BA.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.alpha = hyperparameters['alpha']
        self.focus_lam = hyperparameters['focus_loss']

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_A.apply(weights_init('gaussian'))
        self.dis_B.apply(weights_init('gaussian'))
        self.dis_2.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
Beispiel #7
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()   # super() 函数是用于调用父类(超类)的一个方法。
        lr = hyperparameters['lr']
        # Initiate the networks, 需要好好看看生成器和鉴别器到底是如何构造的
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        # https://blog.csdn.net/liuxiao214/article/details/81037416
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        # s_a , s_b 表示的是两个不同的style
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()  # 16*8*1*1
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        # 两个鉴别器
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        # 两个生成器
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        # 优化器
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        # 优化策略
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        # 解释 apply  apply(lambda x,y : x+y, (1),{'y' : 2})   https://zhuanlan.zhihu.com/p/42756654
        self.apply(weights_init(hyperparameters['init']))  # 初始化当前类
        self.dis_a.apply(weights_init('gaussian'))   # 初始化dis_a,是一个类对象
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
Beispiel #8
0
 def __init__(self, hyperparameters):
     super(MUNIT, self).__init__()
     lr = hyperparameters['lr']
     # Initiate the networks
     self.gen_a = AdaINGen(
         hyperparameters['input_dim_a'],
         hyperparameters['gen'])  # auto-encoder for domain a
     self.gen_b = AdaINGen(
         hyperparameters['input_dim_b'],
         hyperparameters['gen'])  # auto-encoder for domain b
     self.dis_a = MsImageDis(
         hyperparameters['input_dim_a'],
         hyperparameters['dis'])  # discriminator for domain a
     self.dis_b = MsImageDis(
         hyperparameters['input_dim_b'],
         hyperparameters['dis'])  # discriminator for domain b
     self.instancenorm = nn.InstanceNorm2d(512, affine=False)
     self.style_dim = hyperparameters['gen']['style_dim']
Beispiel #9
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'], hyperparameters['new_size'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        self.reg_param = hyperparameters['reg_param']
        self.beta_step = hyperparameters['beta_step']
        self.target_kl = hyperparameters['target_kl']
        self.gan_type = hyperparameters['gan_type']

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_b.parameters())
        gen_params = list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.gen_b.apply(weights_init(hyperparameters['init']))
        self.dis_b.apply(weights_init('gaussian'))

        # SSIM Loss
        self.ssim_loss = pytorch_ssim.SSIM()
Beispiel #10
0
    def __init__(self, hyperparameters, opts):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        self.opts = opts

        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.seg = segmentor(num_classes=2, channels=hyperparameters['input_dim_b'], hyperpars=hyperparameters['seg'])

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.seg_opt = torch.optim.SGD(self.seg.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters['lr_policy'], hyperparameters=hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters['lr_policy'], hyperparameters=hyperparameters)
        self.seg_scheduler = get_scheduler(self.seg_opt, 'constant', hyperparameters=None)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        self.criterion_seg = DiceLoss(ignore_index=hyperparameters['seg']['ignore_index'])
Beispiel #11
0
    def __init__(self, hyperparameters):
        super(AGUIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.noise_dim = hyperparameters['gen']['noise_dim']
        self.attr_dim = len(hyperparameters['gen']['selected_attrs'])
        self.gen = AdaINGen(hyperparameters['input_dim'],
                            hyperparameters['gen'])
        self.dis = MsImageDis(hyperparameters['input_dim'], self.attr_dim,
                              hyperparameters['dis'])
        self.dis_content = ContentDis(
            hyperparameters['gen']['dim'] *
            (2**hyperparameters['gen']['n_downsample']), self.attr_dim)

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis.parameters()) + list(
            self.dis_content.parameters())
        gen_params = list(self.gen.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis.apply(weights_init('gaussian'))
        self.dis_content.apply(weights_init('gaussian'))
Beispiel #12
0
    def __init__(self, param):
        super(SupIntrinsicTrainer, self).__init__()
        lr = param['lr']
        # Initiate the networks
        self.model = AdaINGen(param['input_dim_a'],
                              param['input_dim_b'] + param['input_dim_b'],
                              param['gen'])  # auto-encoder

        # Setup the optimizers
        beta1 = param['beta1']
        beta2 = param['beta2']
        gen_params = list(self.model.parameters())
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=param['weight_decay'])
        self.gen_scheduler = get_scheduler(self.gen_opt, param)

        # Network weight initialization
        self.apply(weights_init(param['init']))
        self.best_result = float('inf')
Beispiel #13
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
Beispiel #14
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()   # super() 函数是用于调用父类(超类)的一个方法。
        lr = hyperparameters['lr']
        # Initiate the networks, 需要好好看看生成器和鉴别器到底是如何构造的
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        # https://blog.csdn.net/liuxiao214/article/details/81037416
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        # s_a , s_b 表示的是两个不同的style
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()  # 16*8*1*1
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        # 两个鉴别器
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        # 两个生成器
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        # 优化器
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        # 优化策略
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        # 解释 apply  apply(lambda x,y : x+y, (1),{'y' : 2})   https://zhuanlan.zhihu.com/p/42756654
        self.apply(weights_init(hyperparameters['init']))  # 初始化当前类
        self.dis_a.apply(weights_init('gaussian'))   # 初始化dis_a,是一个类对象
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)     # here the self.s_a is random style
        s_b = Variable(self.s_b)     # here the self.s_b is random style
        # 两个auto-encoder
        c_a, s_a_fake = self.gen_a.encode(x_a)  # c_a, s_a_fake is the content and style of input x_a
        c_b, s_b_fake = self.gen_b.encode(x_b)
        # x_ba 表示的是 imgb->imga
        x_ba = self.gen_a.decode(c_b, s_a)  # combine(c_b, s_a) to generate the x_ba
        x_ab = self.gen_b.decode(c_a, s_b)  # combine(c_a, s_b) to generate the x_ab
        #  训练模式
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        # 这里只负责参数的更新操作,真正的训练操作实际上来源于AdainGan 以及 MsImgDis
        # self.gen_opt 表示的是生成器的优化器
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)

        # print('content shape:', c_b.shape)
        # print('style shape:', s_b_prime.shape)
        # decode (within domain), 这个是必要的部分,因为需要保证解码器的效果
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)   # 本质上这里传过来的还是c_a, 我们希望c_a_recon与c_a越相似越好
        # decode again (if needed)
        x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss  为什么要用vgg呢?也就是内容不变的约束
        # 对feature map 的 L2 loss  好好理解一下这里的域不变loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

    def sample(self, x_a, x_b):
        # 模型调整为eval模式
        self.eval()
        # # random style
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        # 实际上x_a.size(0) = display_size  16
        for i in range(x_a.size(0)):
            # https://blog.csdn.net/xiexu911/article/details/80820028
            # unsequeeze 在指定的维度上对数据的维度进行扩展
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))

        # 这里用到了cat,那么x_a_recon是4个channel,这样最终输出的就是32个channel,然后我们将这32个channel 分别打印show出来。
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, hyperparameters):
        # self.dis_opt 表示的是鉴别器的优化器
        self.dis_opt.zero_grad()
        # s_a : a 图片的风格
        # s_b : b 图片的风格
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        #  content encode ,encode 只是将不同图片的内容和风格进行编码,不做迁移处理, 生成器有一个编码器和一个解码器
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        # 进行交叉域的内容和风格decode,这里需要好好理解一下decode到底在做什么,为什么可以进行内容和风格的混合
        x_ba = self.gen_a.decode(c_b, s_a)   # xba 是由随机风格和原始图像内容组合的结果
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss  鉴别器loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)  # 这里要把x_ba, x_a 独立出来来看,因为鉴别器的目的是
        # 为了判断真伪, self.loss_dis_a 是分别计算x_ba, x_a的loss的和, 他们两个对象之间不涉及对比。
        # 这里挺有意思的,鉴别器的loss是两个独立输入图片的二分类loss的和, 注意的一点就是,我们实际上是希望送入鉴别器的这两张图拒用相同的风格。
        # 这里就要区分和重建loss的区别。
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b) # 为什么是这样的组合呢?
        #  hyperparameters['gan_w'] weight of adversarial loss
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        # https://blog.csdn.net/jacke121/article/details/82995740 深入理解backward(), step()
        self.loss_dis_total.backward()  # calculate grad
        self.dis_opt.step()         # update grad

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step() # update the learning rate
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        # 每次选择的都是最新的模型
        last_model_name = get_model_list(checkpoint_dir, "gen")
        print('resume model: ', last_model_name)
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations, gpuids):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        if len(gpuids) > 1:
            torch.save({'a': self.gen_a.module.state_dict(), 'b': self.gen_b.module.state_dict()}, gen_name)
            torch.save({'a': self.dis_a.module.state_dict(), 'b': self.dis_b.module.state_dict()}, dis_name)
            torch.save({'gen': self.gen_opt.module.state_dict(), 'dis': self.dis_opt.module.state_dict()}, opt_name)
        else:
            torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
            torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
            torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
Beispiel #15
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b

        self.content_classifier = ContentClassifier(
            hyperparameters['gen']['dim'], hyperparameters)

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.num_con_c = hyperparameters['dis']['num_con_c']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        # dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())

        dis_named_params = list(self.dis_a.named_parameters()) + list(
            self.dis_b.named_parameters())
        # gen_named_params = list(self.gen_a.named_parameters()) + list(self.gen_b.named_parameters())

        ### modifying list params
        dis_params = list()
        # gen_params = list()
        for name, param in dis_named_params:
            if "_Q" in name:
                # print('%s --> gen_params' % name)
                gen_params.append(param)
            else:
                dis_params.append(param)

        content_classifier_params = list(self.content_classifier.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.cla_opt = torch.optim.Adam(
            [p for p in content_classifier_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.cla_scheduler = get_scheduler(self.cla_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        self.content_classifier.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        self.gan_type = hyperparameters['dis']['gan_type']
        self.criterionQ_con = NormalNLLLoss()

        self.criterion_content_classifier = nn.CrossEntropyLoss()

        # self.batch_size = hyperparameters['batch_size']
        self.batch_size_val = hyperparameters['batch_size_val']
Beispiel #16
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b

        self.content_classifier = ContentClassifier(
            hyperparameters['gen']['dim'], hyperparameters)

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.num_con_c = hyperparameters['dis']['num_con_c']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        # dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())

        dis_named_params = list(self.dis_a.named_parameters()) + list(
            self.dis_b.named_parameters())
        # gen_named_params = list(self.gen_a.named_parameters()) + list(self.gen_b.named_parameters())

        ### modifying list params
        dis_params = list()
        # gen_params = list()
        for name, param in dis_named_params:
            if "_Q" in name:
                # print('%s --> gen_params' % name)
                gen_params.append(param)
            else:
                dis_params.append(param)

        content_classifier_params = list(self.content_classifier.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.cla_opt = torch.optim.Adam(
            [p for p in content_classifier_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.cla_scheduler = get_scheduler(self.cla_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        self.content_classifier.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        self.gan_type = hyperparameters['dis']['gan_type']
        self.criterionQ_con = NormalNLLLoss()

        self.criterion_content_classifier = nn.CrossEntropyLoss()

        # self.batch_size = hyperparameters['batch_size']
        self.batch_size_val = hyperparameters['batch_size_val']

        # self.accu_content_classifier_c_a = 0
        # self.accu_content_classifier_c_a_recon = 0
        # self.accu_content_classifier_c_b = 0
        # self.accu_content_classifier_c_b_recon = 0
        # self.accu_CC_all = 0

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, sample_b, hyperparameters, sample_a_limited):
        x_b, label_b = sample_b
        x_a_limited, label_a_limited = sample_a_limited

        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        # x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        # x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)

        # GAN loss
        # self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        x_ba_dis_out = self.dis_a(x_ba)
        self.loss_gen_adv_a = self.compute_gen_adv_loss(x_ba_dis_out)

        # self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        x_ab_dis_out = self.dis_b(x_ab)
        self.loss_gen_adv_b = self.compute_gen_adv_loss(x_ab_dis_out)

        # loss info continuous
        self.info_cont_loss_a = self.compute_info_cont_loss(s_a, x_ba_dis_out)
        self.info_cont_loss_b = self.compute_info_cont_loss(s_b, x_ab_dis_out)

        # label_predict_c_a = self.content_classifier(c_a)
        # label_predict_c_a_recon = self.content_classifier(c_a_recon)
        label_predict_c_b = self.content_classifier(c_b)
        label_predict_c_b_recon = self.content_classifier(c_b_recon)

        # loss_content_classifier_c_a = self.compute_content_classifier_loss(label_predict_c_a, label_a)
        # loss_content_classifier_c_a_recon = self.compute_content_classifier_loss(label_predict_c_a_recon, label_a)

        ### compute loss of classifier c_a based on limited samples
        c_a_limited, _ = self.gen_a.encode(x_a_limited)
        label_predict_c_a_limited = self.content_classifier(c_a_limited)

        x_ab_limited = self.gen_b.decode(c_a_limited, s_b)
        c_a_recon_limited, _ = self.gen_b.encode(x_ab_limited)
        label_predict_c_a_recon_limited = self.content_classifier(
            c_a_recon_limited)

        loss_content_classifier_c_a = self.compute_content_classifier_loss(
            label_predict_c_a_limited, label_a_limited)
        loss_content_classifier_c_a_recon = self.compute_content_classifier_loss(
            label_predict_c_a_recon_limited, label_a_limited)

        loss_content_classifier_b = self.compute_content_classifier_loss(
            label_predict_c_b, label_b)
        loss_content_classifier_c_b_recon = self.compute_content_classifier_loss(
            label_predict_c_b_recon, label_b)

        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              self.info_cont_loss_a + \
                              self.info_cont_loss_b +\
                              loss_content_classifier_c_a + \
                              loss_content_classifier_c_a_recon + \
                              loss_content_classifier_b + \
                              loss_content_classifier_c_b_recon

        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_info_cont_loss(self, style_code, outs_fake):
        loss = 0
        num_cont_code = self.num_con_c
        for it, (out_fake) in enumerate(outs_fake):
            q_mu = out_fake['mu']
            q_var = out_fake['var']
            info_noise = style_code[:, -num_cont_code:].view(
                -1, num_cont_code).squeeze().squeeze()
            # print(q_mu.size())
            # print(q_var.size())
            # print(info_noise.size())
            # print(num_cont_code)
            # exit()
            loss += self.criterionQ_con(info_noise, q_mu, q_var) * 0.1
        return loss

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, x_a, x_b):
        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def cla_update(self, sample_a, sample_b):
        x_a, label_a = sample_a
        x_b, label_b = sample_b
        # print('cla_update')
        # print(x_a.device())
        # exit()
        self.cla_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # print("c_a")
        # print(c_a.size())
        # exit()
        # decode (within domain)
        # x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        # x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)

        label_predict_c_a = self.content_classifier(c_a)
        label_predict_c_a_recon = self.content_classifier(c_a_recon)
        label_predict_c_b = self.content_classifier(c_b)
        label_predict_c_b_recon = self.content_classifier(c_b_recon)

        self.loss_content_classifier_c_a = self.compute_content_classifier_loss(
            label_predict_c_a, label_a)
        self.loss_content_classifier_c_a_recon = self.compute_content_classifier_loss(
            label_predict_c_a_recon, label_a)
        # self.loss_content_classifier_c_a_and_c_a_recon = self.compute_content_classifier_two_predictions_loss(label_predict_c_a_recon, label_predict_c_a)

        self.loss_content_classifier_b = self.compute_content_classifier_loss(
            label_predict_c_b, label_b)
        self.loss_content_classifier_c_b_recon = self.compute_content_classifier_loss(
            label_predict_c_b_recon, label_b)
        # self.loss_content_classifier_c_b_and_c_b_recon = self.compute_content_classifier_two_predictions_loss(label_predict_c_b_recon, label_predict_c_b)

        # self.accu_content_classifier_c_a = self.compute_content_classifier_accuracy(label_predict_c_a, label_a)
        # self.accu_content_classifier_c_a_recon = self.compute_content_classifier_accuracy(label_predict_c_a_recon,
        #                                                                                   label_a)
        # self.accu_content_classifier_c_b = self.compute_content_classifier_accuracy(label_predict_c_b, label_b)
        # self.accu_content_classifier_c_b_recon = self.compute_content_classifier_accuracy(label_predict_c_b_recon,
        #                                                                                   label_b)
        # self.accu_CC_all = self.mean_list([
        #     self.accu_content_classifier_c_a,
        #     self.accu_content_classifier_c_a_recon,
        #     self.accu_content_classifier_c_b,
        #     self.accu_content_classifier_c_b_recon
        # ])

        self.loss_cla_total = self.loss_content_classifier_c_a + self.loss_content_classifier_c_a_recon + \
                              self.loss_content_classifier_b + self.loss_content_classifier_c_b_recon
        # self.loss_content_classifier_c_a_and_c_a_recon + \
        # self.loss_content_classifier_c_b_and_c_b_recon
        self.loss_cla_total.backward()
        self.cla_opt.step()

    def cla_inference(self, test_loader_a, test_loader_b):
        accu_content_classifier_c_a = []
        accu_content_classifier_c_a_recon = []
        accu_content_classifier_c_b = []
        accu_content_classifier_c_b_recon = []
        for it_inf, (samples_a_test, samples_b_test) in enumerate(
                zip(test_loader_a, test_loader_b)):
            x_a, label_a = samples_a_test[0].cuda().detach(
            ), samples_a_test[1].cuda().detach()
            x_b, label_b = samples_b_test[0].cuda().detach(
            ), samples_b_test[1].cuda().detach()

            s_a = Variable(
                torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
            s_b = Variable(
                torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

            # encode
            c_a, s_a_prime = self.gen_a.encode(x_a)
            c_b, s_b_prime = self.gen_b.encode(x_b)
            # print("c_a")
            # print(c_a.size())
            # exit()
            # decode (within domain)
            # x_a_recon = self.gen_a.decode(c_a, s_a_prime)
            # x_b_recon = self.gen_b.decode(c_b, s_b_prime)
            # decode (cross domain)
            x_ba = self.gen_a.decode(c_b, s_a)
            x_ab = self.gen_b.decode(c_a, s_b)
            # encode again
            c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
            c_a_recon, s_b_recon = self.gen_b.encode(x_ab)

            label_predict_c_a = self.content_classifier(c_a)
            label_predict_c_a_recon = self.content_classifier(c_a_recon)
            label_predict_c_b = self.content_classifier(c_b)
            label_predict_c_b_recon = self.content_classifier(c_b_recon)

            # self.loss_content_classifier_c_a = self.compute_content_classifier_loss(label_predict_c_a, label_a)
            # self.loss_content_classifier_c_a_recon = self.compute_content_classifier_loss(label_predict_c_a_recon, label_a)
            # self.loss_content_classifier_c_a_and_c_a_recon = self.compute_content_classifier_two_predictions_loss(label_predict_c_a_recon, label_predict_c_a)
            #
            # self.loss_content_classifier_b = self.compute_content_classifier_loss(label_predict_c_b, label_b)
            # self.loss_content_classifier_c_b_recon = self.compute_content_classifier_loss(label_predict_c_b_recon, label_b)
            # self.loss_content_classifier_c_b_and_c_b_recon = self.compute_content_classifier_two_predictions_loss(label_predict_c_b_recon,
            #                                                                                  label_predict_c_b)

            accu_content_classifier_c_a.append(
                self.compute_content_classifier_accuracy(
                    label_predict_c_a, label_a))
            accu_content_classifier_c_a_recon.append(
                self.compute_content_classifier_accuracy(
                    label_predict_c_a_recon, label_a))
            accu_content_classifier_c_b.append(
                self.compute_content_classifier_accuracy(
                    label_predict_c_b, label_b))
            accu_content_classifier_c_b_recon.append(
                self.compute_content_classifier_accuracy(
                    label_predict_c_b_recon, label_b))

        self.accu_content_classifier_c_a = self.mean_list(
            accu_content_classifier_c_a)
        self.accu_content_classifier_c_a_recon = self.mean_list(
            accu_content_classifier_c_a_recon)
        self.accu_content_classifier_c_b = self.mean_list(
            accu_content_classifier_c_b)
        self.accu_content_classifier_c_b_recon = self.mean_list(
            accu_content_classifier_c_b_recon)

        self.accu_CC_all = self.mean_list([
            self.accu_content_classifier_c_a,
            self.accu_content_classifier_c_a_recon,
            self.accu_content_classifier_c_b,
            self.accu_content_classifier_c_b_recon
        ])

        # self.loss_cla_total = self.loss_content_classifier_c_a + self.loss_content_classifier_c_a_recon + \
        #                       self.loss_content_classifier_b + self.loss_content_classifier_c_b_recon + \
        #                       self.loss_content_classifier_c_a_and_c_a_recon + \
        #                       self.loss_content_classifier_c_b_and_c_b_recon
        # self.loss_cla_total.backward()
        # self.cla_opt.step()

    @staticmethod
    def mean_list(lst):
        return sum(lst) / len(lst)

    def dis_update(self, x_a, x_b, hyperparameters):
        # print('dis_update')
        # print(x_a.is_cuda())
        # exit()
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)

        # D loss
        # self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        # print(x_ba.detach().size())
        # print(x_a.size())
        # exit()
        x_ba_dis_out = self.dis_a(x_ba.detach())
        x_a_dis_out = self.dis_a(x_a)
        self.loss_dis_a = self.compute_dis_loss(x_ba_dis_out, x_a_dis_out)
        # self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        x_ab_dis_out = self.dis_b(x_ab.detach())
        x_b_dis_out = self.dis_b(x_b)
        self.loss_dis_b = self.compute_dis_loss(x_ab_dis_out, x_b_dis_out)

        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def compute_content_classifier_loss(self, label_predict, label_true):
        loss = self.criterion_content_classifier(label_predict, label_true)
        return loss

    def compute_content_classifier_two_predictions_loss(
            self, label_predict_1, label_predict_2):
        # loss = self.criterion_content_classifier(label_predict, label_true)
        # print(label_predict_1.size())
        # print(label_predict_2.size())
        loss = torch.mean(torch.abs(label_predict_1 - label_predict_2))
        # print(loss.size())
        # exit()
        return loss

    def compute_content_classifier_accuracy(self, label_predict, label_true):
        # print("label_true")
        # print(label_true)
        #
        # print("label_predict")
        # print(label_predict[0])
        # print("max")
        values, indices = label_predict.max(1)
        # print(indices)

        results = (label_true == indices)
        # print(results)

        total_correct = results.sum().cpu().numpy()
        # print("total_correct")
        # print(total_correct)

        # total_samples = results.size()
        # print("total_samples")
        # print(total_samples)

        accuracy = float(total_correct) / float(self.batch_size_val)
        # print("accuracy")
        # print(accuracy)
        #
        # exit()
        return accuracy

    def compute_dis_loss(self, outs_fake, outs_real):
        # calculate the loss to train D
        # outs0 = self.forward(input_fake)
        # outs1 = self.forward(input_real)
        loss = 0
        for it, (out_fake, out_real) in enumerate(zip(outs_fake, outs_real)):
            out_fake = out_fake['output_d']
            out_real = out_real['output_d']
            if self.gan_type == 'lsgan':
                loss += torch.mean((out_fake - 0)**2) + torch.mean(
                    (out_real - 1)**2)
            elif self.gan_type == 'nsgan':
                all0 = Variable(torch.zeros_like(out_fake.data).cuda(),
                                requires_grad=False)
                all1 = Variable(torch.ones_like(out_real.data).cuda(),
                                requires_grad=False)
                loss += torch.mean(
                    F.binary_cross_entropy(F.sigmoid(out_fake), all0) +
                    F.binary_cross_entropy(F.sigmoid(out_real), all1))
            else:
                assert 0, "Unsupported GAN type: {}".format(self.gan_type)
        return loss

    def compute_gen_adv_loss(self, outs_fake):
        # calculate the loss to train G
        # out_fake = self.forward(input_fake)
        loss = 0
        for it, (out_fake) in enumerate(outs_fake):
            out_fake = out_fake['output_d']
            if self.gan_type == 'lsgan':
                loss += torch.mean((out_fake - 1)**2)  # LSGAN
            elif self.gan_type == 'nsgan':
                all1 = Variable(torch.ones_like(out_fake.data).cuda(),
                                requires_grad=False)
                loss += torch.mean(
                    F.binary_cross_entropy(F.sigmoid(out_fake), all1))
            else:
                assert 0, "Unsupported GAN type: {}".format(self.gan_type)
        return loss

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load content classifier
        last_model_name = get_model_list(checkpoint_dir, "con_cla")
        state_dict = torch.load(last_model_name)
        self.content_classifier.load_state_dict(state_dict['con_cla'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        self.cla_opt.load_state_dict(state_dict['con_cla'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        self.cla_scheduler = get_scheduler(self.cla_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        con_cla_name = os.path.join(snapshot_dir,
                                    'con_cla_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')

        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save({'con_cla': self.content_classifier.state_dict()},
                   con_cla_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict(),
                'con_cla': self.cla_opt.state_dict()
            }, opt_name)
Beispiel #17
0
class aclgan_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(aclmaskpermgidtno_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_AB = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain A
        self.gen_BA = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain B
        self.dis_A = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain A
        self.dis_B = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain B
        self.dis_2 = MsImageDis(hyperparameters['input_dim_b'],
                                hyperparameters['dis'])  # discriminator 2
        #        self.dis_2B = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator 2 for domain B
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.z_1 = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.z_2 = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.z_3 = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_A.parameters()) + list(
            self.dis_B.parameters()) + list(self.dis_2.parameters())
        gen_params = list(self.gen_AB.parameters()) + list(
            self.gen_BA.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.alpha = hyperparameters['alpha']
        self.focus_lam = hyperparameters['focus_loss']

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_A.apply(weights_init('gaussian'))
        self.dis_B.apply(weights_init('gaussian'))
        self.dis_2.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        z_1 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        z_2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        z_3 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_1, _ = self.gen_AB.encode(x_a)
        c_2, s_2 = self.gen_BA.encode(x_a)
        c_4, s_4 = self.gen_AB.encode(x_b)
        # decode
        self.x_B_fake = self.gen_AB.decode(c_1, z_1)
        self.x_A_fake = self.gen_BA.decode(c_2, z_2)
        # recon
        self.x_A_recon = self.gen_BA.decode(c_2, s_2)
        self.x_B_recon = self.gen_AB.decode(c_4, s_4)
        #encode 2
        c_3, _ = self.gen_BA.encode(self.x_B_fake)
        self.x_A2_fake = self.gen_BA.decode(c_3, z_3)

        self.X_A_A1_pair = torch.cat((x_a, self.x_A_fake), -3)
        self.X_A_A2_pair = torch.cat((x_a, self.x_A2_fake), -3)

    def focus_translation(self, x_fg, x_bg, x_focus):
        x_map = (x_focus + 1) / 2
        x_map = x_map.repeat(1, 3, 1, 1)
        return torch.mul(x_fg, x_map) + torch.mul(x_bg, 1 - x_map)

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()

        focus_delta = hyperparameters['focus_delta']
        focus_lambda = hyperparameters['focus_loss']
        focus_lower = hyperparameters['focus_lower']
        focus_upper = hyperparameters['focus_upper']
        focus_epsilon = hyperparameters['focus_epsilon']
        #forward
        z_1 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        z_2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        z_3 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_1, _ = self.gen_AB.encode(x_a)
        c_2, s_2 = self.gen_BA.encode(x_a)
        c_4, s_4 = self.gen_AB.encode(x_b)
        # decode
        if focus_lambda > 0:
            x_B_fake, x_B_focus = self.gen_AB.decode(c_1, z_1).split(3, 1)
            x_A_fake, x_A_focus = self.gen_BA.decode(c_2,
                                                     self.alpha * z_2).split(
                                                         3, 1)
            x_B_fake = self.focus_translation(x_B_fake, x_a, x_B_focus)
            x_A_fake = self.focus_translation(x_A_fake, x_a, x_A_focus)
            # recon
            x_A_recon, x_A_recon_focus = self.gen_BA.decode(c_2,
                                                            s_2).split(3, 1)
            x_B_recon, x_B_recon_focus = self.gen_AB.decode(c_4,
                                                            s_4).split(3, 1)
#            x_A_recon = self.focus_translation(x_A_recon, x_a, x_A_recon_focus)
#            x_B_recon = self.focus_translation(x_B_recon, x_b, x_B_recon_focus)
        else:
            x_B_fake = self.gen_AB.decode(c_1, z_1)
            x_A_fake = self.gen_BA.decode(c_2, self.alpha * z_2)
            # recon
            x_A_recon = self.gen_BA.decode(c_2, s_2)
            x_B_recon = self.gen_AB.decode(c_4, s_4)

        #encode 2
        c_3, _ = self.gen_BA.encode(x_B_fake)
        if focus_lambda > 0:
            x_A2_fake, x_A2_focus = self.gen_BA.decode(c_3, z_3).split(3, 1)
            x_A2_fake = self.focus_translation(x_A2_fake, x_B_fake, x_A2_focus)
        else:
            x_A2_fake = self.gen_BA.decode(c_3, z_3)

        x_A_A1_pair = torch.cat((x_a, x_A_fake), -3)
        x_A_A2_pair = torch.cat((x_a, x_A2_fake), -3)

        # GAN loss
        self.loss_gen_adv_A = (self.dis_A.calc_gen_loss(x_A_fake) + \
                              self.dis_A.calc_gen_loss(x_A2_fake)) * 0.5
        self.loss_gen_adv_B = self.dis_B.calc_gen_loss(x_B_fake)
        self.loss_gen_adv_2 = self.dis_2.calc_gen_d2_loss(
            x_A_A1_pair, x_A_A2_pair)

        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_A + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_B + \
                              hyperparameters['gan_cw'] * self.loss_gen_adv_2
        if focus_lambda > 0:
            x_B_focus = (x_B_focus + 1) / 2
            x_A_focus = (x_A_focus + 1) / 2
            x_A2_focus = (x_A2_focus + 1) / 2
            self.loss_gen_focus_B_size = (F.relu(torch.sum(x_B_focus - focus_upper), inplace=True) ** 2) * focus_delta + \
                (F.relu(torch.sum(focus_lower - x_B_focus), inplace=True) ** 2) * focus_delta
            self.loss_gen_focus_B_digit = torch.sum(
                1 / (torch.abs(x_B_focus - 0.5) + focus_epsilon))
            self.loss_gen_focus_A_size = (F.relu(torch.sum(x_A_focus - focus_upper), inplace=True) ** 2) * focus_delta + \
                (F.relu(torch.sum(focus_lower - x_A_focus), inplace=True) ** 2) * focus_delta
            self.loss_gen_focus_A_digit = torch.sum(
                1 / (torch.abs(x_A_focus - 0.5) + focus_epsilon))
            #            self.loss_gen_focus_A = torch.sum(1 / (torch.abs(x_A_focus - 0.5) + focus_epsilon))
            self.loss_gen_focus_A2_size = (F.relu(torch.sum(x_A2_focus - focus_upper), inplace=True) ** 2) * focus_delta + \
                (F.relu(torch.sum(focus_lower - x_A2_focus), inplace=True) ** 2) * focus_delta
            self.loss_gen_focus_A2_digit = torch.sum(
                1 / (torch.abs(x_A2_focus - 0.5) + focus_epsilon))
            self.loss_gen_total += focus_lambda * (self.loss_gen_focus_B_size + self.loss_gen_focus_B_digit + \
                            self.loss_gen_focus_A_size + self.loss_gen_focus_A_digit +\
                            self.loss_gen_focus_A2_size + self.loss_gen_focus_A2_digit)/ x_a.size(2) / x_a.size(3) / x_a.size(0) / 3
        self.loss_idt_A = self.recon_criterion(x_A_recon, x_a)
        self.loss_idt_B = self.recon_criterion(x_B_recon, x_b)
        self.loss_gen_total += hyperparameters['recon_x_w'] * self.loss_idt_A + \
                              hyperparameters['recon_x_w'] * self.loss_idt_B

        #        print(self.loss_gen_focus_B, self.loss_gen_total)
        #        print(self.loss_idt_A)
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, x_a, x_b):
        self.eval()
        z_1 = Variable(self.z_1)
        z_2 = Variable(self.z_2)
        z_3 = Variable(self.z_3)
        x_A, x_B, x_A_fake, x_B_fake, x_A2_fake = [], [], [], [], []
        if self.focus_lam > 0:
            mask_A, mask_B, mask_A2, mask_recon = [], [], [], []
            x_A_recon = []
        else:
            x_A_recon, x_B_recon = [], []
        for i in range(x_a.size(0)):
            x_A.append(x_a[i].unsqueeze(0))
            x_B.append(x_b[i].unsqueeze(0))
            if self.focus_lam > 0:
                c_1, s_1 = self.gen_BA.encode(x_a[i].unsqueeze(0))
                img, mask = self.gen_BA.decode(c_1, z_1[i].unsqueeze(0)).split(
                    3, 1)
                x_A_fake.append(
                    self.focus_translation(img, x_a[i].unsqueeze(0), mask))
                mask_A.append(mask)

                img, mask = self.gen_BA.decode(c_1, s_1).split(3, 1)
                #                x_A_recon.append(self.focus_translation(img, x_a[i].unsqueeze(0), mask))
                x_A_recon.append(img)
                mask_recon.append(mask)

                c_2, _ = self.gen_AB.encode(x_a[i].unsqueeze(0))
                x_b_img, mask = self.gen_AB.decode(c_2,
                                                   z_2[i].unsqueeze(0)).split(
                                                       3, 1)
                x_b_img = self.focus_translation(x_b_img, x_a[i].unsqueeze(0),
                                                 mask)
                x_B_fake.append(x_b_img)
                mask_B.append(mask)

                c_3, _ = self.gen_BA.encode(x_b_img)
                img, mask = self.gen_BA.decode(c_3, z_3[i].unsqueeze(0)).split(
                    3, 1)
                x_A2_fake.append(self.focus_translation(img, x_b_img, mask))
                mask_A2.append(mask)

            else:
                c_1, s_1 = self.gen_BA.encode(x_a[i].unsqueeze(0))
                x_A_fake.append(self.gen_BA.decode(c_1, z_1[i].unsqueeze(0)))
                x_A_recon.append(self.gen_BA.decode(c_1, s_1))

                c_2, _ = self.gen_AB.encode(x_a[i].unsqueeze(0))
                x_B1 = self.gen_AB.decode(c_2, z_2[i].unsqueeze(0))
                x_B_fake.append(x_B1)

                c_3, _ = self.gen_BA.encode(x_B1)
                x_A2_fake.append(self.gen_BA.decode(c_3, z_3[i].unsqueeze(0)))

                c_4, s_4 = self.gen_AB.encode(x_b)
                x_B_recon.append(self.gen_AB.decode(c_4, s_4))

        if self.focus_lam > 0:
            x_A, x_B = torch.cat(x_A), torch.cat(x_B)
            x_A_fake, x_B_fake = torch.cat(x_A_fake), torch.cat(x_B_fake)
            mask_A, x_A2_fake = torch.cat(mask_A), torch.cat(x_A2_fake)
            mask_B, mask_recon = torch.cat(mask_B), torch.cat(mask_recon)
            mask_A2, x_A_recon = torch.cat(mask_A2), torch.cat(x_A_recon)
            self.train()
            return x_A, x_A_fake, mask_A, x_B_fake, mask_B, x_A2_fake, mask_A2, x_A_recon, mask_recon

        else:
            x_A, x_B = torch.cat(x_A), torch.cat(x_B)
            x_A_fake, x_B_fake = torch.cat(x_A_fake), torch.cat(x_B_fake)
            x_A_recon, x_A2_fake = torch.cat(x_A_recon), torch.cat(x_A2_fake)
            x_B_recon = torch.cat(x_B_recon)
            self.train()
            return x_A, x_A_fake, x_B_fake, x_A2_fake, x_A_recon, x_B, x_B_recon

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()

        focus_delta = hyperparameters['focus_delta']
        focus_lambda = hyperparameters['focus_loss']
        focus_epsilon = hyperparameters['focus_epsilon']
        #forward
        z_1 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        z_2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        z_3 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_1, _ = self.gen_AB.encode(x_a)
        c_2, s_2 = self.gen_BA.encode(x_a)
        c_4, s_4 = self.gen_AB.encode(x_b)
        # decode
        if focus_lambda > 0:
            x_B_fake, x_B_focus = self.gen_AB.decode(c_1, z_1).split(3, 1)
            x_A_fake, x_A_focus = self.gen_BA.decode(c_2,
                                                     self.alpha * z_2).split(
                                                         3, 1)
            x_B_fake = self.focus_translation(x_B_fake, x_a, x_B_focus)
            x_A_fake = self.focus_translation(x_A_fake, x_a, x_A_focus)
        else:
            x_B_fake = self.gen_AB.decode(c_1, z_1)
            x_A_fake = self.gen_BA.decode(c_2, self.alpha * z_2)

        #encode 2
        c_3, _ = self.gen_BA.encode(x_B_fake)
        if focus_lambda > 0:
            x_A2_fake, x_A2_focus = self.gen_BA.decode(c_3, z_3).split(3, 1)
            x_A2_fake = self.focus_translation(x_A2_fake, x_B_fake, x_A2_focus)
        else:
            x_A2_fake = self.gen_BA.decode(c_3, z_3)

        x_A_A1_pair = torch.cat((x_a, x_A_fake), -3)
        x_A_A2_pair = torch.cat((x_a, x_A2_fake), -3)

        # D loss
        self.loss_dis_A = (self.dis_A.calc_dis_loss(x_A_fake, x_a) + \
                           self.dis_A.calc_dis_loss(x_A2_fake, x_a)) * 0.5
        self.loss_dis_B = self.dis_B.calc_dis_loss(x_B_fake, x_b)
        self.loss_dis_2 = self.dis_2.calc_dis_loss(x_A_A1_pair, x_A_A2_pair)

        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_A + \
                                hyperparameters['gan_w'] * self.loss_dis_B + \
                                hyperparameters['gan_cw'] * self.loss_dis_2

        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_AB.load_state_dict(state_dict['AB'])
        self.gen_BA.load_state_dict(state_dict['BA'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_A.load_state_dict(state_dict['A'])
        self.dis_B.load_state_dict(state_dict['B'])
        self.dis_2.load_state_dict(state_dict['2'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save(
            {
                'AB': self.gen_AB.state_dict(),
                'BA': self.gen_BA.state_dict()
            }, gen_name)
        torch.save(
            {
                'A': self.dis_A.state_dict(),
                'B': self.dis_B.state_dict(),
                '2': self.dis_2.state_dict()
            }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Beispiel #18
0
    def __init__(self, hyperparameters, gpu_ids=[0]):
        super(DGNet_Trainer, self).__init__()
        lr_g = hyperparameters['lr_g']  #生成器学习率
        lr_d = hyperparameters['lr_d']  #判别器学习率
        ID_class = hyperparameters['ID_class']
        if not 'apex' in hyperparameters.keys():
            hyperparameters['apex'] = False
        self.fp16 = hyperparameters['apex']
        # Initiate the networks
        # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False.
        # 构建Es编码+解码过程 gen_a.encode()可以进行编码,gen_b.encode()可以进行解码
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'], fp16 = False)  # auto-encoder for domain a
        self.gen_b = self.gen_a  # auto-encoder for domain b
        
        # ID_stride,外观编码器池化层的stride
        if not 'ID_stride' in hyperparameters.keys():
            hyperparameters['ID_stride'] = 2
        # 构建外观编码器
        if hyperparameters['ID_style']=='PCB':
            self.id_a = PCB(ID_class)
        elif hyperparameters['ID_style']=='AB': #使用的AB编码器
            self.id_a = ft_netAB(ID_class, stride = hyperparameters['ID_stride'], norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) 
        else:
            self.id_a = ft_net(ID_class, norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) # return 2048 now
        # 浅拷贝,两者等同
        self.id_b = self.id_a
        
        # 鉴别器,使用的是一个多尺寸的鉴别器,即对图片进行几次缩放,并且对每次缩放都会预测,计算总的损失
         
        self.dis_a = MsImageDis(3, hyperparameters['dis'], fp16 = False)  # discriminator for domain a
        self.dis_b = self.dis_a # discriminator for domain b

        # load teachers 加载教师模型
        if hyperparameters['teacher'] != "":
            teacher_name = hyperparameters['teacher']
            print(teacher_name)
            # 构建教师模型
            teacher_names = teacher_name.split(',')
            teacher_model = nn.ModuleList()
            teacher_count = 0
      
            for teacher_name in teacher_names:
                config_tmp = load_config(teacher_name)
                if 'stride' in config_tmp:
                    stride = config_tmp['stride'] 
                else:
                    stride = 2
                # 网络搭建
                model_tmp = ft_net(ID_class, stride = stride)
                teacher_model_tmp = load_network(model_tmp, teacher_name)
                teacher_model_tmp.model.fc = nn.Sequential()  # remove the original fc layer in ImageNet
                teacher_model_tmp = teacher_model_tmp.cuda()
                if self.fp16:
                    teacher_model_tmp = amp.initialize(teacher_model_tmp, opt_level="O1")
                teacher_model.append(teacher_model_tmp.cuda().eval())
                teacher_count +=1
            self.teacher_model = teacher_model
            # 选择是否使用bn
            if hyperparameters['train_bn']:
                self.teacher_model = self.teacher_model.apply(train_bn)
        # 实例正则化
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # RGB to one channel
        if hyperparameters['single']=='edge':
            self.single = to_edge
        else:
            self.single = to_gray(False)

        # Random Erasing when training
        if not 'erasing_p' in hyperparameters.keys():
            self.erasing_p = 0
        else:
            self.erasing_p = hyperparameters['erasing_p']  # erasing_p表示随机擦除的概率
        # 随机擦除矩形区域的一些像素,数据增强
        self.single_re = RandomErasing(probability = self.erasing_p, mean=[0.0, 0.0, 0.0])

        if not 'T_w' in hyperparameters.keys():
            hyperparameters['T_w'] = 1
        # Setup the optimizers 设置优化器的参数
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) #+ list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) #+ list(self.gen_b.parameters())
        # 使用Adam优化器
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        # id params
        if hyperparameters['ID_style']=='PCB':
            ignored_params = (list(map(id, self.id_a.classifier0.parameters() ))
                            +list(map(id, self.id_a.classifier1.parameters() ))
                            +list(map(id, self.id_a.classifier2.parameters() ))
                            +list(map(id, self.id_a.classifier3.parameters() ))
                            )
            base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD([
                 {'params': base_params, 'lr': lr2},
                 {'params': self.id_a.classifier0.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier1.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier2.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier3.parameters(), 'lr': lr2*10}
            ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)
        elif hyperparameters['ID_style']=='AB':
            ignored_params = (list(map(id, self.id_a.classifier1.parameters()))
                            + list(map(id, self.id_a.classifier2.parameters())))
            # 获得基本的配置参数,如学习率
            base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD([
                 {'params': base_params, 'lr': lr2},
                 {'params': self.id_a.classifier1.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier2.parameters(), 'lr': lr2*10}
            ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)
        else:
            ignored_params = list(map(id, self.id_a.classifier.parameters() ))
            base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD([
                 {'params': base_params, 'lr': lr2},
                 {'params': self.id_a.classifier.parameters(), 'lr': lr2*10}
            ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)
        
        # 选择各个网络优化的策略
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.id_scheduler = get_scheduler(self.id_opt, hyperparameters)
        self.id_scheduler.gamma = hyperparameters['gamma2']

        #ID Loss
        self.id_criterion = nn.CrossEntropyLoss()
        self.criterion_teacher = nn.KLDivLoss(size_average=False)
        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        # save memory
        if self.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.gen_a = self.gen_a.cuda()
            self.dis_a = self.dis_a.cuda()
            self.id_a = self.id_a.cuda()

            self.gen_b = self.gen_a
            self.dis_b = self.dis_a
            self.id_b = self.id_a

            self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1")
            self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1")
            self.id_a, self.id_opt = amp.initialize(self.id_a, self.id_opt, opt_level="O1")
Beispiel #19
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        '''
            input_dim_a和input_dim_b是输入图像的维度,RGB图就是3
            gen和dis是在yaml中定义的与架构相关的配置
        '''

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        '''
            为每幅显示的图像(总共16幅)配置随机的风格(维度为8)
        '''

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        '''
            这种简洁的写法值得学习:先将parameter()的list并起来,然后[p for p in params if p.requires_grad]
            这里分别为判别器参数、生成器参数各自建立一个优化器
            优化器采用Adam,算法参数为0.5和0.999
            优化器中可同时配置权重衰减,这里是1e-4
            学习率调节器默认配置为每100000步减小为0.5
        '''

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))
        '''
            注:这个apply函数递归地对每个子模块应用某种函数
        '''

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
        '''默认配置中,没有使用这个vgg网络'''

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    #   注,只有在forward内部,是evaluation模式,具体这个方法在哪里用到了,我还不太清楚。
    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1,
                                   1).cuda())  #   两个随机的风格码
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)  #   prime表示是由真图解码来的风格码
        c_b, s_b_prime = self.gen_b.encode(
            x_b)  #   c码为(1,256,64,64);s码为(1,8,1,1)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)  #   (a)用内容码和风格码还原原图
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)  #   (b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(
            c_a_recon,
            s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            c_b_recon,
            s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)  #   (a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)  #   (b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)  #   (c)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)  #   (d)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)  #   (e)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)  #   (f)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0  #   (g)
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0  #   (h)
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)  #   (i)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)  #   (j)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba,
            x_b) if hyperparameters['vgg_w'] > 0 else 0  #   (k)
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab,
            x_a) if hyperparameters['vgg_w'] > 0 else 0  #   (l)
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    #   送进去两张图片(batch),交换他们的风格
    def sample(self, x_a, x_b):
        self.eval()
        s_a1 = Variable(self.s_a)  #   这个是固定的某种风格
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1,
                                    1).cuda())  #   这是即时生成的随机风格
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(
                x_a[i].unsqueeze(0))  #   这是把送入的图片进行编码
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a,
                                               s_a_fake))  #   这是对送入的图像进行重建
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(
                c_b, s_a1[i].unsqueeze(0)))  #   这是把固定的风格施加在b的内容上,产生a风格的图片
            x_ba2.append(self.gen_a.decode(
                c_b, s_a2[i].unsqueeze(0)))  #   这是把随机风格施加在b的内容上,产生a风格的图片
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1,
                                   1).cuda())  #   生成图片的随机风格码 (1,8,1,1)
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, _ = self.gen_a.encode(x_a)  #   将图像利用编码器
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Beispiel #20
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters, opts):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        self.opts = opts

        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.seg = segmentor(num_classes=2, channels=hyperparameters['input_dim_b'], hyperpars=hyperparameters['seg'])

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.seg_opt = torch.optim.SGD(self.seg.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters['lr_policy'], hyperparameters=hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters['lr_policy'], hyperparameters=hyperparameters)
        self.seg_scheduler = get_scheduler(self.seg_opt, 'constant', hyperparameters=None)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        self.criterion_seg = DiceLoss(ignore_index=hyperparameters['seg']['ignore_index'])

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters, target_a, iters):
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        if iters >= hyperparameters['guide_gen_iters']:
            config.task = 0
            self.seg.eval()
            self.pred_x_ab = self.seg(x_ab)
            self.seg.train()

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0

        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)

        # semantic loss ab
        if iters >= hyperparameters['guide_gen_iters']:
            self.loss_sem_ab, _ = self.criterion_seg(self.pred_x_ab, target_a)
        else:
            self.loss_sem_ab = 0

        # only use semantic loss when segmentor has reasonably low loss
        if not hasattr(self, 'loss_seg_ab') or self.loss_seg_ab.detach().item() > -0.3:
            self.loss_sem_ab = 0

        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['sem_w'] * self.loss_sem_ab

        self.loss_gen_total.backward()
        self.gen_opt.step()

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def seg_update(self, x_a,  x_b, target_a, target_b):
        self.seg.train()
        self.seg_opt.zero_grad()
        s_b = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        with torch.no_grad():
            # encode
            c_a, _ = self.gen_a.encode(x_a)
            # decode (cross domain)
            x_ab = self.gen_b.decode(c_a, s_b)

        config.task = 0
        self.pred_x_ab = self.seg(x_ab.detach())

        config.task = 1
        self.pred_x_b = self.seg(x_b)

        self.loss_seg_ab, _ = self.criterion_seg(self.pred_x_ab, target_a)
        self.loss_seg_b, _ = self.criterion_seg(self.pred_x_b, target_b)

        self.loss_seg_total = self.loss_seg_ab + self.loss_seg_b
        self.loss_seg_total.backward()
        self.seg_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()
        if self.seg_scheduler is not None:
            self.seg_scheduler.step()

    def sample(self, x_a, x_b):
        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_aba, x_bab, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], [], [], []
        for i in range(x_b.size(0)):
            # encode
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            # decode (within domain)
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            # decode (cross domain)
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
            # encode again
            c_b_recon, s_a_recon = self.gen_a.encode(x_ba1[-1])
            c_a_recon, s_b_recon = self.gen_b.encode(x_ab1[-1])
            x_aba.append(self.gen_a.decode(c_a_recon, s_a_fake))
            x_bab.append(self.gen_b.decode(c_b_recon, s_b_fake))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab)

        self.train()
        return x_a, x_a_recon, x_aba, x_ab1, x_ab2, x_b, x_b_recon, x_bab, x_ba1, x_ba2

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, 'gen')
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, 'dis')
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load segmentor
        last_model_name = get_model_list(checkpoint_dir, 'seg')
        state_dict = torch.load(last_model_name)
        self.seg.load_state_dict(state_dict)

        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'opt.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])

        state_dict = torch.load(os.path.join(checkpoint_dir, 'opt_seg.pt'))
        self.seg_opt.load_state_dict(state_dict)

        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters['lr_policy'], hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters['lr_policy'], hyperparameters, iterations)
        self.seg_scheduler = get_scheduler(self.seg_opt, 'constant', None, iterations)

        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        seg_name = os.path.join(snapshot_dir, 'seg_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'opt.pt')
        opt_seg_name = os.path.join(snapshot_dir, 'opt_seg.pt')

        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save(self.seg.state_dict(), seg_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
        torch.save(self.seg_opt.state_dict(), opt_seg_name)
Beispiel #21
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        x_a.volatile = True
        x_b.volatile = True
        s_a = Variable(self.s_a, volatile=True)
        s_b = Variable(self.s_b, volatile=True)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

    def sample(self, x_a, x_b):
        self.eval()
        x_a.volatile = True
        x_b.volatile = True
        s_a1 = Variable(self.s_a, volatile=True)
        s_b1 = Variable(self.s_b, volatile=True)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(), volatile=True)
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(), volatile=True)
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
Beispiel #22
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.a_attibute = hyperparameters['label_a']
        self.b_attibute = hyperparameters['label_b']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        if self.a_attibute == 0:
            self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        else:
            self.s_a = torch.randn(display_size,
                                   self.style_dim - self.a_attibute, 1,
                                   1).cuda()
            s_attribute = [i % self.a_attibute for i in range(display_size)]
            s_attribute = torch.tensor(s_attribute, dtype=torch.long).reshape(
                (display_size, 1))
            label_a = torch.zeros(display_size,
                                  self.a_attibute,
                                  dtype=torch.float32).scatter_(
                                      1, s_attribute, 1)
            label_a = label_a.reshape(display_size, self.a_attibute, 1,
                                      1).cuda()
            self.s_a = torch.cat([self.s_a, label_a], 1)
        if self.b_attibute == 0:
            self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        else:
            self.s_b = torch.randn(display_size,
                                   self.style_dim - self.b_attibute, 1,
                                   1).cuda()
            s_attribute = [i % self.b_attibute for i in range(display_size)]
            s_attribute = torch.tensor(s_attribute, dtype=torch.long).reshape(
                (display_size, 1))
            label_b = torch.zeros(display_size,
                                  self.b_attibute,
                                  dtype=torch.float32).scatter_(
                                      1, s_attribute, 1)
            label_b = label_b.reshape(display_size, self.b_attibute, 1,
                                      1).cuda()
            self.s_b = torch.cat([self.s_b, label_b], 1)

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
Beispiel #23
0
    def __init__(self, param):
        super(UnsupIntrinsicTrainer, self).__init__()
        lr = param['lr']
        # Initiate the networks
        self.gen_i = AdaINGen(param['input_dim_a'], param['input_dim_a'],
                              param['gen'])  # auto-encoder for domain I
        self.gen_r = AdaINGen(param['input_dim_b'], param['input_dim_b'],
                              param['gen'])  # auto-encoder for domain R
        self.gen_s = AdaINGen(param['input_dim_c'], param['input_dim_c'],
                              param['gen'])  # auto-encoder for domain S
        self.dis_r = MsImageDis(param['input_dim_b'],
                                param['dis'])  # discriminator for domain R
        self.dis_s = MsImageDis(param['input_dim_c'],
                                param['dis'])  # discriminator for domain S
        gp = param['gen']
        self.with_mapping = True
        self.use_phy_loss = True
        self.use_content_loss = True
        if 'ablation_study' in param:
            if 'with_mapping' in param['ablation_study']:
                wm = param['ablation_study']['with_mapping']
                self.with_mapping = True if wm != 0 else False
            if 'wo_phy_loss' in param['ablation_study']:
                wpl = param['ablation_study']['wo_phy_loss']
                self.use_phy_loss = True if wpl == 0 else False
            if 'wo_content_loss' in param['ablation_study']:
                wcl = param['ablation_study']['wo_content_loss']
                self.use_content_loss = True if wcl == 0 else False

        if self.with_mapping:
            self.fea_s = IntrinsicSplitor(gp['style_dim'], gp['mlp_dim'],
                                          gp['n_layer'],
                                          gp['activ'])  # split style for I
            self.fea_m = IntrinsicMerger(gp['style_dim'], gp['mlp_dim'],
                                         gp['n_layer'],
                                         gp['activ'])  # merge style for R, S
        self.bias_shift = param['bias_shift']
        self.instance_norm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = param['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(param['display_size'])
        self.s_r = torch.randn(display_size, self.style_dim, 1, 1).cuda() + 1.
        self.s_s = torch.randn(display_size, self.style_dim, 1, 1).cuda() - 1.

        # Setup the optimizers
        beta1 = param['beta1']
        beta2 = param['beta2']
        dis_params = list(self.dis_r.parameters()) + list(
            self.dis_s.parameters())
        if self.with_mapping:
            gen_params = list(self.gen_i.parameters()) + list(self.gen_r.parameters()) + \
                         list(self.gen_s.parameters()) + \
                         list(self.fea_s.parameters()) + list(self.fea_m.parameters())
        else:
            gen_params = list(self.gen_i.parameters()) + list(
                self.gen_r.parameters()) + list(self.gen_s.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=param['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=param['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, param)
        self.gen_scheduler = get_scheduler(self.gen_opt, param)

        # Network weight initialization
        self.apply(weights_init(param['init']))
        self.dis_r.apply(weights_init('gaussian'))
        self.dis_s.apply(weights_init('gaussian'))
        self.best_result = float('inf')
        self.reflectance_loss = LocalAlbedoSmoothnessLoss(param)
Beispiel #24
0
class UnsupIntrinsicTrainer(nn.Module):
    def __init__(self, param):
        super(UnsupIntrinsicTrainer, self).__init__()
        lr = param['lr']
        # Initiate the networks
        self.gen_i = AdaINGen(param['input_dim_a'], param['input_dim_a'],
                              param['gen'])  # auto-encoder for domain I
        self.gen_r = AdaINGen(param['input_dim_b'], param['input_dim_b'],
                              param['gen'])  # auto-encoder for domain R
        self.gen_s = AdaINGen(param['input_dim_c'], param['input_dim_c'],
                              param['gen'])  # auto-encoder for domain S
        self.dis_r = MsImageDis(param['input_dim_b'],
                                param['dis'])  # discriminator for domain R
        self.dis_s = MsImageDis(param['input_dim_c'],
                                param['dis'])  # discriminator for domain S
        gp = param['gen']
        self.with_mapping = True
        self.use_phy_loss = True
        self.use_content_loss = True
        if 'ablation_study' in param:
            if 'with_mapping' in param['ablation_study']:
                wm = param['ablation_study']['with_mapping']
                self.with_mapping = True if wm != 0 else False
            if 'wo_phy_loss' in param['ablation_study']:
                wpl = param['ablation_study']['wo_phy_loss']
                self.use_phy_loss = True if wpl == 0 else False
            if 'wo_content_loss' in param['ablation_study']:
                wcl = param['ablation_study']['wo_content_loss']
                self.use_content_loss = True if wcl == 0 else False

        if self.with_mapping:
            self.fea_s = IntrinsicSplitor(gp['style_dim'], gp['mlp_dim'],
                                          gp['n_layer'],
                                          gp['activ'])  # split style for I
            self.fea_m = IntrinsicMerger(gp['style_dim'], gp['mlp_dim'],
                                         gp['n_layer'],
                                         gp['activ'])  # merge style for R, S
        self.bias_shift = param['bias_shift']
        self.instance_norm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = param['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(param['display_size'])
        self.s_r = torch.randn(display_size, self.style_dim, 1, 1).cuda() + 1.
        self.s_s = torch.randn(display_size, self.style_dim, 1, 1).cuda() - 1.

        # Setup the optimizers
        beta1 = param['beta1']
        beta2 = param['beta2']
        dis_params = list(self.dis_r.parameters()) + list(
            self.dis_s.parameters())
        if self.with_mapping:
            gen_params = list(self.gen_i.parameters()) + list(self.gen_r.parameters()) + \
                         list(self.gen_s.parameters()) + \
                         list(self.fea_s.parameters()) + list(self.fea_m.parameters())
        else:
            gen_params = list(self.gen_i.parameters()) + list(
                self.gen_r.parameters()) + list(self.gen_s.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=param['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=param['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, param)
        self.gen_scheduler = get_scheduler(self.gen_opt, param)

        # Network weight initialization
        self.apply(weights_init(param['init']))
        self.dis_r.apply(weights_init('gaussian'))
        self.dis_s.apply(weights_init('gaussian'))
        self.best_result = float('inf')
        self.reflectance_loss = LocalAlbedoSmoothnessLoss(param)

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def physical_criterion(self, x_i, x_r, x_s):
        return torch.mean(torch.abs(x_i - x_r * x_s))

    def forward(self, x_i):
        c_i, s_i_fake = self.gen_i.encode(x_i)
        if self.with_mapping:
            s_r, s_s = self.fea_s(s_i_fake)
        else:
            s_r = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) + self.bias_shift
            s_s = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) - self.bias_shift
        x_ri = self.gen_r.decode(c_i, s_r)
        x_si = self.gen_s.decode(c_i, s_s)
        return x_ri, x_si

    def inference(self, x_i, use_rand_fea=False):
        with torch.no_grad():
            c_i, s_i_fake = self.gen_i.encode(x_i)
            if self.with_mapping:
                s_r, s_s = self.fea_s(s_i_fake)
            else:
                s_r = Variable(
                    torch.randn(x_i.size(0), self.style_dim, 1,
                                1).cuda()) + self.bias_shift
                s_s = Variable(
                    torch.randn(x_i.size(0), self.style_dim, 1,
                                1).cuda()) - self.bias_shift
            if use_rand_fea:
                s_r = Variable(
                    torch.randn(x_i.size(0), self.style_dim, 1,
                                1).cuda()) + self.bias_shift
                s_s = Variable(
                    torch.randn(x_i.size(0), self.style_dim, 1,
                                1).cuda()) - self.bias_shift
            x_ri = self.gen_r.decode(c_i, s_r)
            x_si = self.gen_s.decode(c_i, s_s)
        return x_ri, x_si

    # noinspection PyAttributeOutsideInit
    def gen_update(self, x_i, x_r, x_s, targets=None, param=None):
        self.gen_opt.zero_grad()
        # ============= Domain Translations =============
        # encode
        c_i, s_i_prime = self.gen_i.encode(x_i)
        c_r, s_r_prime = self.gen_r.encode(x_r)
        c_s, s_s_prime = self.gen_s.encode(x_s)

        if self.with_mapping:
            s_ri, s_si = self.fea_s(s_i_prime)
        else:
            s_ri = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) + self.bias_shift
            s_si = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) - self.bias_shift
        s_r_rand = Variable(
            torch.randn(x_r.size(0), self.style_dim, 1,
                        1).cuda()) + self.bias_shift
        s_s_rand = Variable(
            torch.randn(x_s.size(0), self.style_dim, 1,
                        1).cuda()) - self.bias_shift
        if self.with_mapping:
            s_i_recon = self.fea_m(s_ri, s_si)
        else:
            s_i_recon = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1, 1).cuda())
        # decode (within domain)
        x_i_recon = self.gen_i.decode(c_i, s_i_prime)
        x_r_recon = self.gen_s.decode(c_r, s_r_prime)
        x_s_recon = self.gen_r.decode(c_s, s_s_prime)
        # decode (cross domain)
        x_rs = self.gen_r.decode(c_s, s_r_rand)
        x_ri = self.gen_r.decode(c_i, s_ri)
        x_ri_rand = self.gen_r.decode(c_i, s_r_rand)
        x_sr = self.gen_s.decode(c_r, s_s_rand)
        x_si = self.gen_s.decode(c_i, s_si)
        x_si_rand = self.gen_s.decode(c_i, s_r_rand)
        # encode again, for feature domain consistency constraints
        c_rs_recon, s_rs_recon = self.gen_r.encode(x_rs)
        c_ri_recon, s_ri_recon = self.gen_r.encode(x_ri)
        c_ri_rand_recon, s_ri_rand_recon = self.gen_r.encode(x_ri_rand)
        c_sr_recon, s_sr_recon = self.gen_s.encode(x_sr)
        c_si_recon, s_si_recon = self.gen_s.encode(x_si)
        c_si_rand_recon, s_si_rand_recon = self.gen_s.encode(x_si_rand)
        # decode again, for image domain cycle consistency
        x_rsr = self.gen_r.decode(c_sr_recon, s_r_prime)
        x_iri = self.gen_i.decode(c_ri_recon, s_i_prime)
        x_iri_rand = self.gen_i.decode(c_ri_rand_recon, s_i_prime)
        x_srs = self.gen_s.decode(c_rs_recon, s_s_prime)
        x_isi = self.gen_i.decode(c_si_recon, s_i_prime)
        x_isi_rand = self.gen_i.decode(c_si_rand_recon, s_i_prime)

        # ============= Loss Functions =============
        # Encoder decoder reconstruction loss for three domain
        self.loss_gen_recon_x_i = self.recon_criterion(x_i_recon, x_i)
        self.loss_gen_recon_x_r = self.recon_criterion(x_r_recon, x_r)
        self.loss_gen_recon_x_s = self.recon_criterion(x_s_recon, x_s)
        # Style-level reconstruction loss for cross domain
        if self.with_mapping:
            self.loss_gen_recon_s_ii = self.recon_criterion(
                s_i_recon, s_i_prime)
        else:
            self.loss_gen_recon_s_ii = 0
        self.loss_gen_recon_s_ri = self.recon_criterion(s_ri_recon, s_ri)
        self.loss_gen_recon_s_ri_rand = self.recon_criterion(
            s_ri_rand_recon, s_ri)
        self.loss_gen_recon_s_rs = self.recon_criterion(s_rs_recon, s_r_rand)
        self.loss_gen_recon_s_sr = self.recon_criterion(s_sr_recon, s_s_rand)
        self.loss_gen_recon_s_si = self.recon_criterion(s_si_recon, s_si)
        self.loss_gen_recon_s_si_rand = self.recon_criterion(
            s_si_rand_recon, s_si)
        # Content-level reconstruction loss for cross domain
        self.loss_gen_recon_c_rs = self.recon_criterion(c_rs_recon, c_s)
        self.loss_gen_recon_c_ri = self.recon_criterion(
            c_ri_recon, c_i) if self.use_content_loss is True else 0
        self.loss_gen_recon_c_ri_rand = self.recon_criterion(
            c_ri_rand_recon, c_i) if self.use_content_loss is True else 0
        self.loss_gen_recon_c_sr = self.recon_criterion(c_sr_recon, c_r)
        self.loss_gen_recon_c_si = self.recon_criterion(
            c_si_recon, c_i) if self.use_content_loss is True else 0
        self.loss_gen_recon_c_si_rand = self.recon_criterion(
            c_si_rand_recon, c_i) if self.use_content_loss is True else 0
        # Cycle consistency loss for three image domain
        self.loss_gen_cyc_recon_x_rs = self.recon_criterion(x_rsr, x_r)
        self.loss_gen_cyc_recon_x_ir = self.recon_criterion(x_iri, x_i)
        self.loss_gen_cyc_recon_x_ir_rand = self.recon_criterion(
            x_iri_rand, x_i)
        self.loss_gen_cyc_recon_x_sr = self.recon_criterion(x_srs, x_s)
        self.loss_gen_cyc_recon_x_is = self.recon_criterion(x_isi, x_i)
        self.loss_gen_cyc_recon_x_is_rand = self.recon_criterion(
            x_isi_rand, x_i)
        # GAN loss
        self.loss_gen_adv_rs = self.dis_r.calc_gen_loss(x_rs)
        self.loss_gen_adv_ri = self.dis_r.calc_gen_loss(x_ri)
        self.loss_gen_adv_ri_rand = self.dis_r.calc_gen_loss(x_ri_rand)
        self.loss_gen_adv_sr = self.dis_s.calc_gen_loss(x_sr)
        self.loss_gen_adv_si = self.dis_s.calc_gen_loss(x_si)
        self.loss_gen_adv_si_rand = self.dis_s.calc_gen_loss(x_si_rand)
        # Physical loss
        self.loss_gen_phy_i = self.physical_criterion(
            x_i, x_ri, x_si) if self.use_phy_loss is True else 0
        self.loss_gen_phy_i_rand = self.physical_criterion(
            x_i, x_ri_rand, x_si_rand) if self.use_phy_loss is True else 0

        # Reflectance smoothness loss
        self.loss_refl_ri = self.reflectance_loss(
            x_ri, targets) if targets is not None else 0
        self.loss_refl_ri_rand = self.reflectance_loss(
            x_ri_rand, targets) if targets is not None else 0

        # total loss
        self.loss_gen_total = param['gan_w'] * self.loss_gen_adv_rs + \
                              param['gan_w'] * self.loss_gen_adv_ri + \
                              param['gan_w'] * self.loss_gen_adv_ri_rand + \
                              param['gan_w'] * self.loss_gen_adv_sr + \
                              param['gan_w'] * self.loss_gen_adv_si + \
                              param['gan_w'] * self.loss_gen_adv_si_rand + \
                              param['recon_x_w'] * self.loss_gen_recon_x_i + \
                              param['recon_x_w'] * self.loss_gen_recon_x_r + \
                              param['recon_x_w'] * self.loss_gen_recon_x_s + \
                              param['recon_s_w'] * self.loss_gen_recon_s_ii + \
                              param['recon_s_w'] * self.loss_gen_recon_s_ri + \
                              param['recon_s_w'] * self.loss_gen_recon_s_ri_rand + \
                              param['recon_s_w'] * self.loss_gen_recon_s_rs + \
                              param['recon_s_w'] * self.loss_gen_recon_s_si + \
                              param['recon_s_w'] * self.loss_gen_recon_s_si_rand + \
                              param['recon_s_w'] * self.loss_gen_recon_s_sr + \
                              param['recon_c_w'] * self.loss_gen_recon_c_ri + \
                              param['recon_c_w'] * self.loss_gen_recon_c_rs + \
                              param['recon_c_w'] * self.loss_gen_recon_c_ri_rand + \
                              param['recon_c_w'] * self.loss_gen_recon_c_si + \
                              param['recon_c_w'] * self.loss_gen_recon_c_sr + \
                              param['recon_c_w'] * self.loss_gen_recon_c_si_rand + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_ir + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_ir_rand + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_is + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_is_rand + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_rs + \
                              param['recon_x_cyc_w'] * self.loss_gen_cyc_recon_x_sr + \
                              param['phy_x_w'] * self.loss_gen_phy_i + \
                              param['phy_x_w'] * self.loss_gen_phy_i_rand + \
                              param['refl_smooth_w'] * self.loss_refl_ri + \
                              param['refl_smooth_w'] * self.loss_refl_ri_rand

        self.loss_gen_total.backward()
        self.gen_opt.step()

    def sample(self, x_i, x_r, x_s):
        self.eval()
        s_r = Variable(self.s_r)
        s_s = Variable(self.s_s)
        x_i_recon, x_r_recon, x_s_recon, x_rs, x_ri, x_sr, x_si = [], [], [], [], [], [], []
        for i in range(x_i.size(0)):
            c_i, s_i_fake = self.gen_i.encode(x_i[i].unsqueeze(0))
            c_r, s_r_fake = self.gen_r.encode(x_r[i].unsqueeze(0))
            c_s, s_s_fake = self.gen_s.encode(x_s[i].unsqueeze(0))
            if self.with_mapping:
                s_ri, s_si = self.fea_s(s_i_fake)
            else:
                s_ri = Variable(torch.randn(1, self.style_dim, 1,
                                            1).cuda()) + self.bias_shift
                s_si = Variable(torch.randn(1, self.style_dim, 1,
                                            1).cuda()) - self.bias_shift
            x_i_recon.append(self.gen_i.decode(c_i, s_i_fake))
            x_r_recon.append(self.gen_r.decode(c_r, s_r_fake))
            x_s_recon.append(self.gen_s.decode(c_s, s_s_fake))
            x_rs.append(self.gen_r.decode(c_s, s_r[i].unsqueeze(0)))
            x_ri.append(self.gen_r.decode(c_i, s_ri.unsqueeze(0)))
            x_sr.append(self.gen_s.decode(c_s, s_s[i].unsqueeze(0)))
            x_si.append(self.gen_s.decode(c_i, s_si.unsqueeze(0)))
        x_i_recon, x_r_recon, x_s_recon = torch.cat(x_i_recon), torch.cat(
            x_r_recon), torch.cat(x_s_recon)
        x_rs, x_ri = torch.cat(x_rs), torch.cat(x_ri)
        x_sr, x_si = torch.cat(x_sr), torch.cat(x_si)
        self.train()
        return x_i, x_i_recon, x_r, x_r_recon, x_rs, x_ri, x_s, x_s_recon, x_sr, x_si

    # noinspection PyAttributeOutsideInit
    def dis_update(self, x_i, x_r, x_s, params):
        self.dis_opt.zero_grad()
        s_r = Variable(torch.randn(x_r.size(0), self.style_dim, 1,
                                   1).cuda()) - self.bias_shift
        s_s = Variable(torch.randn(x_s.size(0), self.style_dim, 1,
                                   1).cuda()) + self.bias_shift
        # encode
        c_r, _ = self.gen_r.encode(x_r)
        c_s, _ = self.gen_s.encode(x_s)
        c_i, s_i = self.gen_i.encode(x_i)
        if self.with_mapping:
            s_ri, s_si = self.fea_s(s_i)
        else:
            s_ri = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) + self.bias_shift
            s_si = Variable(
                torch.randn(x_i.size(0), self.style_dim, 1,
                            1).cuda()) - self.bias_shift
        # decode (cross domain)
        x_rs = self.gen_r.decode(c_s, s_r)
        x_ri = self.gen_r.decode(c_i, s_ri)
        x_sr = self.gen_s.decode(c_r, s_s)
        x_si = self.gen_s.decode(c_i, s_si)
        # D loss
        self.loss_dis_rs = self.dis_r.calc_dis_loss(x_rs.detach(), x_r)
        self.loss_dis_ri = self.dis_r.calc_dis_loss(x_ri.detach(), x_r)
        self.loss_dis_sr = self.dis_s.calc_dis_loss(x_sr.detach(), x_s)
        self.loss_dis_si = self.dis_s.calc_dis_loss(x_si.detach(), x_s)

        self.loss_dis_total = params['gan_w'] * self.loss_dis_rs +\
                              params['gan_w'] * self.loss_dis_ri +\
                              params['gan_w'] * self.loss_dis_sr +\
                              params['gan_w'] * self.loss_dis_si

        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, param):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_i.load_state_dict(state_dict['i'])
        self.gen_r.load_state_dict(state_dict['r'])
        self.gen_s.load_state_dict(state_dict['s'])
        if self.with_mapping:
            self.fea_m.load_state_dict(state_dict['fm'])
            self.fea_s.load_state_dict(state_dict['fs'])
        self.best_result = state_dict['best_result']
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_r.load_state_dict(state_dict['r'])
        self.dis_s.load_state_dict(state_dict['s'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, param, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, param, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        if self.with_mapping:
            torch.save(
                {
                    'i': self.gen_i.state_dict(),
                    'r': self.gen_r.state_dict(),
                    's': self.gen_s.state_dict(),
                    'fs': self.fea_s.state_dict(),
                    'fm': self.fea_m.state_dict(),
                    'best_result': self.best_result
                }, gen_name)
        else:
            torch.save(
                {
                    'i': self.gen_i.state_dict(),
                    'r': self.gen_r.state_dict(),
                    's': self.gen_s.state_dict(),
                    'best_result': self.best_result
                }, gen_name)
        torch.save({
            'r': self.dis_r.state_dict(),
            's': self.dis_s.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Beispiel #25
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters, opts):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']

        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        self.loss = {}

        # fix the noise used in sampling
        self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(opts.output_base + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
            logger.info(
                '{} - {} - Number of parameters: {}'.format(name, model,
                                                        num_params))

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss['G/rec_x_A'] = self.recon_criterion(x_a_recon, x_a)
        self.loss['G/rec_x_B'] = self.recon_criterion(x_b_recon, x_b)
        self.loss['G/rec_s_A'] = self.recon_criterion(s_a_recon, s_a)
        self.loss['G/rec_s_B'] = self.recon_criterion(s_b_recon, s_b)
        self.loss['G/rec_c_A'] = self.recon_criterion(c_a_recon, c_a)
        self.loss['G/rec_c_B'] = self.recon_criterion(c_b_recon, c_b)
        self.loss['G/cycrec_x_A'] = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss['G/cycrec_x_B'] = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0

        # GAN loss
        self.loss['G/adv_A'] = self.dis_a.calc_gen_loss(x_ba)
        self.loss['G/adv_B'] = self.dis_b.calc_gen_loss(x_ab)

        # domain-invariant perceptual loss
        self.loss['G/vgg_A'] = self.compute_vgg_loss(self.vgg, x_ba.cuda(), x_b.cuda()) if hyperparameters['vgg_w'] > 0 else 0
        self.loss['G/vgg_B'] = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0

        # total loss
        self.loss['G/total'] = hyperparameters['gan_w'] * self.loss['G/adv_A'] + \
                              hyperparameters['gan_w'] * self.loss['G/adv_B'] + \
                              hyperparameters['recon_x_w'] * self.loss['G/rec_x_A'] + \
                              hyperparameters['recon_s_w'] * self.loss['G/rec_s_A'] + \
                              hyperparameters['recon_c_w'] * self.loss['G/rec_c_A'] + \
                              hyperparameters['recon_x_w'] * self.loss['G/rec_x_B'] + \
                              hyperparameters['recon_s_w'] * self.loss['G/rec_s_B'] + \
                              hyperparameters['recon_c_w'] * self.loss['G/rec_c_B'] + \
                              hyperparameters['recon_x_cyc_w'] * self.loss['G/cycrec_x_A'] + \
                              hyperparameters['recon_x_cyc_w'] * self.loss['G/cycrec_x_B'] + \
                              hyperparameters['vgg_w'] * self.loss['G/vgg_A'] + \
                              hyperparameters['vgg_w'] * self.loss['G/vgg_B']
        self.loss['G/total'].backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

    def sample(self, x_a, x_b):
        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a_fake.unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b_fake.unsqueeze(0)))

        outputs = {}
        outputs['A/real'] = x_a
        outputs['B/real'] = x_b

        outputs['A/rec'] = torch.cat(x_a_recon)
        outputs['B/rec'] = torch.cat(x_b_recon)

        outputs['A/B_random_style'] = torch.cat(x_ab1)
        outputs['A/B'] = torch.cat(x_ab2)
        outputs['B/A_random_style'] = torch.cat(x_ba1)
        outputs['B/A'] = torch.cat(x_ba2)

        self.train()

        return outputs

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss
        self.loss['D/A'] = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss['D/B'] = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss['D/total'] = hyperparameters['gan_w'] * self.loss['D/A'] + hyperparameters['gan_w'] * self.loss['D/B']
        self.loss['D/total'].backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            old_lr = self.dis_opt.param_groups[0]['lr']

            self.dis_scheduler.step()

            new_lr = self.dis_opt.param_groups[0]['lr']
            if old_lr != new_lr:
                logger.info('Updated D learning rate: {}'.format(new_lr))
        if self.gen_scheduler is not None:
            old_lr = self.gen_opt.param_groups[0]['lr']
            self.gen_scheduler.step()
            new_lr = self.gen_opt.param_groups[0]['lr']
            if old_lr != new_lr:
                logger.info('Updated G learning rate: {}'.format(new_lr))

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        logger.info('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)

        logger.info('Saving snapshots to: {}'.format(snapshot_dir))
Beispiel #26
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        '''
            input_dim_a和input_dim_b是输入图像的维度,RGB图就是3
            gen和dis是在yaml中定义的与架构相关的配置
        '''

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        '''
            为每幅显示的图像(总共16幅)配置随机的风格(维度为8)
        '''

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        '''
            这种简洁的写法值得学习:先将parameter()的list并起来,然后[p for p in params if p.requires_grad]
            这里分别为判别器参数、生成器参数各自建立一个优化器
            优化器采用Adam,算法参数为0.5和0.999
            优化器中可同时配置权重衰减,这里是1e-4
            学习率调节器默认配置为每100000步减小为0.5
        '''

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))
        '''
            注:这个apply函数递归地对每个子模块应用某种函数
        '''

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
        '''默认配置中,没有使用这个vgg网络'''
Beispiel #27
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = 0.5
        beta2 = 0.999
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = lr_scheduler.StepLR(self.dis_opt, step_size=hyperparameters['step_size'],
                                        gamma=hyperparameters['gamma'], last_epoch=-1)
        self.gen_scheduler = lr_scheduler.StepLR(self.gen_opt, step_size=hyperparameters['step_size'],
                                        gamma=hyperparameters['gamma'], last_epoch=-1)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        s_a_fake = self.gen_a.encode(x_a)
        s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(x_b, s_a)
        x_ab = self.gen_b.decode(x_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # encode
        s_a_prime = self.gen_a.encode(x_a)
        s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(x_a, s_a_prime)
        x_b_recon = self.gen_b.decode(x_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(x_b, s_a)
        x_ab = self.gen_b.decode(x_a, s_b)
        # encode again
        s_a_recon = self.gen_a.encode(x_ba)
        s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(x_ba, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(x_ab, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b

        self.loss_gen_total.backward()
        self.gen_opt.step()

    def sample(self, x_a, x_b):
        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(x_a[i].unsqueeze(0), s_a_fake))
            x_b_recon.append(self.gen_b.decode(x_b[i].unsqueeze(0), s_b_fake))
            x_ba1.append(self.gen_a.decode(x_b[i].unsqueeze(0), s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(x_b[i].unsqueeze(0), s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(x_a[i].unsqueeze(0), s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(x_a[i].unsqueeze(0), s_b2[i].unsqueeze(0)))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        # decode (cross domain)
        x_ba = self.gen_a.decode(x_b, s_a)
        x_ab = self.gen_b.decode(x_a, s_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    # def resume(self, checkpoint_dir, hyperparameters):
    #     # Load generators
    #     last_model_name = get_model_list(checkpoint_dir, "gen")
    #     state_dict = torch.load(last_model_name)
    #     self.gen_a.load_state_dict(state_dict['a'])
    #     self.gen_b.load_state_dict(state_dict['b'])
    #     iterations = int(last_model_name[-11:-3])
    #     # Load discriminators
    #     last_model_name = get_model_list(checkpoint_dir, "dis")
    #     state_dict = torch.load(last_model_name)
    #     self.dis_a.load_state_dict(state_dict['a'])
    #     self.dis_b.load_state_dict(state_dict['b'])
    #     # Load optimizers
    #     state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
    #     self.dis_opt.load_state_dict(state_dict['dis'])
    #     self.gen_opt.load_state_dict(state_dict['gen'])
    #     # Reinitilize schedulers
    #     self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
    #     self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
    #     print('Resume from iteration %d' % iterations)
    #     return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
Beispiel #28
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        self.display_size = int(hyperparameters['display_size'])
        self.s_a = self.random_style()
        self.s_b = self.random_style()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        """
        Args:
            x_a: Image domain A
            x_b: Image domain B
            hyperparameters:

        Returns:

        """

        self.gen_opt.zero_grad()
        s_a = self.random_style(x_a)
        s_b = self.random_style(x_b)

        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a) # c_a - content encoding, s_a_prime - style encoding
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime) # x_a_recon - reconstruction from content and style vectors
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a) # content b, style a
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba) # encode to get content_b and style_a from cross domain image
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

    def random_style(self, x=None, factor=1):
        dim = self.display_size if x is None else x.size(0)
        return Variable(torch.randn(dim, self.style_dim, 1, 1).cuda()) * factor

    def sample(self, x_a, x_b):
        """

        Args:
            x_a:
            x_b:

        Returns:
            (tuple): domainA: original
                              reconstruction
                              A to B - fixed sample noise
                              A to B - random noise

        """

        self.eval()
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        s_a2 = self.random_style(x_a, factor=5)
        s_b2 = self.random_style(x_b, factor=5)
        #print(s_a1, s_a2)
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def sampleB_toA(self, x_b):
        """
        Args:
            x_b:

        Returns:
            (tuple, length=batch): INPUT IMAGES to domain A
        """

        self.eval()
        s_a2 = self.random_style(x_b)

        x_ba1 = []

        for i in range(x_b.size(0)): # loop through batches
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_ba1.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
        return x_ba1

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = self.random_style(x_a)
        s_b = self.random_style(x_b)
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters, gen_model=None, dis_model=None):
        # Load generators
        if gen_model is None:
            gen_model = get_model_list(checkpoint_dir, "gen") # last gen model
        gen_state_dict = torch.load(gen_model)
        self.gen_a.load_state_dict(gen_state_dict['a'])
        self.gen_b.load_state_dict(gen_state_dict['b'])
        iterations = int(gen_model[-11:-3])

        # Load discriminators
        if dis_model is None:
            dis_model = get_model_list(checkpoint_dir, "dis")
        dis_state_dict = torch.load(dis_model)
        self.dis_a.load_state_dict(dis_state_dict['a'])
        self.dis_b.load_state_dict(dis_state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def load_model(self, checkpoint_dir, hyperparameters, iteration):
        from pathlib import Path
        gen_model = Path(checkpoint_dir) / f"gen_{iteration:08d}.pt"
        dis_model = Path(checkpoint_dir) / f"dis_{iteration:08d}.pt"
        return self.resume(checkpoint_dir, hyperparameters, gen_model.as_posix(), dis_model.as_posix())

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
Beispiel #29
0
class DGNet_Trainer(nn.Module):
    def __init__(self, hyperparameters, gpu_ids=[0]):
        super(DGNet_Trainer, self).__init__()
        lr_g = hyperparameters['lr_g']  #生成器学习率
        lr_d = hyperparameters['lr_d']  #判别器学习率
        ID_class = hyperparameters['ID_class']
        if not 'apex' in hyperparameters.keys():
            hyperparameters['apex'] = False
        self.fp16 = hyperparameters['apex']
        # Initiate the networks
        # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False.
        # 构建Es编码+解码过程 gen_a.encode()可以进行编码,gen_b.encode()可以进行解码
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'], fp16 = False)  # auto-encoder for domain a
        self.gen_b = self.gen_a  # auto-encoder for domain b
        
        # ID_stride,外观编码器池化层的stride
        if not 'ID_stride' in hyperparameters.keys():
            hyperparameters['ID_stride'] = 2
        # 构建外观编码器
        if hyperparameters['ID_style']=='PCB':
            self.id_a = PCB(ID_class)
        elif hyperparameters['ID_style']=='AB': #使用的AB编码器
            self.id_a = ft_netAB(ID_class, stride = hyperparameters['ID_stride'], norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) 
        else:
            self.id_a = ft_net(ID_class, norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) # return 2048 now
        # 浅拷贝,两者等同
        self.id_b = self.id_a
        
        # 鉴别器,使用的是一个多尺寸的鉴别器,即对图片进行几次缩放,并且对每次缩放都会预测,计算总的损失
         
        self.dis_a = MsImageDis(3, hyperparameters['dis'], fp16 = False)  # discriminator for domain a
        self.dis_b = self.dis_a # discriminator for domain b

        # load teachers 加载教师模型
        if hyperparameters['teacher'] != "":
            teacher_name = hyperparameters['teacher']
            print(teacher_name)
            # 构建教师模型
            teacher_names = teacher_name.split(',')
            teacher_model = nn.ModuleList()
            teacher_count = 0
      
            for teacher_name in teacher_names:
                config_tmp = load_config(teacher_name)
                if 'stride' in config_tmp:
                    stride = config_tmp['stride'] 
                else:
                    stride = 2
                # 网络搭建
                model_tmp = ft_net(ID_class, stride = stride)
                teacher_model_tmp = load_network(model_tmp, teacher_name)
                teacher_model_tmp.model.fc = nn.Sequential()  # remove the original fc layer in ImageNet
                teacher_model_tmp = teacher_model_tmp.cuda()
                if self.fp16:
                    teacher_model_tmp = amp.initialize(teacher_model_tmp, opt_level="O1")
                teacher_model.append(teacher_model_tmp.cuda().eval())
                teacher_count +=1
            self.teacher_model = teacher_model
            # 选择是否使用bn
            if hyperparameters['train_bn']:
                self.teacher_model = self.teacher_model.apply(train_bn)
        # 实例正则化
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # RGB to one channel
        if hyperparameters['single']=='edge':
            self.single = to_edge
        else:
            self.single = to_gray(False)

        # Random Erasing when training
        if not 'erasing_p' in hyperparameters.keys():
            self.erasing_p = 0
        else:
            self.erasing_p = hyperparameters['erasing_p']  # erasing_p表示随机擦除的概率
        # 随机擦除矩形区域的一些像素,数据增强
        self.single_re = RandomErasing(probability = self.erasing_p, mean=[0.0, 0.0, 0.0])

        if not 'T_w' in hyperparameters.keys():
            hyperparameters['T_w'] = 1
        # Setup the optimizers 设置优化器的参数
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) #+ list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) #+ list(self.gen_b.parameters())
        # 使用Adam优化器
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        # id params
        if hyperparameters['ID_style']=='PCB':
            ignored_params = (list(map(id, self.id_a.classifier0.parameters() ))
                            +list(map(id, self.id_a.classifier1.parameters() ))
                            +list(map(id, self.id_a.classifier2.parameters() ))
                            +list(map(id, self.id_a.classifier3.parameters() ))
                            )
            base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD([
                 {'params': base_params, 'lr': lr2},
                 {'params': self.id_a.classifier0.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier1.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier2.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier3.parameters(), 'lr': lr2*10}
            ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)
        elif hyperparameters['ID_style']=='AB':
            ignored_params = (list(map(id, self.id_a.classifier1.parameters()))
                            + list(map(id, self.id_a.classifier2.parameters())))
            # 获得基本的配置参数,如学习率
            base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD([
                 {'params': base_params, 'lr': lr2},
                 {'params': self.id_a.classifier1.parameters(), 'lr': lr2*10},
                 {'params': self.id_a.classifier2.parameters(), 'lr': lr2*10}
            ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)
        else:
            ignored_params = list(map(id, self.id_a.classifier.parameters() ))
            base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
            lr2 = hyperparameters['lr2']
            self.id_opt = torch.optim.SGD([
                 {'params': base_params, 'lr': lr2},
                 {'params': self.id_a.classifier.parameters(), 'lr': lr2*10}
            ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)
        
        # 选择各个网络优化的策略
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
        self.id_scheduler = get_scheduler(self.id_opt, hyperparameters)
        self.id_scheduler.gamma = hyperparameters['gamma2']

        #ID Loss
        self.id_criterion = nn.CrossEntropyLoss()
        self.criterion_teacher = nn.KLDivLoss(size_average=False)
        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        # save memory
        if self.fp16:
            # Name the FP16_Optimizer instance to replace the existing optimizer
            assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
            self.gen_a = self.gen_a.cuda()
            self.dis_a = self.dis_a.cuda()
            self.id_a = self.id_a.cuda()

            self.gen_b = self.gen_a
            self.dis_b = self.dis_a
            self.id_b = self.id_a

            self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1")
            self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1")
            self.id_a, self.id_opt = amp.initialize(self.id_a, self.id_opt, opt_level="O1")

    def to_re(self, x):
        out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3))
        out = out.cuda()
        for i in range(x.size(0)):
            out[i,:,:,:] = self.single_re(x[i,:,:,:])
        return out

    def recon_criterion(self, input, target):
        diff = input - target.detach()
        return torch.mean(torch.abs(diff[:]))

    def recon_criterion_sqrt(self, input, target):
        diff = input - target
        return torch.mean(torch.sqrt(torch.abs(diff[:])+1e-8))

    def recon_criterion2(self, input, target):
        diff = input - target
        return torch.mean(diff[:]**2)

    def recon_cos(self, input, target):
        cos = torch.nn.CosineSimilarity()
        cos_dis = 1 - cos(input, target)
        return torch.mean(cos_dis[:])
    
    # 送入x_a,x_b两张图片(来自训练集的不同ID)
    def forward(self, x_a, x_b, xp_a, xp_b):
        # 通过st 编码器,编码成两个st code:
        # s_a[batch,128,64,32]
        # s_b[batch,128,64,32]
        # single会根据参数设定判断是否转化为灰度图
        s_a = self.gen_a.encode(self.single(x_a))
        s_b = self.gen_b.encode(self.single(x_b))
        
        # f代表的是经过ap编码器得到的ap code,
        # p表示对身份的预测
        f_a, p_a = self.id_a(scale2(x_a))
        f_b, p_b = self.id_b(scale2(x_b))
        
        # 进行解码G操作,这里的x_a,与x_b进行ID交叉生成图片
        x_ba = self.gen_a.decode(s_b, f_a)
        x_ab = self.gen_b.decode(s_a, f_b)
        
        # 进行解码G操作,这里的x_a x_b进行自身同构生成图片
        x_a_recon = self.gen_a.decode(s_a, f_a)
        x_b_recon = self.gen_b.decode(s_b, f_b)
      
        # xp_a与x_a同ID不同图片  xp_b与x_b同ID不同图片
        # 进行外观编码
        fp_a, pp_a = self.id_a(scale2(xp_a))
        fp_b, pp_b = self.id_b(scale2(xp_b))
        # 进行解码G操作,这里的xp_a,与xp_b进行ID同构生成图片
        x_a_recon_p = self.gen_a.decode(s_a, fp_a)
        x_b_recon_p = self.gen_b.decode(s_b, fp_b)

        # Random Erasing only effect the ID and PID loss.
        # 进行像素擦除,后ap code编码
        if self.erasing_p > 0:
            x_a_re = self.to_re(scale2(x_a.clone()))
            x_b_re = self.to_re(scale2(x_b.clone()))
            xp_a_re = self.to_re(scale2(xp_a.clone()))
            xp_b_re = self.to_re(scale2(xp_b.clone()))
            _, p_a = self.id_a(x_a_re)
            _, p_b = self.id_b(x_b_re)
            # encode the same ID different photo
            _, pp_a = self.id_a(xp_a_re) 
            _, pp_b = self.id_b(xp_b_re)
            
        # 混合合成图片:x_ab[images_a的st,images_b的ap]    混合合成图片x_ba[images_b的st,images_a的ap]
        # s_a[输入图片images_a经过Es编码得到的 st code]     s_b[输入图片images_b经过Es编码得到的 st code]
        # f_a[输入图片images_a经过Ea编码得到的 ap code]     f_b[输入图片images_b经过Ea编码得到的 ap code]
        # p_a[输入图片images_a经过Ea编码进行身份ID的预测]   p_b[输入图片images_b经过Ea编码进行身份ID的预测]
        # pp_a[输入图片pos_a经过Ea编码进行身份ID的预测]     pp_b[输入图片pos_b经过Ea编码进行身份ID的预测]
        # x_a_recon[输入图片images_a(s_a)与自己(f_a)合成的图片,当然和images_a长得一样]
        # x_b_recon[输入图片images_b(s_b)与自己(f_b)合成的图片,当然和images_b长得一样]
        # x_a_recon_p[输入图片images_a(s_a)与图片pos_a(fp_a)合成的图片,当然和images_a长得一样]
        # x_b_recon_p[输入图片images_a(s_a)与图片pos_b(fp_b)合成的图片,当然和images_b长得一样]

        return x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p

    def gen_update(self, x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, x_a, x_b, xp_a, xp_b, l_a, l_b, hyperparameters, iteration, num_gpu):
        # ppa, ppb is the same person
        self.gen_opt.zero_grad()
        self.id_opt.zero_grad()
 
        # no gradient
        # 对合成x_ba与x_ab分别进行一份拷贝
        x_ba_copy = Variable(x_ba.data, requires_grad=False)
        x_ab_copy = Variable(x_ab.data, requires_grad=False)

        rand_num = random.uniform(0,1)
        #################################
        # encode structure
        if hyperparameters['use_encoder_again']>=rand_num:
            # encode again (encoder is tuned, input is fixed)
            s_a_recon = self.gen_b.enc_content(self.single(x_ab_copy))
            s_b_recon = self.gen_a.enc_content(self.single(x_ba_copy))
        else:
            # copy the encoder
            self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content)
            self.enc_content_copy = self.enc_content_copy.eval()
            # encode again (encoder is fixed, input is tuned)
            s_a_recon = self.enc_content_copy(self.single(x_ab))
            s_b_recon = self.enc_content_copy(self.single(x_ba))

        #################################
        # encode appearance
        self.id_a_copy = copy.deepcopy(self.id_a)
        self.id_a_copy = self.id_a_copy.eval()
        if hyperparameters['train_bn']:
            self.id_a_copy = self.id_a_copy.apply(train_bn)
        self.id_b_copy = self.id_a_copy
        # encode again (encoder is fixed, input is tuned)
        # 对混合生成的图片x_ba,x_ab记行Es编码操作,同时对身份进行鉴别
        # f_a_recon,f_b_recon表示的ap code,p_a_recon,p_b_recon表示对身份的鉴别 Lrecon^code2
        f_a_recon, p_a_recon = self.id_a_copy(scale2(x_ba))
        f_b_recon, p_b_recon = self.id_b_copy(scale2(x_ab))

        # teacher Loss
        #  Tune the ID model
        log_sm = nn.LogSoftmax(dim=1)
        if hyperparameters['teacher_w'] >0 and hyperparameters['teacher'] != "":
            if hyperparameters['ID_style'] == 'normal':
                _, p_a_student = self.id_a(scale2(x_ba_copy))
                p_a_student = log_sm(p_a_student)
                p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy), num_class = hyperparameters['ID_class'], alabel = l_a, slabel = l_b, teacher_style = hyperparameters['teacher_style'])
                self.loss_teacher = self.criterion_teacher(p_a_student, p_a_teacher) / p_a_student.size(0)

                _, p_b_student = self.id_b(scale2(x_ab_copy))
                p_b_student = log_sm(p_b_student)
                
                p_b_teacher = predict_label(self.teacher_model, scale2(x_ab_copy), num_class = hyperparameters['ID_class'], alabel = l_b, slabel = l_a, teacher_style = hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(p_b_student, p_b_teacher) / p_b_student.size(0)
            elif hyperparameters['ID_style'] == 'AB':
                # normal teacher-student loss
                # BA -> LabelA(smooth) + LabelB(batchB) Lprim Lrecon^code1
                _, p_ba_student = self.id_a(scale2(x_ba_copy))# f_a, s_b
                p_a_student = log_sm(p_ba_student[0])
                # 计算离散距离,可以理解为p_a_student与p_a_teacher每个元素的距离和,然后除以p_a_student.size(0)取平均值
	            # 就是说学生网络(Ea)的预测越与教师网络结果相同,则是最好的
                with torch.no_grad():
                    p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy), num_class = hyperparameters['ID_class'], alabel = l_a, slabel = l_b, teacher_style = hyperparameters['teacher_style'])
                self.loss_teacher = self.criterion_teacher(p_a_student, p_a_teacher) / p_a_student.size(0)

                _, p_ab_student = self.id_b(scale2(x_ab_copy)) # f_b, s_a
                p_b_student = log_sm(p_ab_student[0])
                with torch.no_grad():

                    p_b_teacher = predict_label(self.teacher_model, scale2(x_ab_copy), num_class = hyperparameters['ID_class'], alabel = l_b, slabel = l_a, teacher_style = hyperparameters['teacher_style'])
                self.loss_teacher += self.criterion_teacher(p_b_student, p_b_teacher) / p_b_student.size(0)

                # branch b loss
                # here we give different label Lfine
                # 求得学生网络(Ea)预测ID结果,与实际ID之间的损失,注意这里取的是p_ba_student[1],表示是细节身份特征
                loss_B = self.id_criterion(p_ba_student[1], l_b) + self.id_criterion(p_ab_student[1], l_a)
                self.loss_teacher = hyperparameters['T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B
        else:
            self.loss_teacher = 0.0

        # auto-encoder image reconstruction Limg1^recon Limg2^recon
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a)
        self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b)

        # feature reconstruction
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0
        self.loss_gen_recon_f_a = self.recon_criterion(f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0
        self.loss_gen_recon_f_b = self.recon_criterion(f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0

        x_aba = self.gen_a.decode(s_a_recon, f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(s_b_recon, f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # ID loss AND Tune the Generated image
        if hyperparameters['ID_style']=='PCB':
            self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b)
            self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b)
            self.loss_gen_recon_id = self.PCB_loss(p_a_recon, l_a) + self.PCB_loss(p_b_recon, l_b)
        elif hyperparameters['ID_style']=='AB':
            weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w']
            self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \
                         + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) )
            self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion(pp_b[0], l_b) #+ weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) )
            self.loss_gen_recon_id = self.id_criterion(p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b)
        else:
            self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion(p_b, l_b)
            self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion(pp_b, l_b)
            self.loss_gen_recon_id = self.id_criterion(p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b)

        #print(f_a_recon, f_a)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        if num_gpu>1:
            self.loss_gen_adv_a = self.dis_a.module.calc_gen_loss(self.dis_a, x_ba)
            self.loss_gen_adv_b = self.dis_b.module.calc_gen_loss(self.dis_b, x_ab)
        else:
            self.loss_gen_adv_a = self.dis_a.calc_gen_loss(self.dis_a, x_ba)
            self.loss_gen_adv_b = self.dis_b.calc_gen_loss(self.dis_b, x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # 设置每个loss所占的权重
        if iteration > hyperparameters['warm_iter']:
            hyperparameters['recon_f_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'], hyperparameters['max_w'])
            hyperparameters['recon_s_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'], hyperparameters['max_w'])
            hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale']
            hyperparameters['recon_x_cyc_w'] = min(hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w'])

        if iteration > hyperparameters['warm_teacher_iter']:
            hyperparameters['teacher_w'] += hyperparameters['warm_scale']
            hyperparameters['teacher_w'] = min(hyperparameters['teacher_w'], hyperparameters['max_teacher_w'])
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \       #GAN损失
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \   #GAN图像重构损失
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \   #编码损失
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \
                              hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['id_w'] * self.loss_id + \
                              hyperparameters['pid_w'] * self.loss_pid + \
                              hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \  #id损失
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \
                              hyperparameters['teacher_w'] * self.loss_teacher             #辨别模块损失
        if self.fp16:
            with amp.scale_loss(self.loss_gen_total, [self.gen_opt,self.id_opt]) as scaled_loss:
                scaled_loss.backward()
            self.gen_opt.step()
            self.id_opt.step()
        else:
            self.loss_gen_total.backward()
            self.gen_opt.step()
            self.id_opt.step()
        print("L_total: %.4f, L_gan: %.4f,  Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \
                                                        hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \
                                                        hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \
                                                        hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \
                                                        hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \
                                                        hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \
                                                        hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \
                                                        hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \
                                                        hyperparameters['id_w'] * self.loss_id,\
                                                        hyperparameters['pid_w'] * self.loss_pid,\
                                                        hyperparameters['teacher_w'] * self.loss_teacher )  )
Beispiel #30
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters, device):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        self.device = device
        # Initiate the networks
        self.gen_a = AdaINGen(
            hyperparameters['input_dim_a'],
            hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(
            hyperparameters['input_dim_b'],
            hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(
            hyperparameters['input_dim_a'],
            hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(
            hyperparameters['input_dim_b'],
            hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1,
                               1).to(self.device)
        self.s_b = torch.randn(display_size, self.style_dim, 1,
                               1).to(self.device)

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(
            self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(
            self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, self.s_a)
        x_ab = self.gen_b.decode(c_a, self.s_b)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        s_a = torch.randn(x_a.size(0), self.style_dim, 1, 1).to(self.device)
        s_b = torch.randn(x_b.size(0), self.style_dim, 1, 1).to(self.device)
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # decode again (if needed)
        x_aba = self.gen_a.decode(
            c_a_recon,
            s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(
            c_b_recon,
            s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(
            x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(
            x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(
            self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(
            self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, x_a, x_b):
        self.eval()
        s_a2 = torch.randn(x_a.size(0), self.style_dim, 1, 1).to(self.device)
        s_b2 = torch.randn(x_b.size(0), self.style_dim, 1, 1).to(self.device)
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
            x_ba1.append(self.gen_a.decode(c_b, self.s_a[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, self.s_b[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = torch.randn(x_a.size(0), self.style_dim, 1, 1).to(self.device)
        s_b = torch.randn(x_b.size(0), self.style_dim, 1, 1).to(self.device)
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters[
            'gan_w'] * self.loss_dis_a + hyperparameters[
                'gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({
            'a': self.gen_a.state_dict(),
            'b': self.gen_b.state_dict()
        }, gen_name)
        torch.save({
            'a': self.dis_a.state_dict(),
            'b': self.dis_b.state_dict()
        }, dis_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
Beispiel #31
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        # self.gen_aT = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda(cun)
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda(cun)

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

    def recon_criterion(self, input, target):
        return torch.mean(torch.abs(input - target))

    def forward(self, x_a, x_b):
        self.eval()
        s_a = Variable(self.s_a)
        s_b = Variable(self.s_b)
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_aT, s_a_fake_T = self.gen_a.encodeT(x_a)
        
        # self.gen_a.ptl()
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(c_b, s_a_fake)#change +  _fake
        x_ab = self.gen_b.decode(c_a, s_b_fake)
        x_aT = self.gen_a.decodeT(c_a,s_a_fake_T)
        self.train()
        return x_ab, x_ba

    def gen_update(self, x_a, x_b, hyperparameters):
        self.gen_opt.zero_grad()
        # s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(cun))#torch.randn(*sizes)
        # s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(cun))
        # encode
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        c_aT,s_aT_prime = self.gen_a.encodeT(x_a)

        # decode (within domain)
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_aT_recon = self.gen_a.decodeT(c_a,s_aT_prime)

        # print("style code size:",s_a_prime.size())
        
        # print("recon img size:",x_a_recon.size())
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a_prime)
        x_ab = self.gen_b.decode(c_a, s_b_prime)
        # encode again
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
        # print("content code size:",c_a_recon.size())
        # decode again (if needed)
        x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # reconstruction loss
        self.loss_gen_styleT = self.recon_criterion(x_a,x_aT_recon)
        self.loss_gen_content = self.recon_criterion(c_a,c_aT)
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a_prime)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b_prime)
        # self.loss_gen_geo = self.recon_criterion(s_a_prime,s_b_prime)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        # GAN loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
        # domain-invariant perceptual loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b
        self.loss_gen_total += self.recon_criterion(c_a,c_aT) * 5
        self.loss_gen_total += self.loss_gen_styleT * 5
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):
        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

    def sample(self, x_a, x_b):
        self.eval()
        # s_a1 = Variable(self.s_a)
        # s_b1 = Variable(self.s_b)
        # s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(cun))
        # s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(cun))
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2, x_ba, x_ab = [], [], [], [], [], [], [], []
        # for i in range(x_a.size(0)):
        c_a, s_a_fake = self.gen_a.encode(x_a)
        c_b, s_b_fake = self.gen_b.encode(x_b)
        x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
        x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
        x_ba.append(self.gen_a.decode(c_b, s_a_fake))
        # x_ba2.append(self.gen_a.decode(c_b, s_a_fake))
        x_ab.append(self.gen_b.decode(c_a, s_b_fake))
        # x_ab2.append(self.gen_b.decode(c_a, s_b_fake))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba = torch.cat(x_ba)
        x_ab = torch.cat(x_ab)
 
        # x_ab1, x_ab2 = torch.cat(x_ab1), torch
        # .cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba

    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(cun))
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(cun))
        # encode
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, _)
        x_ab = self.gen_b.decode(c_a, _)
        # D loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):
        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen_a.load_state_dict(state_dict['a'])
        self.gen_b.load_state_dict(state_dict['b'])
        iterations = int(last_model_name[-11:-3])
        # Load discriminators
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis_a.load_state_dict(state_dict['a'])
        self.dis_b.load_state_dict(state_dict['b'])
        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)


# class UNIT_Trainer(nn.Module):
#     def __init__(self, hyperparameters):
#         super(UNIT_Trainer, self).__init__()
#         lr = hyperparameters['lr']
#         # Initiate the networks
#         self.gen_a = VAEGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
#         self.gen_b = VAEGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
#         self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
#         self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
#         self.instancenorm = nn.InstanceNorm2d(512, affine=False)

#         # Setup the optimizers
#         beta1 = hyperparameters['beta1']
#         beta2 = hyperparameters['beta2']
#         dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
#         gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
#         self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
#                                         lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
#         self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
#                                         lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
#         self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
#         self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

#         # Network weight initialization
#         self.apply(weights_init(hyperparameters['init']))
#         self.dis_a.apply(weights_init('gaussian'))
#         self.dis_b.apply(weights_init('gaussian'))

#         # Load VGG model if needed
#         if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
#             self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
#             self.vgg.eval()
#             for param in self.vgg.parameters():
#                 param.requires_grad = False

#     def recon_criterion(self, input, target):
#         return torch.mean(torch.abs(input - target))

#     def forward(self, x_a, x_b):
#         self.eval()
#         h_a, _ = self.gen_a.encode(x_a)
#         h_b, _ = self.gen_b.encode(x_b)
#         x_ba = self.gen_a.decode(h_b)
#         x_ab = self.gen_b.decode(h_a)
#         self.train()
#         return x_ab, x_ba

#     def __compute_kl(self, mu):
#         # def _compute_kl(self, mu, sd):
#         # mu_2 = torch.pow(mu, 2)
#         # sd_2 = torch.pow(sd, 2)
#         # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0)
#         # return encoding_loss
#         mu_2 = torch.pow(mu, 2)
#         encoding_loss = torch.mean(mu_2)
#         return encoding_loss

#     def gen_update(self, x_a, x_b, hyperparameters):
#         self.gen_opt.zero_grad()
#         # encode
#         h_a, n_a = self.gen_a.encode(x_a)
#         h_b, n_b = self.gen_b.encode(x_b)
#         # decode (within domain)
#         x_a_recon = self.gen_a.decode(h_a + n_a)
#         x_b_recon = self.gen_b.decode(h_b + n_b)
#         # decode (cross domain)
#         x_ba = self.gen_a.decode(h_b + n_b)
#         x_ab = self.gen_b.decode(h_a + n_a)
#         # encode again
#         h_b_recon, n_b_recon = self.gen_a.encode(x_ba)
#         h_a_recon, n_a_recon = self.gen_b.encode(x_ab)
#         # decode again (if needed)
#         x_aba = self.gen_a.decode(h_a_recon + n_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
#         x_bab = self.gen_b.decode(h_b_recon + n_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None

#         # reconstruction loss
#         self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
#         self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
#         self.loss_gen_recon_kl_a = self.__compute_kl(h_a)
#         self.loss_gen_recon_kl_b = self.__compute_kl(h_b)
#         self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a)
#         self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b)
#         self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon)
#         self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon)
#         # GAN loss
#         self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
#         self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
#         # domain-invariant perceptual loss
#         self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
#         self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
#         # total loss
#         # self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \                           
#         #                       hyperparameters['gan_w'] * self.loss_gen_adv_b + \
#         #                       hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
#         #                       hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
#         #                       hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
#         #                       hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
#         #                       hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
#         #                       hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
#         #                       hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
#         #                       hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \
#         #                       hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
#         #                       hyperparameters['vgg_w'] * self.loss_gen_vgg_b
#         # self.loss_gen_total.backward()
#         # self.gen_opt.step()
#         self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \                           
#                               hyperparameters['gan_w'] * self.loss_gen_adv_b + \
#                               hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
#                               hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
#                               hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
#                               hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
#                               hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
#                               hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
#                               hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
#                               hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab + \
#                               hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
#                               hyperparameters['vgg_w'] * self.loss_gen_vgg_b
#         self.loss_gen_total.backward()
#         self.gen_opt.step()

#     def compute_vgg_loss(self, vgg, img, target):
#         img_vgg = vgg_preprocess(img)
#         target_vgg = vgg_preprocess(target)
#         img_fea = vgg(img_vgg)
#         target_fea = vgg(target_vgg)
#         return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

#     def sample(self, x_a, x_b):
#         self.eval()
#         x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], []
#         for i in range(x_a.size(0)):
#             h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0))
#             h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0))
#             x_a_recon.append(self.gen_a.decode(h_a))
#             x_b_recon.append(self.gen_b.decode(h_b))
#             x_ba.append(self.gen_a.decode(h_b))
#             x_ab.append(self.gen_b.decode(h_a))
#         x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
#         x_ba = torch.cat(x_ba)
#         x_ab = torch.cat(x_ab)
#         self.train()
#         return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba

#     def dis_update(self, x_a, x_b, hyperparameters):
#         self.dis_opt.zero_grad()
#         # encode
#         h_a, n_a = self.gen_a.encode(x_a)
#         h_b, n_b = self.gen_b.encode(x_b)
#         # decode (cross domain)
#         x_ba = self.gen_a.decode(h_b + n_b)
#         x_ab = self.gen_b.decode(h_a + n_a)
#         # D loss
#         self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
#         self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
#         self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
#         self.loss_dis_total.backward()
#         self.dis_opt.step()

#     def update_learning_rate(self):
#         if self.dis_scheduler is not None:
#             self.dis_scheduler.step()
#         if self.gen_scheduler is not None:
#             self.gen_scheduler.step()

#     def resume(self, checkpoint_dir, hyperparameters):
#         # Load generators
#         last_model_name = get_model_list(checkpoint_dir, "gen")
#         state_dict = torch.load(last_model_name)
#         self.gen_a.load_state_dict(state_dict['a'])
#         self.gen_b.load_state_dict(state_dict['b'])
#         iterations = int(last_model_name[-11:-3])
#         # Load discriminators
#         last_model_name = get_model_list(checkpoint_dir, "dis")
#         state_dict = torch.load(last_model_name)
#         self.dis_a.load_state_dict(state_dict['a'])
#         self.dis_b.load_state_dict(state_dict['b'])
#         # Load optimizers
#         state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
#         self.dis_opt.load_state_dict(state_dict['dis'])
#         self.gen_opt.load_state_dict(state_dict['gen'])
#         # Reinitilize schedulers
#         self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
#         self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
#         print('Resume from iteration %d' % iterations)
#         return iterations

#     def save(self, snapshot_dir, iterations):
#         # Save generators, discriminators, and optimizers
#         gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
#         dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (
# terations + 1))
#         opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
#         torch.save({'a': self.gen_a.state_dict(), 'b': self.gen_b.state_dict()}, gen_name)
        torch.save({'a': self.dis_a.state_dict(), 'b': self.dis_b.state_dict()}, dis_name)
        torch.save({'gen': self.gen_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
Beispiel #32
0
class SupIntrinsicTrainer(nn.Module):
    def __init__(self, param):
        super(SupIntrinsicTrainer, self).__init__()
        lr = param['lr']
        # Initiate the networks
        self.model = AdaINGen(param['input_dim_a'],
                              param['input_dim_b'] + param['input_dim_b'],
                              param['gen'])  # auto-encoder

        # Setup the optimizers
        beta1 = param['beta1']
        beta2 = param['beta2']
        gen_params = list(self.model.parameters())
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=param['weight_decay'])
        self.gen_scheduler = get_scheduler(self.gen_opt, param)

        # Network weight initialization
        self.apply(weights_init(param['init']))
        self.best_result = float('inf')

    def recon_criterion(self, input, target, mask=None):
        if mask is not None:
            return torch.mean(torch.abs(input[mask] - target[mask]))
        else:
            return torch.mean(torch.abs(input - target))

    def forward(self, x):
        self.eval()
        out = self.model(x)
        x_r, x_s = out[:, :3, :, :], out[:, :, 3:, :]
        return x_r, x_s

    def gen_update(self, x_i, x_r, x_s, x_m, param):
        self.gen_opt.zero_grad()

        out = self.model(x_i)
        pred_r, pred_s = out[:, :3, :, :], out[:, 3:, :, :]

        # reconstruction loss
        self.loss_r = self.recon_criterion(pred_r, x_r, x_m)
        self.loss_s = self.recon_criterion(pred_s, x_s, x_m)
        # total loss
        self.loss_gen_total = self.loss_r + self.loss_s
        self.loss_gen_total.backward()
        self.gen_opt.step()

    def sample(self, x_i, x_r, x_s):
        self.eval()
        x_ri, x_si = [], []
        for i in range(x_i.size(0)):
            out = self.model(x_i[i].unsqueeze(0))
            x_r, x_s = out[:, :3, :, :], out[:, 3:, :, :]
            x_ri.append(x_r)
            x_si.append(x_s)
        x_ri = torch.cat(x_ri)
        x_si = torch.cat(x_si)
        self.train()
        return x_i, x_r, x_ri, x_s, x_si

    # noinspection PyAttributeOutsideInit
    def dis_update(self, x_i, x_r, x_s, param=None):
        pass

    def update_learning_rate(self):
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, param):
        # Load generators
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.model.load_state_dict(state_dict['i'])
        iterations = int(last_model_name[-11:-3])

        # Load optimizers
        state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
        self.gen_opt.load_state_dict(state_dict['gen'])
        # Reinitilize schedulers
        self.gen_scheduler = get_scheduler(self.gen_opt, param, iterations)
        print('Resume from iteration %d' % iterations)
        return iterations

    def save(self, snapshot_dir, iterations):
        # Save generators, discriminators, and optimizers
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
        opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
        torch.save(
            {
                'model': self.model.state_dict(),
                'best_result': self.best_result
            }, gen_name)
        torch.save({'gen': self.gen_opt.state_dict()}, opt_name)