def __init__(self, configer): super(CycleGAN, self).__init__() # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.configer = configer self.netG_A = SubNetSelector.generator( net_dict=self.configer.get('network', 'generatorA'), use_dropout=self.configer.get('network', 'use_dropout'), norm_type=self.configer.get('network', 'norm_type')) self.netG_B = SubNetSelector.generator( net_dict=self.configer.get('network', 'generatorB'), use_dropout=self.configer.get('network', 'use_dropout'), norm_type=self.configer.get('network', 'norm_type')) self.netD_A = SubNetSelector.discriminator( net_dict=self.configer.get('network', 'discriminatorA'), norm_type=self.configer.get('network', 'norm_type')) self.netD_B = SubNetSelector.discriminator( net_dict=self.configer.get('network', 'discriminatorB'), norm_type=self.configer.get('network', 'norm_type')) self.fake_A_pool = ImagePool( self.configer.get('network', 'imgpool_size')) self.fake_B_pool = ImagePool( self.configer.get('network', 'imgpool_size')) # define loss functions self.criterionGAN = GANLoss( gan_mode=self.configer.get('loss', 'params')['gan_mode']) self.criterionCycle = nn.L1Loss() self.criterionIdt = nn.L1Loss()
def __init__(self, configer): super(CycleGAN, self).__init__() # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.configer = configer self.netG_A = SubnetSelector.generator( self.configer.get('network', 'generatorA')) self.netG_B = SubnetSelector.generator( self.configer.get('network', 'generatorB')) self.netD_A = SubnetSelector.discriminator( self.configer.get('network', 'discriminatorA')) self.netD_B = SubnetSelector.discriminator( self.configer.get('network', 'discriminatorB')) self.fake_A_pool = ImagePool( self.configer.get('network', 'discriminatorA')['pool_size']) self.fake_B_pool = ImagePool( self.configer.get('network', 'discriminatorB')['pool_size']) # define loss functions self.criterionGAN = GANLoss( use_lsgan=self.configer.get('loss', 'use_lsgan')) self.criterionCycle = nn.L1Loss() self.criterionIdt = nn.L1Loss()
def __init__(self, configer): super(Pix2Pix, self).__init__() self.configer = configer # load/define networks self.netG = SubnetSelector.generator(net_dict=self.configer.get('network', 'generator')) self.netD = SubnetSelector.discriminator(net_dict=self.configer.get('network', 'discriminator')) self.fake_AB_pool = ImagePool(self.configer.get('network', 'imgpool_size')) # define loss functions self.criterionGAN = GANLoss(gan_mode=self.configer.get('loss', 'gan_mode')) self.criterionL1 = nn.L1Loss()
def initialize(self, opt): # load/define networks self.netG = SubnetSelector.generator( self.configer.get('network', 'generator')) self.netD = SubnetSelector.discriminator( self.configer.get('network', 'discriminator')) self.fake_AB_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = GANLoss( use_lsgan=self.configer.get('loss', 'use_lsgan')) self.criterionL1 = nn.L1Loss()
class Pix2Pix(nn.Module): def __init__(self, configer): super(Pix2Pix, self).__init__() self.configer = configer # load/define networks self.netG = SubNetSelector.generator( net_dict=self.configer.get('network', 'generator'), use_dropout=self.configer.get('network', 'use_dropout'), norm_type=self.configer.get('network', 'norm_type')) self.netD = SubNetSelector.discriminator( net_dict=self.configer.get('network', 'discriminator'), norm_type=self.configer.get('network', 'norm_type')) self.fake_AB_pool = ImagePool( self.configer.get('network', 'imgpool_size')) # define loss functions self.criterionGAN = GANLoss( gan_mode=self.configer.get('loss', 'params')['gan_mode']) self.criterionL1 = nn.L1Loss() def forward(self, data_dict, testing=False): if testing: out_dict = dict() if 'imgA' in data_dict: out_dict['realA'] = data_dict['imgA'] out_dict['fakeB'] = self.netG.forward(data_dict['imgA']) if 'imgB' in data_dict: out_dict['realB'] = data_dict['imgB'] return out_dict # First, G(A) should fake the discriminator fake_B = self.netG.forward(data_dict['imgA']) G_fake_AB = torch.cat((data_dict['imgA'], fake_B), 1) pred_fake = self.netD.forward(G_fake_AB) loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B loss_G_L1 = self.criterionL1(fake_B, data_dict['imgB']) * self.configer.get( 'loss', 'loss_weights')['l1_loss'] loss_G = loss_G_GAN + loss_G_L1 D_fake_AB = self.fake_AB_pool.query( torch.cat((data_dict['imgA'], fake_B), 1)) pred_fake = self.netD.forward(D_fake_AB.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Real D_real_AB = torch.cat((data_dict['imgA'], data_dict['imgB']), 1) self.pred_real = self.netD.forward(D_real_AB) loss_D_real = self.criterionGAN(self.pred_real, True) # Combined loss loss_D = (loss_D_fake + loss_D_real) * 0.5 return dict(loss_G=loss_G, loss_D=loss_D)
class Pix2Pix(nn.Module): def initialize(self, opt): # load/define networks self.netG = SubnetSelector.generator( self.configer.get('network', 'generator')) self.netD = SubnetSelector.discriminator( self.configer.get('network', 'discriminator')) self.fake_AB_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = GANLoss( use_lsgan=self.configer.get('loss', 'use_lsgan')) self.criterionL1 = nn.L1Loss() def forward(self, data_dict): # First, G(A) should fake the discriminator fake_B = self.netG.forward(data_dict['imgA']) G_fake_AB = torch.cat((data_dict['imgA'], fake_B), 1) pred_fake = self.netD.forward(G_fake_AB) loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B loss_G_L1 = self.criterionL1(fake_B, data_dict['imgB']) * self.opt.lambda_A loss_G = loss_G_GAN + loss_G_L1 D_fake_AB = self.fake_AB_pool.query( torch.cat((data_dict['imgA'], fake_B), 1)) pred_fake = self.netD.forward(D_fake_AB.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Real D_real_AB = torch.cat((data_dict['imgA'], data_dict['imgB']), 1) self.pred_real = self.netD.forward(D_real_AB) loss_D_real = self.criterionGAN(self.pred_real, True) # Combined loss loss_D = (loss_D_fake + loss_D_real) * 0.5 return dict(loss=loss_G + loss_D) def forward_test(self, data_dict): return dict(fakeB=self.netG.forward(data_dict['imgA']))
class CycleGAN(nn.Module): def __init__(self, configer): super(CycleGAN, self).__init__() # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.configer = configer self.netG_A = SubNetSelector.generator( net_dict=self.configer.get('network', 'generatorA'), use_dropout=self.configer.get('network', 'use_dropout'), norm_type=self.configer.get('network', 'norm_type')) self.netG_B = SubNetSelector.generator( net_dict=self.configer.get('network', 'generatorB'), use_dropout=self.configer.get('network', 'use_dropout'), norm_type=self.configer.get('network', 'norm_type')) self.netD_A = SubNetSelector.discriminator( net_dict=self.configer.get('network', 'discriminatorA'), norm_type=self.configer.get('network', 'norm_type')) self.netD_B = SubNetSelector.discriminator( net_dict=self.configer.get('network', 'discriminatorB'), norm_type=self.configer.get('network', 'norm_type')) self.fake_A_pool = ImagePool( self.configer.get('network', 'imgpool_size')) self.fake_B_pool = ImagePool( self.configer.get('network', 'imgpool_size')) # define loss functions self.criterionGAN = GANLoss( gan_mode=self.configer.get('loss', 'params')['gan_mode']) self.criterionCycle = nn.L1Loss() self.criterionIdt = nn.L1Loss() def forward(self, data_dict, testing=False): if testing: out_dict = dict() if 'imgA' in data_dict: fake_B = self.netG_A.forward(data_dict['imgA']) rec_A = self.netG_B.forward(fake_B) out_dict['fakeB'] = fake_B out_dict['recA'] = rec_A if 'imgB' in data_dict: fake_A = self.netG_B.forward(data_dict['imgB']) rec_B = self.netG_A.forward(fake_A) out_dict['fakeA'] = fake_A out_dict['recB'] = rec_B return out_dict cycle_loss_weight = self.configer.get('loss', 'loss_weights')['cycle_loss'] idt_loss_weight = self.configer.get('loss', 'loss_weights')['idt_loss'] # Identity loss if idt_loss_weight > 0: # G_A should be identity if real_B is fed. idt_A = self.netG_A.forward(data_dict['imgB']) loss_idt_A = self.criterionIdt(idt_A, data_dict['imgB']) * idt_loss_weight # G_B should be identity if real_A is fed. idt_B = self.netG_B.forward(data_dict['imgA']) loss_idt_B = self.criterionIdt(idt_B, data_dict['imgA']) * idt_loss_weight else: loss_idt_A = 0 loss_idt_B = 0 # GAN loss # D_A(G_A(A)) fake_B = self.netG_A.forward(data_dict['imgA']) pred_fake = self.netD_B.forward(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) fake_A = self.netG_B.forward(data_dict['imgB']) pred_fake = self.netD_A.forward(fake_A) loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss rec_A = self.netG_B.forward(fake_B) loss_cycle_A = self.criterionCycle( rec_A, data_dict['imgA']) * cycle_loss_weight # Backward cycle loss rec_B = self.netG_A.forward(fake_A) loss_cycle_B = self.criterionCycle( rec_B, data_dict['imgB']) * cycle_loss_weight # combined loss loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B D_fake_A = self.fake_A_pool.query(fake_A.clone()) D_real_A = self.netD_A.forward(data_dict['imgA']) loss_D_real_A = self.criterionGAN(D_real_A, True) # Fake D_fake_A = self.netD_A.forward(D_fake_A.detach()) loss_D_fake_A = self.criterionGAN(D_fake_A, False) # Combined loss loss_D_A = (loss_D_real_A + loss_D_fake_A) * 0.5 D_fake_B = self.fake_B_pool.query(fake_B.clone()) D_real_B = self.netD_B.forward(data_dict['imgB']) loss_D_real_B = self.criterionGAN(D_real_B, True) # Fake D_fake_B = self.netD_A.forward(D_fake_B.detach()) loss_D_fake_B = self.criterionGAN(D_fake_B, False) # Combined loss loss_D_B = (loss_D_real_B + loss_D_fake_B) * 0.5 return dict(loss=loss_G + loss_D_A + loss_D_B)