class gan(nn.Module): # def __init__(self, params, save_dir, g_weight_dir, d_weight_dir, d_update_freq=1, start_epoch=0, g_lr=2e-4, d_lr=2e-4, use_cuda=True): def __init__(self, params, args): super(gan, self).__init__() self.G = MModel(params, use_cuda=True) self.D = Discriminator(params, bias=True) self.vgg_loss = VGGPerceptualLoss() self.L1_loss = nn.L1Loss() if args.use_cuda: self.G = self.G.cuda() self.D = self.D.cuda() self.vgg_loss = self.vgg_loss.cuda() self.L1_loss = self.L1_loss.cuda() if args.g_weight_dir: self.G.load_state_dict(torch.load(args.g_weight_dir), strict=True) if args.d_weight_dir: self.D.load_state_dict(torch.load(args.d_weight_dir), strict=False) self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=args.g_lr) self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=args.d_lr) self.save_dir = args.save_dir if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) self.d_update_freq = args.d_update_freq self.save_freq = args.save_freq self.writer = SummaryWriter('runs/' + args.save_dir) self.use_cuda = args.use_cuda self.start_epoch = args.start_epoch def G_loss(self, input, target): vgg = self.vgg_loss(input, target) L1 = self.L1_loss(input, target) # return vgg return vgg + L1 def update_D(self, loss, epoch): if epoch % self.d_update_freq == 0: loss.backward() self.optimizer_D.step() def get_patch_weight(self, pose, size=62): heads = pose[:, 0, :, :] heads = heads.unsqueeze(1) heads = torch.nn.functional.interpolate(heads, size=size) heads = heads * 5 + torch.ones_like(heads) return heads def gan_loss(self, out, label, pose): # weight = self.get_patch_weight(pose) # return nn.BCELoss(weight=weight)(out, torch.ones_like(out) if label==1 else torch.zeros_like(out)) return nn.BCELoss()( out, torch.ones_like(out) if label == 1 else torch.zeros_like(out)) def train(self, dl, epoch): # i -- current epoch cnt = 0 loss_D_real_sum, loss_D_fake_sum, loss_D_sum, loss_G_gan_sum, loss_G_img_sum, loss_G_sum = 0, 0, 0, 0, 0, 0 for iter, (src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans, src_mask_gt, tgt_face, tgt_face_box, src_face_box) in enumerate(dl): print('epoch:', epoch, 'iter:', iter) self.optimizer_D.zero_grad() if self.use_cuda: src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans = src_img.cuda( ), y.cuda(), src_pose.cuda(), tgt_pose.cuda( ), src_mask_prior.cuda(), x_trans.cuda() out = self.G(src_img, src_pose, tgt_pose, src_mask_prior, x_trans) gen = out[0] loss_D_real = self.gan_loss(self.D(y, tgt_pose), 1, tgt_pose) loss_D_fake = self.gan_loss(self.D(gen.detach(), tgt_pose), 0, tgt_pose) loss_D = loss_D_real + loss_D_fake self.update_D(loss_D, epoch) if False and epoch < 10: loss_G_gan = torch.zeros((1)) loss_G_img = torch.zeros((1)) loss_G = loss_G_gan + loss_G_img else: self.optimizer_G.zero_grad() loss_G_gan = self.gan_loss(self.D(gen, tgt_pose), 1, tgt_pose) loss_G_img = self.G_loss(gen, y) # vgg_loss + L1_loss loss_G = loss_G_gan + loss_G_img loss_G.backward() self.optimizer_G.step() loss_D_real_sum += loss_D_real.item() loss_D_fake_sum += loss_D_fake.item() loss_D_sum += loss_D.item() loss_G_gan_sum += loss_G_gan.item() loss_G_img_sum += loss_G_img.item() loss_G_sum += loss_G.item() cnt += 1 # if epoch % self.save_freq == 0 and iter < 3: # self.writer.add_images('gen/epoch%d'%epoch, gen*0.5+0.5) # self.writer.add_images('y/epoch%d'%epoch, y*0.5+0.5) # self.writer.add_images('src_mask/epoch%d'%epoch, out[2].view((out[2].size(0)*out[2].size(1), 1, out[2].size(2), out[2].size(3)))) # self.writer.add_images('warped/epoch%d'%epoch, out[3].view((out[3].size(0)*11, 3, out[3].size(2), out[3].size(3)))*0.5+0.5) self.writer.add_scalar('loss_D_real', loss_D_real_sum / cnt, epoch) self.writer.add_scalar('loss_D_fake', loss_D_fake_sum / cnt, epoch) self.writer.add_scalar('loss_D', loss_D_sum / cnt, epoch) self.writer.add_scalar('loss_G_gan', loss_G_gan_sum / cnt, epoch) self.writer.add_scalar('loss_G_img', loss_G_img_sum / cnt, epoch) self.writer.add_scalar('loss_G', loss_G_sum / cnt, epoch) self.writer.add_scalars('DG', { 'D': loss_D / cnt, 'G': loss_G / cnt }, epoch) if epoch % self.save_freq == 0: torch.save(self.G.state_dict(), os.path.join(self.save_dir, 'g_epoch_%d.pth' % epoch)) torch.save(self.D.state_dict(), os.path.join(self.save_dir, 'd_epoch_%d.pth' % epoch)) def test(self, test_dl, epoch): self.G.eval() for iter, (src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans, src_mask_gt, tgt_face, tgt_face_box, src_face_box) in enumerate(test_dl): print('test', 'epoch:', epoch, 'iter:', iter) if self.use_cuda: src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans = src_img.cuda( ), y.cuda(), src_pose.cuda(), tgt_pose.cuda( ), src_mask_prior.cuda(), x_trans.cuda() with torch.no_grad(): out = self.G(src_img, src_pose, tgt_pose, src_mask_prior, x_trans) gen = out[0] if iter == 0: self.writer.add_images('test_gen/epoch%d' % epoch, gen * 0.5 + 0.5) self.writer.add_images('test_y/epoch%d' % epoch, y * 0.5 + 0.5) self.writer.add_images('test_src/epoch%d' % epoch, src_img * 0.5 + 0.5) self.writer.add_images( 'test_src_mask/epoch%d' % epoch, out[2].view( (out[2].size(0) * out[2].size(1), 1, out[2].size(2), out[2].size(3))))
params = get_general_params() params['IMG_HEIGHT'] = 256 params['IMG_WIDTH'] = 256 params['posemap_downsample'] = 2 ds = mtdataset(params, mini=mini, full_y=full_y) dl = DataLoader(ds, bs, shuffle) model = MModel(params, use_cuda=True).cuda() model.train() # model.load_state_dict(torch.load('/versa/kangliwei/motion_transfer/0424-gan/g_epoch_1670.pth'), strict=False) # model.load_state_dict(torch.load('/versa/kangliwei/motion_transfer/0429-256-gan/g_epoch_720.pth')) model.load_state_dict(torch.load('/versa/kangliwei/motion_transfer/0604-gan/epoch_290.pth')) # if start_epoch != 0: # model.load_state_dict(torch.load('/versa/kangliwei/motion_transfer/'+weight_dir+'/epoch_%d.pth'%(start_epoch-1))) vgg_loss = VGGPerceptualLoss().cuda() l1mask_loss = L1MaskLoss().cuda() model_params = model.parameters() optimizer = torch.optim.Adam(model_params, lr=lr) writer = SummaryWriter('runs/'+model_dir) print('len(dl)', len(dl)) print('len(ds)', len(ds)) for epoch in range(start_epoch, n_epoch): cnt = 0 loss_sum = 0 fg_loss_sum = 0 mask_loss_sum = 0 L1_sum = 0 for i, (src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans, src_mask_gt, tgt_face, tgt_face_box, src_face_box) in enumerate(dl): optimizer.zero_grad() print('epoch:', epoch, 'iter:', i) src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans, tgt_face = src_img.cuda(), y.cuda(), src_pose.cuda(), tgt_pose.cuda(), src_mask_prior.cuda(), x_trans.cuda(), tgt_face.cuda()
class gan(nn.Module): # def __init__(self, params, save_dir, g_weight_dir, d_weight_dir, d_update_freq=1, start_epoch=0, g_lr=2e-4, d_lr=2e-4, use_cuda=True): def __init__(self, params, args): super(gan, self).__init__() self.G = MModel(params, use_cuda=True) self.D = Discriminator(params, bias=True) self.Face_D = FaceDisc() self.vgg_loss = VGGPerceptualLoss() self.L1_loss = nn.L1Loss() if args.use_cuda: self.G = self.G.cuda() self.D = self.D.cuda() self.Face_D = self.Face_D.cuda() self.vgg_loss = self.vgg_loss.cuda() self.L1_loss = self.L1_loss.cuda() if args.g_weight_dir: self.G.load_state_dict(torch.load(args.g_weight_dir), strict=True) if args.d_weight_dir: self.D.load_state_dict(torch.load(args.d_weight_dir), strict=False) self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=args.g_lr) self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=args.d_lr) self.optimizer_Face_D = torch.optim.Adam(self.Face_D.parameters(), lr=args.face_d_lr) self.save_dir = args.save_dir if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) self.d_update_freq = args.d_update_freq self.save_freq = args.save_freq self.writer = SummaryWriter('runs/' + args.save_dir) self.use_cuda = args.use_cuda self.start_epoch = args.start_epoch def set_lr(self, op, lr): # op: {D|G|Face_D} or optim object if op == 'G': for g in self.optimizer_G.param_groups: g['lr'] = lr return None if op == 'D': for g in self.optimizer_D.param_groups: g['lr'] = lr return None if op == 'Face_D': for g in self.optimizer_Face_D.param_groups: g['lr'] = lr return None for g in op.param_groups: g['lr'] = lr def G_loss(self, input, target): vgg = self.vgg_loss(input, target) # L1 = self.L1_loss(input, target) return vgg # return vgg + L1 def update_D(self, loss, epoch): if epoch % self.d_update_freq == 0: loss.backward() self.optimizer_D.step() def update_G(self, loss, epoch): if epoch >= 10: loss.backward() self.optimizer_G.step() def update_Face_D(self, loss, epoch): loss.backward() self.optimizer_Face_D.step() def get_patch_weight(self, pose, size=62): heads = pose[:, 0, :, :] heads = heads.unsqueeze(1) heads = torch.nn.functional.interpolate(heads, size=size) heads = heads * 5 + torch.ones_like(heads) return heads def gan_loss(self, out, label, pose): # weight = self.get_patch_weight(pose) # return nn.BCELoss(weight=weight)(out, torch.ones_like(out) if label==1 else torch.zeros_like(out)) return nn.BCELoss()( out, torch.ones_like(out) if label == 1 else torch.zeros_like(out)) def face_gan_loss(self, out, label): return nn.BCELoss()( out, torch.ones_like(out) if label == 1 else torch.zeros_like(out)) def crop_face(self, gen, size, tgt_face_box): gen_face = torch.zeros(size).cuda() for i in range(gen.size(0)): # print(tgt_face_box[i][1], tgt_face_box[i][3], tgt_face_box[i][0], tgt_face_box[i][2]) face = gen[i, :, tgt_face_box[i][1]:tgt_face_box[i][3], tgt_face_box[i][0]:tgt_face_box[i][2]] # print(face.size()) face = F.interpolate(face.unsqueeze(0), size=256) gen_face[i] = face return gen_face def get_face_lambda(self, epoch): return 1 if epoch in range(10, 50): return 0.0001 if epoch in range(50, 70): return 0.001 if epoch in range(70, 80): return 0.01 if epoch in range(80, 90): return 0.1 return 1 def train(self, dl, epoch): cnt = 0 loss_face_D_sum, loss_D_real_sum, loss_D_fake_sum, loss_D_sum, loss_G_gan_sum, loss_G_img_sum, loss_G_sum = 0, 0, 0, 0, 0, 0, 0 for iter, (src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans, src_mask_gt, tgt_face, tgt_face_box, src_face_box) in enumerate(dl): print('epoch:', epoch, 'iter:', iter) if self.use_cuda: src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans, tgt_face = src_img.cuda( ), y.cuda(), src_pose.cuda(), tgt_pose.cuda( ), src_mask_prior.cuda(), x_trans.cuda(), tgt_face.cuda() # out = self.G(src_img, F.interpolate(src_pose, scale_factor=1/2), F.interpolate(tgt_pose, scale_factor=1/2), src_mask_prior, x_trans) out = self.G(src_img, src_pose, tgt_pose, src_mask_prior, x_trans) gen = out[0] gen_face = self.crop_face(gen, tgt_face.size(), tgt_face_box) self.optimizer_D.zero_grad() loss_D_real = self.gan_loss(self.D(y, tgt_pose), 1, tgt_pose) loss_D_fake = self.gan_loss(self.D(gen.detach(), tgt_pose), 0, tgt_pose) loss_D = loss_D_real + loss_D_fake loss_D_sum += loss_D.item() self.update_D(loss_D, epoch) self.optimizer_Face_D.zero_grad() pred_tgt = self.Face_D(tgt_face) pred_gen = self.Face_D(gen_face.detach()) print('pred_tgt', pred_tgt) print('pred_gen', pred_gen) loss_D_real = self.face_gan_loss(pred_tgt, 1) loss_D_fake = self.face_gan_loss(pred_gen, 0) loss_face_D = loss_D_real + loss_D_fake loss_face_D_sum += loss_face_D.item() self.update_Face_D(loss_face_D, epoch) self.optimizer_G.zero_grad() loss_G_gan = self.gan_loss(self.D(gen, tgt_pose), 1, tgt_pose) loss_G_face_gan = self.face_gan_loss(self.Face_D(gen_face), 1) loss_G_img = self.G_loss(gen, y) # vgg_loss + L1_loss lmd = self.get_face_lambda(epoch) loss_G = loss_G_gan + loss_G_img + lmd * loss_G_face_gan # loss_G = loss_G_face_gan self.update_G(loss_G, epoch) # loss_D_real_sum += loss_D_real.item() # loss_D_fake_sum += loss_D_fake.item() loss_G_gan_sum += loss_G_gan.item() loss_G_img_sum += loss_G_img.item() loss_G_sum += loss_G.item() cnt += 1 # if epoch % self.save_freq == 0 and iter < 3: # self.writer.add_images('gen/epoch%d'%epoch, gen*0.5+0.5) # self.writer.add_images('y/epoch%d'%epoch, y*0.5+0.5) # self.writer.add_images('src_mask/epoch%d'%epoch, out[2].view((out[2].size(0)*out[2].size(1), 1, out[2].size(2), out[2].size(3)))) # self.writer.add_images('warped/epoch%d'%epoch, out[3].view((out[3].size(0)*11, 3, out[3].size(2), out[3].size(3)))*0.5+0.5) # self.writer.add_scalar('loss_D_real', loss_D_real_sum/cnt, epoch) # self.writer.add_scalar('loss_D_fake', loss_D_fake_sum/cnt, epoch) self.writer.add_scalar('loss_D', loss_D_sum / cnt, epoch) self.writer.add_scalar('loss_face_D', loss_face_D_sum / cnt, epoch) self.writer.add_scalars( 'DG', { 'D': loss_D / cnt, 'face_D': loss_face_D_sum / cnt, 'G': loss_G_sum / cnt }, epoch) self.writer.add_scalar('loss_G_gan', loss_G_gan_sum / cnt, epoch) self.writer.add_scalar('loss_G_img', loss_G_img_sum / cnt, epoch) self.writer.add_scalar('loss_G', loss_G_sum / cnt, epoch) self.writer.add_images('train_gen_face/epoch%d' % epoch, gen_face * 0.5 + 0.5) self.writer.add_images('train_tgt_face/epoch%d' % epoch, tgt_face * 0.5 + 0.5) self.writer.add_images('train_src/epoch%d' % epoch, src_img * 0.5 + 0.5) self.writer.add_images('train_tgt/epoch%d' % epoch, y * 0.5 + 0.5) self.writer.add_images('train_gen/epoch%d' % epoch, gen * 0.5 + 0.5) if epoch % self.save_freq == 0: torch.save(self.G.state_dict(), os.path.join(self.save_dir, 'g_epoch_%d.pth' % epoch)) torch.save(self.D.state_dict(), os.path.join(self.save_dir, 'd_epoch_%d.pth' % epoch)) torch.save( self.Face_D.state_dict(), os.path.join(self.save_dir, 'faced_epoch_%d.pth' % epoch)) def test(self, test_dl, epoch): self.G.eval() for iter, (src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans, src_mask_gt, tgt_face, tgt_face_box, src_face_box) in enumerate(test_dl): print('test', 'epoch:', epoch, 'iter:', iter) if self.use_cuda: src_img, y, src_pose, tgt_pose, src_mask_prior, x_trans = src_img.cuda( ), y.cuda(), src_pose.cuda(), tgt_pose.cuda( ), src_mask_prior.cuda(), x_trans.cuda() with torch.no_grad(): # src_pose = F.interpolate(src_pose, scale_factor=1/2) # tgt_pose = F.interpolate(tgt_pose, scale_factor=1/2) out = self.G(src_img, src_pose, tgt_pose, src_mask_prior, x_trans) gen = out[0] if iter == 0: self.writer.add_images('test_gen/epoch%d' % epoch, gen * 0.5 + 0.5) self.writer.add_images('test_y/epoch%d' % epoch, y * 0.5 + 0.5) self.writer.add_images('test_src/epoch%d' % epoch, src_img * 0.5 + 0.5) self.writer.add_images( 'test_src_mask/epoch%d' % epoch, out[2].view( (out[2].size(0) * out[2].size(1), 1, out[2].size(2), out[2].size(3)))) self.writer.add_images( 'test_warped/epoch%d' % epoch, out[3].view( (out[3].size(0) * 11, 3, out[3].size(2), out[3].size(3))) * 0.5 + 0.5)