Esempio n. 1
0
def train_D(x_real):
    G.train()
    D.train()

    z = torch.randn(args.batch_size, args.z_dim, 1, 1).to(device)
    x_fake = G(z).detach()

    x_real_d_logit = D(x_real)
    x_fake_d_logit = D(x_fake)

    x_real_d_loss, x_fake_d_loss = d_loss_fn(x_real_d_logit, x_fake_d_logit)
    gp = g_penal.gradient_penalty(
        functools.partial(D),
        x_real,
        x_fake,
        gp_mode=args.gradient_penalty_mode,
        sample_mode=args.gradient_penalty_sample_mode)

    D_loss = (x_real_d_loss +
              x_fake_d_loss) + gp * args.gradient_penalty_weight

    D.zero_grad()
    D_loss.backward()
    D_optimizer.step()

    return {'d_loss': x_real_d_loss + x_fake_d_loss, 'gp': gp}
Esempio n. 2
0
	for ep in tqdm.trange(args.epochs, desc='Epoch Loop'):
	    it_d, it_g = 0, 0
	    #for x_real,flag in tqdm.tqdm(data_loader, desc='Inner Epoch Loop'):
	    for x_real in tqdm.tqdm(data_loader, desc='Inner Epoch Loop'):
	        x_real = x_real.to(device)
	        z = torch.randn(args.batch_size, args.z_dim, 1, 1).to(device)

#--------training D-----------
	        x_fake = G(z)
	        #print('x_real.shape:'+str(x_real.shape))
	        x_real_d_logit = D(x_real)
	        x_fake_d_logit = D(x_fake.detach())

	        x_real_d_loss, x_fake_d_loss = d_loss_fn(x_real_d_logit, x_fake_d_logit)

	        gp = g_penal.gradient_penalty(functools.partial(D), x_real, x_fake.detach(), gp_mode=args.gradient_penalty_mode, sample_mode=args.gradient_penalty_sample_mode)
	        #gp = torch.tensor(0.0)
	        D_loss = (x_real_d_loss + x_fake_d_loss) + gp * args.gradient_penalty_weight
	        #D_loss = 1/(1+0.005*ep)*D_loss # 渐进式GP!

	        D.zero_grad()
	        D_loss.backward()
	        D_optimizer.step()
	        #decayD.step()

	        D_loss_dict={'d_loss': x_real_d_loss + x_fake_d_loss, 'gp': gp}

	        it_d += 1
	        for k, v in D_loss_dict.items():
	            writer.add_scalar('D/%s' % k, v.data.cpu().numpy(), global_step=it_d)