def __init__(self, GS_nlayers=6, DS_nlayers=5, GS_nf=32, DS_nf=32, GT_nlayers=6, DT_nlayers=5, GT_nf=32, DT_nf=32, gpu=True): super(ShapeMatchingGAN, self).__init__() self.GS_nlayers = GS_nlayers self.DS_nlayers = DS_nlayers self.GS_nf = GS_nf self.DS_nf = DS_nf self.GT_nlayers = GT_nlayers self.DT_nlayers = DT_nlayers self.GT_nf = GT_nf self.DT_nf = DT_nf self.gpu = gpu self.lambda_l1 = 100 self.lambda_gp = 10 self.lambda_sadv = 0.1 self.lambda_gly = 1.0 self.lambda_tadv = 1.0 self.lambda_sty = 0.01 self.style_weights = [1e3 / n**2 for n in [64, 128, 256, 512, 512]] self.loss = nn.L1Loss() self.gramloss = GramMSELoss() self.gramloss = self.gramloss.cuda() if self.gpu else self.gramloss self.getmask = SemanticFeature() for param in self.getmask.parameters(): param.requires_grad = False self.G_S = GlyphGenerator(self.GS_nf, self.GS_nlayers) self.D_S = Discriminator(3, self.DS_nf, self.DS_nlayers) self.G_T = TextureGenerator(self.GT_nf, self.GT_nlayers) self.D_T = Discriminator(6, self.DT_nf, self.DT_nlayers) self.trainerG_S = torch.optim.Adam(self.G_S.parameters(), lr=0.0002, betas=(0.5, 0.999)) self.trainerD_S = torch.optim.Adam(self.D_S.parameters(), lr=0.0002, betas=(0.5, 0.999)) self.trainerG_T = torch.optim.Adam(self.G_T.parameters(), lr=0.0002, betas=(0.5, 0.999)) self.trainerD_T = torch.optim.Adam(self.D_T.parameters(), lr=0.0002, betas=(0.5, 0.999))
class ShapeMatchingGAN(nn.Module): def __init__(self, GS_nlayers=6, DS_nlayers=5, GS_nf=32, DS_nf=32, GT_nlayers=6, DT_nlayers=5, GT_nf=32, DT_nf=32, gpu=True): """ :param GS_nlayers: 6 :param DS_nlayers: 4 :param GS_nf: 32 :param DS_nf: 32 :param GT_nlayers: 6 :param DT_nlayers: 4 :param GT_nf: 32 :param DT_nf: 32 :param gpu: """ super(ShapeMatchingGAN, self).__init__() self.GS_nlayers = GS_nlayers # 6 self.DS_nlayers = DS_nlayers # 4 self.GS_nf = GS_nf # 32 self.DS_nf = DS_nf # 32 self.GT_nlayers = GT_nlayers # 6 self.DT_nlayers = DT_nlayers # 4 self.GT_nf = GT_nf # 32 self.DT_nf = DT_nf # 32 self.gpu = gpu self.lambda_l1 = 100 self.lambda_gp = 10 self.lambda_distance = 0.01 self.lambda_sadv = 0.1 self.lambda_gly = 1.0 self.lambda_tadv = 1.0 self.lambda_sty = 0.01 self.style_weights = [1e3 / n**2 for n in [64, 128, 256, 512, 512]] # [0.244140625, 0.06103515625, 0.0152587890625, 0.003814697265625, 0.003814697265625] self.loss = nn.L1Loss() self.gramloss = GramMSELoss() self.gramloss = self.gramloss.cuda() if self.gpu else self.gramloss self.getmask = SemanticFeature() for param in self.getmask.parameters(): param.requires_grad = False self.G_S = GlyphGenerator(self.GS_nf, self.GS_nlayers) self.D_S = Discriminator(3, self.DS_nf, self.DS_nlayers) self.G_T = TextureGenerator(self.GT_nf, self.GT_nlayers) self.D_T = Discriminator(6, self.DT_nf, self.DT_nlayers) self.trainerG_S = torch.optim.Adam(self.G_S.parameters(), lr=0.0002, betas=(0.5, 0.999)) self.trainerD_S = torch.optim.Adam(self.D_S.parameters(), lr=0.0002, betas=(0.5, 0.999)) self.trainerG_T = torch.optim.Adam(self.G_T.parameters(), lr=0.0002, betas=(0.5, 0.999)) self.trainerD_T = torch.optim.Adam(self.D_T.parameters(), lr=0.0002, betas=(0.5, 0.999)) # FOR TESTING def forward(self, x, l): x[:, 0:1] = gaussian(x[:, 0:1], stddev=0.2) xl = self.G_S(x, l) xl[:, 0:1] = gaussian(xl[:, 0:1], stddev=0.2) return self.G_T(xl) # FOR TRAINING # init weight def init_networks(self, weights_init): self.G_S.apply(weights_init) self.D_S.apply(weights_init) self.G_T.apply(weights_init) self.D_T.apply(weights_init) # WGAN-GP: calculate gradient penalty def calc_gradient_penalty(self, netD, real_data, fake_data): alpha = torch.rand(real_data.shape[0], 1, 1, 1) alpha = alpha.cuda() if self.gpu else alpha interpolates = alpha * real_data + ((1 - alpha) * fake_data) interpolates = Variable(interpolates, requires_grad=True) disc_interpolates = netD(interpolates) gradients = autograd.grad( outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size()).cuda() if self.gpu else torch.ones(disc_interpolates.size()), create_graph=True, retain_graph=True, only_inputs=True)[0] gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() return gradient_penalty def update_structure_discriminator(self, x, xl, l): # xl与x裁剪的坐标是相同的。 # xl是加入了一些噪声的自Xl[idx]随机裁剪出的 32 个 大小为 256x256 的xl图像 [32, 3, 256, 256] # x就是输入的Output的随机裁剪/选择后的结果,也就是原距离图像随机裁剪/选择后的,与 xl shape 相同 [32, 3, 256, 256] with torch.no_grad(): fake_x = self.G_S(xl, l) # shape同 x,是[32,3,256,256] fake_output = self.D_S(fake_x) # 11552长度向量 real_output = self.D_S(x) # 11552长度向量 gp = self.calc_gradient_penalty(self.D_S, x.data, fake_x.data) LSadv = self.lambda_sadv * (fake_output.mean() - real_output.mean() + self.lambda_gp * gp) self.trainerD_S.zero_grad() LSadv.backward() self.trainerD_S.step() return (real_output.mean() - fake_output.mean()).data.mean() * self.lambda_sadv def update_structure_generator(self, x, xl, l, t=None): fake_x = self.G_S(xl, l) # fake_x : [32,3,256,256] fake_output = self.D_S(fake_x) # 向量 LSadv = -fake_output.mean() * self.lambda_sadv LSrec = self.loss(fake_x, x) * self.lambda_l1 LS = LSadv + LSrec if t is not None: # weight map based on the distance field # whose pixel value increases with its distance to the nearest text contour point of t Mt = (t[:, 1:2] + t[:, 2:3]) * 0.5 + 1.0 t_noise = t.clone() t_noise[:, 0:1] = gaussian(t_noise[:, 0:1], stddev=0.2) fake_t = self.G_S(t_noise, l) LSgly = self.loss(fake_t * Mt, t * Mt) * self.lambda_gly LS = LS + LSgly self.trainerG_S.zero_grad() LS.backward() self.trainerG_S.step() # global id # if id % 60 == 0: # viz_img = to_data(torch.cat((x[0], xl[0], fake_x[0]), dim=2)) # save_image(viz_img, '../output/structure_result%d.jpg'%id) # id += 1 return LSadv.data.mean(), LSrec.data.mean(), LSgly.data.mean( ) if t is not None else 0 def structure_one_pass(self, x, xl, l, t=None): # TODO t 是干嘛的??? LDadv = self.update_structure_discriminator(x, xl, l) LGadv, Lrec, Lgly = self.update_structure_generator(x, xl, l, t) return [LDadv, LGadv, Lrec, Lgly] def update_texture_discriminator(self, x, y): # texture transfer: x 风格距离图 [bs, 3, 256, 256] # texture transfer: y 风格图 [bs, 3, 256, 256] with torch.no_grad(): fake_y = self.G_T(x) # 从风格距离图中生成风格图,就像pix2pix那样!!! fake_concat = torch.cat((x, fake_y), dim=1) fake_output = self.D_T(fake_concat) real_concat = torch.cat((x, y), dim=1) real_output = self.D_T(real_concat) gp = self.calc_gradient_penalty(self.D_T, real_concat.data, fake_concat.data) LTadv = self.lambda_tadv * (fake_output.mean() - real_output.mean() + self.lambda_gp * gp) self.trainerD_T.zero_grad() LTadv.backward() self.trainerD_T.step() return (real_output.mean() - fake_output.mean()).data.mean() * self.lambda_tadv def update_texture_generator(self, x, y, t=None, l=None, VGGfeatures=None, style_targets=None): fake_y = self.G_T(x) # 计算L_distance # 风格距离图变为黑白图 BW = x[:, 0, :, :].clone().detach().unsqueeze(dim=1) # print(BW.shape) BW = BW.expand(BW.shape[0], 3, BW.shape[2], BW.shape[3]) C = BW D = x X = fake_y C.require_grad_ = False D.require_grad_ = False # Ldistance = 1e-6 * torch.sum((C * D - X * D)) Ldistance = torch.norm(C * D - X * D, 'fro') * 0.5 * self.lambda_distance fake_concat = torch.cat((x, fake_y), dim=1) fake_output = self.D_T(fake_concat) LTadv = -fake_output.mean() * self.lambda_tadv Lrec = self.loss(fake_y, y) * self.lambda_l1 LT = LTadv + Lrec + Ldistance if t is not None: with torch.no_grad(): t[:, 0:1] = gaussian(t[:, 0:1], stddev=0.2) source_mask = self.G_S(t, l).detach() source = source_mask.clone() source[:, 0:1] = gaussian(source[:, 0:1], stddev=0.2) smaps_fore = [(A.detach() + 1) * 0.5 for A in self.getmask(source_mask[:, 0:1])] smaps_back = [1 - A for A in smaps_fore] fake_t = self.G_T(source) out = VGGfeatures(fake_t) style_losses1 = [ self.style_weights[a] * self.gramloss(A * smaps_fore[a], style_targets[0][a]) for a, A in enumerate(out) ] style_losses2 = [ self.style_weights[a] * self.gramloss(A * smaps_back[a], style_targets[1][a]) for a, A in enumerate(out) ] Lsty = (sum(style_losses1) + sum(style_losses2)) * self.lambda_sty LT = LT + Lsty # global id # if id % 20 == 0: # viz_img = to_data(torch.cat((x[0], y[0], fake_y[0]), dim=2)) # save_image(viz_img, '../output/texturee_result%d.jpg'%id) # id += 1 self.trainerG_T.zero_grad() LT.backward() self.trainerG_T.step() return Ldistance.data.mean(), LTadv.data.mean(), Lrec.data.mean( ), Lsty.data.mean() if t is not None else 0 def texture_one_pass(self, x, y, t=None, l=None, VGGfeatures=None, style_targets=None): LDadv = self.update_texture_discriminator(x, y) Ldiatance, LGadv, Lrec, Lsty = self.update_texture_generator( x, y, t, l, VGGfeatures, style_targets) return [Ldiatance, LDadv, LGadv, Lrec, Lsty] def save_structure_model(self, filepath, filename): torch.save(self.G_S.state_dict(), os.path.join(filepath, filename + '-GS.ckpt')) torch.save(self.D_S.state_dict(), os.path.join(filepath, filename + '-DS.ckpt')) def save_texture_model(self, filepath, filename): torch.save(self.G_T.state_dict(), os.path.join(filepath, filename + '-GT.ckpt')) torch.save(self.D_T.state_dict(), os.path.join(filepath, filename + '-DT.ckpt'))
class ShapeMatchingGAN(nn.Module): def __init__(self, GS_nlayers=6, DS_nlayers=5, GS_nf=32, DS_nf=32, GT_nlayers=6, DT_nlayers=5, GT_nf=32, DT_nf=32, gpu=True): super(ShapeMatchingGAN, self).__init__() self.GS_nlayers = GS_nlayers self.DS_nlayers = DS_nlayers self.GS_nf = GS_nf self.DS_nf = DS_nf self.GT_nlayers = GT_nlayers self.DT_nlayers = DT_nlayers self.GT_nf = GT_nf self.DT_nf = DT_nf self.gpu = gpu self.lambda_l1 = 100 self.lambda_gp = 10 self.lambda_sadv = 0.1 self.lambda_gly = 1.0 self.lambda_tadv = 1.0 self.lambda_sty = 0.01 self.style_weights = [1e3 / n**2 for n in [64, 128, 256, 512, 512]] self.loss = nn.L1Loss() self.gramloss = GramMSELoss() self.gramloss = self.gramloss.cuda() if self.gpu else self.gramloss self.getmask = SemanticFeature() for param in self.getmask.parameters(): param.requires_grad = False self.G_S = GlyphGenerator(self.GS_nf, self.GS_nlayers) self.D_S = Discriminator(3, self.DS_nf, self.DS_nlayers) self.G_T = TextureGenerator(self.GT_nf, self.GT_nlayers) self.D_T = Discriminator(6, self.DT_nf, self.DT_nlayers) self.trainerG_S = torch.optim.Adam(self.G_S.parameters(), lr=0.0002, betas=(0.5, 0.999)) self.trainerD_S = torch.optim.Adam(self.D_S.parameters(), lr=0.0002, betas=(0.5, 0.999)) self.trainerG_T = torch.optim.Adam(self.G_T.parameters(), lr=0.0002, betas=(0.5, 0.999)) self.trainerD_T = torch.optim.Adam(self.D_T.parameters(), lr=0.0002, betas=(0.5, 0.999)) # FOR TESTING def forward(self, x, l): x[:, 0:1] = gaussian(x[:, 0:1], stddev=0.2) xl = self.G_S(x, l) xl[:, 0:1] = gaussian(xl[:, 0:1], stddev=0.2) return self.G_T(xl) # FOR TRAINING # init weight def init_networks(self, weights_init): self.G_S.apply(weights_init) self.D_S.apply(weights_init) self.G_T.apply(weights_init) self.D_T.apply(weights_init) # WGAN-GP: calculate gradient penalty def calc_gradient_penalty(self, netD, real_data, fake_data): alpha = torch.rand(real_data.shape[0], 1, 1, 1) alpha = alpha.cuda() if self.gpu else alpha interpolates = alpha * real_data + ((1 - alpha) * fake_data) interpolates = Variable(interpolates, requires_grad=True) disc_interpolates = netD(interpolates) gradients = autograd.grad( outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size()).cuda() if self.gpu else torch.ones(disc_interpolates.size()), create_graph=True, retain_graph=True, only_inputs=True)[0] gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() return gradient_penalty def update_structure_discriminator(self, x, xl, l): with torch.no_grad(): fake_x = self.G_S(xl, l) fake_output = self.D_S(fake_x) real_output = self.D_S(x) gp = self.calc_gradient_penalty(self.D_S, x.data, fake_x.data) LSadv = self.lambda_sadv * (fake_output.mean() - real_output.mean() + self.lambda_gp * gp) self.trainerD_S.zero_grad() LSadv.backward() self.trainerD_S.step() return (real_output.mean() - fake_output.mean()).data.mean() * self.lambda_sadv def update_structure_generator(self, x, xl, l, t=None): fake_x = self.G_S(xl, l) fake_output = self.D_S(fake_x) LSadv = -fake_output.mean() * self.lambda_sadv LSrec = self.loss(fake_x, x) * self.lambda_l1 LS = LSadv + LSrec if t is not None: # weight map based on the distance field # whose pixel value increases with its distance to the nearest text contour point of t Mt = (t[:, 1:2] + t[:, 2:3]) * 0.5 + 1.0 t_noise = t.clone() t_noise[:, 0:1] = gaussian(t_noise[:, 0:1], stddev=0.2) fake_t = self.G_S(t_noise, l) LSgly = self.loss(fake_t * Mt, t * Mt) * self.lambda_gly LS = LS + LSgly self.trainerG_S.zero_grad() LS.backward() self.trainerG_S.step() #global id #if id % 60 == 0: # viz_img = to_data(torch.cat((x[0], xl[0], fake_x[0]), dim=2)) # save_image(viz_img, '../output/structure_result%d.jpg'%id) #id += 1 return LSadv.data.mean(), LSrec.data.mean(), LSgly.data.mean( ) if t is not None else 0 def structure_one_pass(self, x, xl, l, t=None): LDadv = self.update_structure_discriminator(x, xl, l) LGadv, Lrec, Lgly = self.update_structure_generator(x, xl, l, t) return [LDadv, LGadv, Lrec, Lgly] def update_texture_discriminator(self, x, y): with torch.no_grad(): fake_y = self.G_T(x) fake_concat = torch.cat((x, fake_y), dim=1) fake_output = self.D_T(fake_concat) real_concat = torch.cat((x, y), dim=1) real_output = self.D_T(real_concat) gp = self.calc_gradient_penalty(self.D_T, real_concat.data, fake_concat.data) LTadv = self.lambda_tadv * (fake_output.mean() - real_output.mean() + self.lambda_gp * gp) self.trainerD_T.zero_grad() LTadv.backward() self.trainerD_T.step() return (real_output.mean() - fake_output.mean()).data.mean() * self.lambda_tadv def update_texture_generator(self, x, y, t=None, l=None, VGGfeatures=None, style_targets=None): fake_y = self.G_T(x) fake_concat = torch.cat((x, fake_y), dim=1) fake_output = self.D_T(fake_concat) LTadv = -fake_output.mean() * self.lambda_tadv Lrec = self.loss(fake_y, y) * self.lambda_l1 LT = LTadv + Lrec if t is not None: with torch.no_grad(): t[:, 0:1] = gaussian(t[:, 0:1], stddev=0.2) source_mask = self.G_S(t, l).detach() source = source_mask.clone() source[:, 0:1] = gaussian(source[:, 0:1], stddev=0.2) smaps_fore = [(A.detach() + 1) * 0.5 for A in self.getmask(source_mask[:, 0:1])] smaps_back = [1 - A for A in smaps_fore] fake_t = self.G_T(source) out = VGGfeatures(fake_t) style_losses1 = [ self.style_weights[a] * self.gramloss(A * smaps_fore[a], style_targets[0][a]) for a, A in enumerate(out) ] style_losses2 = [ self.style_weights[a] * self.gramloss(A * smaps_back[a], style_targets[1][a]) for a, A in enumerate(out) ] Lsty = (sum(style_losses1) + sum(style_losses2)) * self.lambda_sty LT = LT + Lsty #global id #if id % 20 == 0: # viz_img = to_data(torch.cat((x[0], y[0], fake_y[0]), dim=2)) # save_image(viz_img, '../output/texturee_result%d.jpg'%id) #id += 1 self.trainerG_T.zero_grad() LT.backward() self.trainerG_T.step() return LTadv.data.mean(), Lrec.data.mean(), Lsty.data.mean( ) if t is not None else 0 def texture_one_pass(self, x, y, t=None, l=None, VGGfeatures=None, style_targets=None): LDadv = self.update_texture_discriminator(x, y) LGadv, Lrec, Lsty = self.update_texture_generator( x, y, t, l, VGGfeatures, style_targets) return [LDadv, LGadv, Lrec, Lsty] def save_structure_model(self, filepath, filename): torch.save(self.G_S.state_dict(), os.path.join(filepath, filename + '-GS.ckpt')) torch.save(self.D_S.state_dict(), os.path.join(filepath, filename + '-DS.ckpt')) def save_texture_model(self, filepath, filename): torch.save(self.G_T.state_dict(), os.path.join(filepath, filename + '-GT.ckpt')) torch.save(self.D_T.state_dict(), os.path.join(filepath, filename + '-DT.ckpt'))