def train_D(x_real): G.train() D.train() z = torch.randn(args.batch_size, args.z_dim, 1, 1).to(device) x_fake = G(z).detach() x_real_d_logit = D(x_real) x_fake_d_logit = D(x_fake) x_real_d_loss, x_fake_d_loss = d_loss_fn(x_real_d_logit, x_fake_d_logit) gp = g_penal.gradient_penalty( functools.partial(D), x_real, x_fake, gp_mode=args.gradient_penalty_mode, sample_mode=args.gradient_penalty_sample_mode) D_loss = (x_real_d_loss + x_fake_d_loss) + gp * args.gradient_penalty_weight D.zero_grad() D_loss.backward() D_optimizer.step() return {'d_loss': x_real_d_loss + x_fake_d_loss, 'gp': gp}
for ep in tqdm.trange(args.epochs, desc='Epoch Loop'): it_d, it_g = 0, 0 #for x_real,flag in tqdm.tqdm(data_loader, desc='Inner Epoch Loop'): for x_real in tqdm.tqdm(data_loader, desc='Inner Epoch Loop'): x_real = x_real.to(device) z = torch.randn(args.batch_size, args.z_dim, 1, 1).to(device) #--------training D----------- x_fake = G(z) #print('x_real.shape:'+str(x_real.shape)) x_real_d_logit = D(x_real) x_fake_d_logit = D(x_fake.detach()) x_real_d_loss, x_fake_d_loss = d_loss_fn(x_real_d_logit, x_fake_d_logit) gp = g_penal.gradient_penalty(functools.partial(D), x_real, x_fake.detach(), gp_mode=args.gradient_penalty_mode, sample_mode=args.gradient_penalty_sample_mode) #gp = torch.tensor(0.0) D_loss = (x_real_d_loss + x_fake_d_loss) + gp * args.gradient_penalty_weight #D_loss = 1/(1+0.005*ep)*D_loss # 渐进式GP! D.zero_grad() D_loss.backward() D_optimizer.step() #decayD.step() D_loss_dict={'d_loss': x_real_d_loss + x_fake_d_loss, 'gp': gp} it_d += 1 for k, v in D_loss_dict.items(): writer.add_scalar('D/%s' % k, v.data.cpu().numpy(), global_step=it_d)