class gan(nn.Module): # def __init__(self, params, save_dir, g_weight_dir, d_weight_dir, d_update_freq=1, start_epoch=0, g_lr=2e-4, d_lr=2e-4, use_cuda=True): def __init__(self, params, args): super(gan, self).__init__() self.G = MModel(params, use_cuda=True) self.D = Discriminator(params, bias=True) self.vgg_loss = VGGPerceptualLoss() self.L1_loss = nn.L1Loss() if args.use_cuda: self.G = self.G.cuda() self.D = self.D.cuda() self.vgg_loss = self.vgg_loss.cuda() self.L1_loss = self.L1_loss.cuda() if args.g_weight_dir: self.G.load_state_dict(torch.load(args.g_weight_dir), strict=True) if args.d_weight_dir: self.D.load_state_dict(torch.load(args.d_weight_dir), strict=False) self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=args.g_lr) self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=args.d_lr) self.save_dir = args.save_dir if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) self.d_update_freq = args.d_update_freq self.save_freq = args.save_freq self.writer = SummaryWriter('runs/' + args.save_dir) self.use_cuda = args.use_cuda self.start_epoch = args.start_epoch def G_loss(self, input, target): vgg = self.vgg_loss(input, target) L1 = self.L1_loss(input, target) # return vgg return vgg + L1 def update_D(self, loss, epoch): if epoch % self.d_update_freq == 0: loss.backward() self.optimizer_D.step() def get_patch_weight(self, pose, size=62): heads = pose[:, 0, :, :] heads = heads.unsqueeze(1) heads = torch.nn.functional.interpolate(heads, size=size) heads = heads * 5 + torch.ones_like(heads) return heads def gan_loss(self, out, label, pose): # weight = self.get_patch_weight(pose) # return nn.BCELoss(weight=weight)(out, torch.ones_like(out) if label==1 else torch.zeros_like(out)) return nn.BCELoss()( out, torch.ones_like(out) if label == 1 else torch.zeros_like(out)) def train(self, dl, epoch): # i -- current epoch cnt = 0 loss_D_real_sum, loss_D_fake_sum, loss_D_sum, loss_G_gan_sum, loss_G_img_sum, loss_G_sum = 0, 0, 0, 0, 0, 0 for iter, (src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans, src_mask_gt, tgt_face, tgt_face_box, src_face_box) in enumerate(dl): print('epoch:', epoch, 'iter:', iter) self.optimizer_D.zero_grad() if self.use_cuda: src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans = src_img.cuda( ), y.cuda(), src_pose.cuda(), tgt_pose.cuda( ), src_mask_prior.cuda(), x_trans.cuda() out = self.G(src_img, src_pose, tgt_pose, src_mask_prior, x_trans) gen = out[0] loss_D_real = self.gan_loss(self.D(y, tgt_pose), 1, tgt_pose) loss_D_fake = self.gan_loss(self.D(gen.detach(), tgt_pose), 0, tgt_pose) loss_D = loss_D_real + loss_D_fake self.update_D(loss_D, epoch) if False and epoch < 10: loss_G_gan = torch.zeros((1)) loss_G_img = torch.zeros((1)) loss_G = loss_G_gan + loss_G_img else: self.optimizer_G.zero_grad() loss_G_gan = self.gan_loss(self.D(gen, tgt_pose), 1, tgt_pose) loss_G_img = self.G_loss(gen, y) # vgg_loss + L1_loss loss_G = loss_G_gan + loss_G_img loss_G.backward() self.optimizer_G.step() loss_D_real_sum += loss_D_real.item() loss_D_fake_sum += loss_D_fake.item() loss_D_sum += loss_D.item() loss_G_gan_sum += loss_G_gan.item() loss_G_img_sum += loss_G_img.item() loss_G_sum += loss_G.item() cnt += 1 # if epoch % self.save_freq == 0 and iter < 3: # self.writer.add_images('gen/epoch%d'%epoch, gen*0.5+0.5) # self.writer.add_images('y/epoch%d'%epoch, y*0.5+0.5) # self.writer.add_images('src_mask/epoch%d'%epoch, out[2].view((out[2].size(0)*out[2].size(1), 1, out[2].size(2), out[2].size(3)))) # self.writer.add_images('warped/epoch%d'%epoch, out[3].view((out[3].size(0)*11, 3, out[3].size(2), out[3].size(3)))*0.5+0.5) self.writer.add_scalar('loss_D_real', loss_D_real_sum / cnt, epoch) self.writer.add_scalar('loss_D_fake', loss_D_fake_sum / cnt, epoch) self.writer.add_scalar('loss_D', loss_D_sum / cnt, epoch) self.writer.add_scalar('loss_G_gan', loss_G_gan_sum / cnt, epoch) self.writer.add_scalar('loss_G_img', loss_G_img_sum / cnt, epoch) self.writer.add_scalar('loss_G', loss_G_sum / cnt, epoch) self.writer.add_scalars('DG', { 'D': loss_D / cnt, 'G': loss_G / cnt }, epoch) if epoch % self.save_freq == 0: torch.save(self.G.state_dict(), os.path.join(self.save_dir, 'g_epoch_%d.pth' % epoch)) torch.save(self.D.state_dict(), os.path.join(self.save_dir, 'd_epoch_%d.pth' % epoch)) def test(self, test_dl, epoch): self.G.eval() for iter, (src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans, src_mask_gt, tgt_face, tgt_face_box, src_face_box) in enumerate(test_dl): print('test', 'epoch:', epoch, 'iter:', iter) if self.use_cuda: src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans = src_img.cuda( ), y.cuda(), src_pose.cuda(), tgt_pose.cuda( ), src_mask_prior.cuda(), x_trans.cuda() with torch.no_grad(): out = self.G(src_img, src_pose, tgt_pose, src_mask_prior, x_trans) gen = out[0] if iter == 0: self.writer.add_images('test_gen/epoch%d' % epoch, gen * 0.5 + 0.5) self.writer.add_images('test_y/epoch%d' % epoch, y * 0.5 + 0.5) self.writer.add_images('test_src/epoch%d' % epoch, src_img * 0.5 + 0.5) self.writer.add_images( 'test_src_mask/epoch%d' % epoch, out[2].view( (out[2].size(0) * out[2].size(1), 1, out[2].size(2), out[2].size(3))))
loss_sum += loss.item() fg_loss_sum += fg_loss.item() # mask_loss_sum += mask_loss.item() L1_sum += L1.item() # writer.add_images('Image/epoch%d/y'%epoch, y) # writer.add_images('Image/epoch%d/gen'%epoch, out) if (epoch % 10 == 0 or epoch < 5) and i == 0: # writer.add_scalar('Epoch%d/loss'%epoch, loss.item(), i) writer.add_images('genFG/epoch%d'%epoch, out[0]*0.5+0.5) writer.add_images('y/epoch%d'%epoch, y*0.5+0.5) writer.add_images('src_img/epoch%d'%epoch, src_img*0.5+0.5) # writer.add_images('src_mask_delta/epoch%d'%epoch, out[1].view((out[1].size(0)*out[1].size(1), 1, out[1].size(2), out[1].size(3)))) writer.add_images('src_mask/epoch%d'%epoch, out[2].view((out[2].size(0)*out[2].size(1), 1, out[2].size(2), out[2].size(3)))) # writer.add_images('src_mask_prior/epoch%d'%epoch, src_mask_prior.view((src_mask_prior.size(0)*src_mask_prior.size(1), 1, src_mask_prior.size(2), src_mask_prior.size(3)))) # writer.add_images('warped/epoch%d'%epoch, out[3].view((out[3].size(0)*11, 3, out[3].size(2), out[3].size(3)))*0.5+0.5) # writer.add_images('mask_sum/epoch%d'%epoch, mask_sum.unsqueeze(1)) # writer.add_histogram('src_mask_hist/epoch%d'%epoch, out[2]) # writer.add_histogram('src_mask_delta_hist/epoch%d'%epoch, out[1]) # writer.add_histogram('src_mask_gt/epoch%d'%epoch, src_mask_gt) writer.add_scalar('Train/loss', loss_sum/cnt, epoch) writer.add_scalar('Train/fg_loss', fg_loss_sum/cnt, epoch) # writer.add_scalar('Train/mask_loss', mask_loss_sum/cnt, epoch) writer.add_scalar('Train/L1_loss', L1_sum/cnt, epoch) if epoch % 10 == 0 and epoch != 0 and mini == False: torch.save(model.state_dict(), model_dir+'/epoch_%d.pth'%epoch)
class gan(nn.Module): # def __init__(self, params, save_dir, g_weight_dir, d_weight_dir, d_update_freq=1, start_epoch=0, g_lr=2e-4, d_lr=2e-4, use_cuda=True): def __init__(self, params, args): super(gan, self).__init__() self.G = MModel(params, use_cuda=True) self.D = Discriminator(params, bias=True) self.Face_D = FaceDisc() self.vgg_loss = VGGPerceptualLoss() self.L1_loss = nn.L1Loss() if args.use_cuda: self.G = self.G.cuda() self.D = self.D.cuda() self.Face_D = self.Face_D.cuda() self.vgg_loss = self.vgg_loss.cuda() self.L1_loss = self.L1_loss.cuda() if args.g_weight_dir: self.G.load_state_dict(torch.load(args.g_weight_dir), strict=True) if args.d_weight_dir: self.D.load_state_dict(torch.load(args.d_weight_dir), strict=False) self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=args.g_lr) self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=args.d_lr) self.optimizer_Face_D = torch.optim.Adam(self.Face_D.parameters(), lr=args.face_d_lr) self.save_dir = args.save_dir if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) self.d_update_freq = args.d_update_freq self.save_freq = args.save_freq self.writer = SummaryWriter('runs/' + args.save_dir) self.use_cuda = args.use_cuda self.start_epoch = args.start_epoch def set_lr(self, op, lr): # op: {D|G|Face_D} or optim object if op == 'G': for g in self.optimizer_G.param_groups: g['lr'] = lr return None if op == 'D': for g in self.optimizer_D.param_groups: g['lr'] = lr return None if op == 'Face_D': for g in self.optimizer_Face_D.param_groups: g['lr'] = lr return None for g in op.param_groups: g['lr'] = lr def G_loss(self, input, target): vgg = self.vgg_loss(input, target) # L1 = self.L1_loss(input, target) return vgg # return vgg + L1 def update_D(self, loss, epoch): if epoch % self.d_update_freq == 0: loss.backward() self.optimizer_D.step() def update_G(self, loss, epoch): if epoch >= 10: loss.backward() self.optimizer_G.step() def update_Face_D(self, loss, epoch): loss.backward() self.optimizer_Face_D.step() def get_patch_weight(self, pose, size=62): heads = pose[:, 0, :, :] heads = heads.unsqueeze(1) heads = torch.nn.functional.interpolate(heads, size=size) heads = heads * 5 + torch.ones_like(heads) return heads def gan_loss(self, out, label, pose): # weight = self.get_patch_weight(pose) # return nn.BCELoss(weight=weight)(out, torch.ones_like(out) if label==1 else torch.zeros_like(out)) return nn.BCELoss()( out, torch.ones_like(out) if label == 1 else torch.zeros_like(out)) def face_gan_loss(self, out, label): return nn.BCELoss()( out, torch.ones_like(out) if label == 1 else torch.zeros_like(out)) def crop_face(self, gen, size, tgt_face_box): gen_face = torch.zeros(size).cuda() for i in range(gen.size(0)): # print(tgt_face_box[i][1], tgt_face_box[i][3], tgt_face_box[i][0], tgt_face_box[i][2]) face = gen[i, :, tgt_face_box[i][1]:tgt_face_box[i][3], tgt_face_box[i][0]:tgt_face_box[i][2]] # print(face.size()) face = F.interpolate(face.unsqueeze(0), size=256) gen_face[i] = face return gen_face def get_face_lambda(self, epoch): return 1 if epoch in range(10, 50): return 0.0001 if epoch in range(50, 70): return 0.001 if epoch in range(70, 80): return 0.01 if epoch in range(80, 90): return 0.1 return 1 def train(self, dl, epoch): cnt = 0 loss_face_D_sum, loss_D_real_sum, loss_D_fake_sum, loss_D_sum, loss_G_gan_sum, loss_G_img_sum, loss_G_sum = 0, 0, 0, 0, 0, 0, 0 for iter, (src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans, src_mask_gt, tgt_face, tgt_face_box, src_face_box) in enumerate(dl): print('epoch:', epoch, 'iter:', iter) if self.use_cuda: src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans, tgt_face = src_img.cuda( ), y.cuda(), src_pose.cuda(), tgt_pose.cuda( ), src_mask_prior.cuda(), x_trans.cuda(), tgt_face.cuda() # out = self.G(src_img, F.interpolate(src_pose, scale_factor=1/2), F.interpolate(tgt_pose, scale_factor=1/2), src_mask_prior, x_trans) out = self.G(src_img, src_pose, tgt_pose, src_mask_prior, x_trans) gen = out[0] gen_face = self.crop_face(gen, tgt_face.size(), tgt_face_box) self.optimizer_D.zero_grad() loss_D_real = self.gan_loss(self.D(y, tgt_pose), 1, tgt_pose) loss_D_fake = self.gan_loss(self.D(gen.detach(), tgt_pose), 0, tgt_pose) loss_D = loss_D_real + loss_D_fake loss_D_sum += loss_D.item() self.update_D(loss_D, epoch) self.optimizer_Face_D.zero_grad() pred_tgt = self.Face_D(tgt_face) pred_gen = self.Face_D(gen_face.detach()) print('pred_tgt', pred_tgt) print('pred_gen', pred_gen) loss_D_real = self.face_gan_loss(pred_tgt, 1) loss_D_fake = self.face_gan_loss(pred_gen, 0) loss_face_D = loss_D_real + loss_D_fake loss_face_D_sum += loss_face_D.item() self.update_Face_D(loss_face_D, epoch) self.optimizer_G.zero_grad() loss_G_gan = self.gan_loss(self.D(gen, tgt_pose), 1, tgt_pose) loss_G_face_gan = self.face_gan_loss(self.Face_D(gen_face), 1) loss_G_img = self.G_loss(gen, y) # vgg_loss + L1_loss lmd = self.get_face_lambda(epoch) loss_G = loss_G_gan + loss_G_img + lmd * loss_G_face_gan # loss_G = loss_G_face_gan self.update_G(loss_G, epoch) # loss_D_real_sum += loss_D_real.item() # loss_D_fake_sum += loss_D_fake.item() loss_G_gan_sum += loss_G_gan.item() loss_G_img_sum += loss_G_img.item() loss_G_sum += loss_G.item() cnt += 1 # if epoch % self.save_freq == 0 and iter < 3: # self.writer.add_images('gen/epoch%d'%epoch, gen*0.5+0.5) # self.writer.add_images('y/epoch%d'%epoch, y*0.5+0.5) # self.writer.add_images('src_mask/epoch%d'%epoch, out[2].view((out[2].size(0)*out[2].size(1), 1, out[2].size(2), out[2].size(3)))) # self.writer.add_images('warped/epoch%d'%epoch, out[3].view((out[3].size(0)*11, 3, out[3].size(2), out[3].size(3)))*0.5+0.5) # self.writer.add_scalar('loss_D_real', loss_D_real_sum/cnt, epoch) # self.writer.add_scalar('loss_D_fake', loss_D_fake_sum/cnt, epoch) self.writer.add_scalar('loss_D', loss_D_sum / cnt, epoch) self.writer.add_scalar('loss_face_D', loss_face_D_sum / cnt, epoch) self.writer.add_scalars( 'DG', { 'D': loss_D / cnt, 'face_D': loss_face_D_sum / cnt, 'G': loss_G_sum / cnt }, epoch) self.writer.add_scalar('loss_G_gan', loss_G_gan_sum / cnt, epoch) self.writer.add_scalar('loss_G_img', loss_G_img_sum / cnt, epoch) self.writer.add_scalar('loss_G', loss_G_sum / cnt, epoch) self.writer.add_images('train_gen_face/epoch%d' % epoch, gen_face * 0.5 + 0.5) self.writer.add_images('train_tgt_face/epoch%d' % epoch, tgt_face * 0.5 + 0.5) self.writer.add_images('train_src/epoch%d' % epoch, src_img * 0.5 + 0.5) self.writer.add_images('train_tgt/epoch%d' % epoch, y * 0.5 + 0.5) self.writer.add_images('train_gen/epoch%d' % epoch, gen * 0.5 + 0.5) if epoch % self.save_freq == 0: torch.save(self.G.state_dict(), os.path.join(self.save_dir, 'g_epoch_%d.pth' % epoch)) torch.save(self.D.state_dict(), os.path.join(self.save_dir, 'd_epoch_%d.pth' % epoch)) torch.save( self.Face_D.state_dict(), os.path.join(self.save_dir, 'faced_epoch_%d.pth' % epoch)) def test(self, test_dl, epoch): self.G.eval() for iter, (src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans, src_mask_gt, tgt_face, tgt_face_box, src_face_box) in enumerate(test_dl): print('test', 'epoch:', epoch, 'iter:', iter) if self.use_cuda: src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans = src_img.cuda( ), y.cuda(), src_pose.cuda(), tgt_pose.cuda( ), src_mask_prior.cuda(), x_trans.cuda() with torch.no_grad(): # src_pose = F.interpolate(src_pose, scale_factor=1/2) # tgt_pose = F.interpolate(tgt_pose, scale_factor=1/2) out = self.G(src_img, src_pose, tgt_pose, src_mask_prior, x_trans) gen = out[0] if iter == 0: self.writer.add_images('test_gen/epoch%d' % epoch, gen * 0.5 + 0.5) self.writer.add_images('test_y/epoch%d' % epoch, y * 0.5 + 0.5) self.writer.add_images('test_src/epoch%d' % epoch, src_img * 0.5 + 0.5) self.writer.add_images( 'test_src_mask/epoch%d' % epoch, out[2].view( (out[2].size(0) * out[2].size(1), 1, out[2].size(2), out[2].size(3)))) self.writer.add_images( 'test_warped/epoch%d' % epoch, out[3].view( (out[3].size(0) * 11, 3, out[3].size(2), out[3].size(3))) * 0.5 + 0.5)