D_loss = D_real_loss + D_fake_loss D_loss.backward() D_optimizer.step() ''' 训练生成器 ''' Net_G.zero_grad() Noise_var = Variable(torch.randn(BATCH_SIZE, NOISE_DIM)) image_fake = Net_G(Noise_var,label_var) D_fake = Net_D(image_fake,label_var) G_loss = BCELoss()(D_fake, label_true_var) G_loss.backward() G_optimizer.step() proBar.show(D_loss.data[0], G_loss.data[0]) Noise_var = Variable(torch.randn(BATCH_SIZE, NOISE_DIM)) y = (torch.ones(BATCH_SIZE) * 7).long() y = one_hot(y) y = Variable(y.cuda() if GPU_NUMS > 1 else y) samples = Net_G(Noise_var,y)[:24] img = torchvision.utils.make_grid( samples.data) npimg = img.cpu().numpy() plt.imshow(np.transpose(npimg, (1,2,0))) if not os.path.exists('out/'): os.makedirs('out/')
label_fake_G_var = Variable(onehot[label_fake].cuda() if CONFIG["GPU_NUMS"] > 0 else onehot[label_fake]) label_fake_D_var = Variable(fill[label_fake].cuda() if CONFIG["GPU_NUMS"] > 0 else fill[label_fake]) g_result = NetG(img_fake_var, label_fake_G_var) d_result = NetD(g_result, label_fake_D_var) d_result = d_result.squeeze() D_LOSS_FAKE = BCELoss()(d_result, label_false_var) D_train_loss = D_LOSS_REAL + D_LOSS_FAKE D_train_loss.backward() D_optimizer.step() NetG.zero_grad() img_fake = torch.randn((mini_batch, 100)).view(-1, 100, 1, 1) label_fake = (torch.rand(mini_batch, 1) * 10).type(torch.LongTensor).squeeze() img_fake_var = Variable(img_fake.cuda() if CONFIG["GPU_NUMS"] > 0 else img_fake) label_fake_G_var = Variable(onehot[label_fake].cuda() if CONFIG["GPU_NUMS"] > 0 else onehot[label_fake]) label_fake_D_var = Variable(fill[label_fake].cuda() if CONFIG["GPU_NUMS"] > 0 else fill[label_fake]) g_result = NetG(img_fake_var, label_fake_G_var) d_result = NetD(g_result, label_fake_D_var) d_result = d_result.squeeze() G_train_loss= BCELoss()(d_result, label_true_var) G_train_loss.backward() G_optimizer.step() bar.show(epoch, D_train_loss.item(), G_train_loss.item()) test_images = NetG(fixed_z_, fixed_y_label_) torchvision.utils.save_image(test_images.data[:100],'outputs/mnist_%03d.png' % (epoch),nrow=10, normalize=True,range=(-1,1), padding=0)