def main(config): print("Hyper-params:") print(config) # create exp folder and save config exp_dir = os.path.join(config.exp_dir, config.exp_name) if not os.path.exists(exp_dir): os.makedirs(exp_dir) plots_dir = os.path.join(exp_dir, 'extra_plots') if not os.path.exists(plots_dir): os.makedirs(plots_dir) if config.manualSeed is None: config.manualSeed = random.randint(1, 10000) print("Random Seed: ", config.manualSeed) random.seed(config.manualSeed) torch.manual_seed(config.manualSeed) np.random.seed(config.manualSeed) if torch.cuda.is_available(): torch.cuda.manual_seed(config.manualSeed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device {0!s}".format(device)) dataloader = load_mnist(config.batchSize) eval_dataloader = load_mnist(config.batchSize, subset=5000) eig_dataloader = load_mnist(1000, train=True, subset=1000) fixed_noise = torch.randn(64, config.nz, 1, 1, device=device) # define the model netG = Generator(config.ngpu, config.nc, config.ngf, config.nz).to(device) netG.apply(weights_init) if config.netG != '': print('loading generator from %s' % config.netG) netG.load_state_dict(torch.load(config.netG)['state_gen']) print(netG) # sigmoid = config.model == 'dcgan' sigmoid = False netD = Discriminator(config.ngpu, config.nc, config.ndf, config.dnorm, sigmoid).to(device) netD.apply(weights_init) if config.netD != '': print('loading discriminator from %s' % config.netD) netD.load_state_dict(torch.load(config.netD)['state_dis']) print(netD) # evaluation G and D evalG = Generator(config.ngpu, config.nc, config.ngf, config.nz).to(device) evalG.apply(weights_init) evalD = Discriminator(config.ngpu, config.nc, config.ndf, config.dnorm, sigmoid).to(device) evalD.apply(weights_init) # defining the loss function model_loss_dis, model_loss_gen = define_model_loss(config) # # defining learning rates based on the model # if config.model in ['wgan', 'wgan_gp']: # config.lrG = config.lrD / config.n_critic # warnings.warn('modifying learning rates to lrD=%f, lrG=%f' % (config.lrD, config.lrG)) if config.lrG is None: config.lrG = config.lrD # setup optimizer if config.optimizer == 'adam': optimizerD = optim.Adam(netD.parameters(), lr=config.lrD, betas=(config.beta1, config.beta2)) optimizerG = optim.Adam(netG.parameters(), lr=config.lrG, betas=(config.beta1, config.beta2)) elif config.optimizer == 'extraadam': optimizerD = ExtraAdam(netD.parameters(), lr=config.lrD) optimizerG = ExtraAdam(netG.parameters(), lr=config.lrG) elif config.optimizer == 'rmsprop': optimizerD = optim.RMSprop(netD.parameters(), lr=config.lrD) optimizerG = optim.RMSprop(netG.parameters(), lr=config.lrG) elif config.optimizer == 'sgd': optimizerD = optim.SGD(netD.parameters(), lr=config.lrD, momentum=config.beta1) optimizerG = optim.SGD(netG.parameters(), lr=config.lrG, momentum=config.beta1) else: raise ValueError('Optimizer %s not supported' % config.optimizer) with open(os.path.join(exp_dir, 'config.json'), 'w') as f: json.dump(vars(config), f, indent=4) summary_writer = SummaryWriter(log_dir=exp_dir) global_step = 0 torch.save({ 'state_gen': netG.state_dict(), 'state_dis': netD.state_dict() }, '%s/checkpoint_step_%06d.pth' % (exp_dir, global_step)) # compute and save eigen values function def comp_and_save_eigs(step, n_eigs=20): eig_checkpoint = torch.load('%s/checkpoint_step_%06d.pth' % (exp_dir, step), map_location=device) evalG.load_state_dict(eig_checkpoint['state_gen']) evalD.load_state_dict(eig_checkpoint['state_dis']) gen_eigs, dis_eigs, game_eigs = \ compute_eigenvalues(evalG, evalD, eig_dataloader, config, model_loss_gen, model_loss_dis, device, verbose=True, n_eigs=n_eigs) np.savez(os.path.join(plots_dir, 'eigenvalues_%d' % step), gen_eigs=gen_eigs, dis_eigs=dis_eigs, game_eigs=game_eigs) return gen_eigs, dis_eigs, game_eigs if config.compute_eig: # eigenvalues of initialization gen_eigs_init, dis_eigs_init, game_eigs_init = comp_and_save_eigs(0) for epoch in range(config.niter): for i, data in enumerate(dataloader, 0): global_step += 1 ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### # train with real netD.zero_grad() x_real = data[0].to(device) batch_size = x_real.size(0) noise = torch.randn(batch_size, config.nz, 1, 1, device=device) x_fake = netG(noise) errD, D_x, D_G_z1 = model_loss_dis(x_real, x_fake.detach(), netD, device) # gradient penalty if config.model == 'wgan_gp': errD += config.gp_lambda * netD.get_penalty( x_real.detach(), x_fake.detach()) errD.backward() D_x = D_x.mean().item() D_G_z1 = D_G_z1.mean().item() if config.optimizer == "extraadam": if i % 2 == 0: optimizerD.extrapolation() else: optimizerD.step() else: optimizerD.step() # weight clipping if config.model == 'wgan': for p in netD.parameters(): p.data.clamp_(-config.clip, config.clip) ############################ # (2) Update G network: maximize log(D(G(z))) ########################### if config.model == 'dcgan' or (config.model in ['wgan', 'wgan_gp'] and i % config.n_critic == 0): netG.zero_grad() errG, D_G_z2 = model_loss_gen(x_fake, netD, device) errG.backward() D_G_z2 = D_G_z2.mean().item() if config.optimizer == "extraadam": if i % 2 == 0: optimizerG.extrapolation() else: optimizerG.step() else: optimizerG.step() if global_step % config.printFreq == 0: print( '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, config.niter, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) summary_writer.add_scalar("loss/D", errD.item(), global_step) summary_writer.add_scalar("loss/G", errG.item(), global_step) summary_writer.add_scalar("output/D_real", D_x, global_step) summary_writer.add_scalar("output/D_fake", D_G_z1, global_step) # every epoch save samples fake = netG(fixed_noise) # vutils.save_image(fake.detach(), # '%s/fake_samples_step-%06d.png' % (exp_dir, global_step), # normalize=True) fake_grid = vutils.make_grid(fake.detach(), normalize=True) summary_writer.add_image("G_samples", fake_grid, global_step) # generate samples for IS evaluation IS_fake = [] for i in range(10): noise = torch.randn(500, config.nz, 1, 1, device=device) IS_fake.append(netG(noise)) IS_fake = torch.cat(IS_fake) IS_mean, IS_std = mnist_inception_score(IS_fake, device) print("IS score: mean=%.4f, std=%.4f" % (IS_mean, IS_std)) summary_writer.add_scalar("IS_mean", IS_mean, global_step) # do checkpointing checkpoint = { 'state_gen': netG.state_dict(), 'state_dis': netD.state_dict() } torch.save(checkpoint, '%s/checkpoint_step_%06d.pth' % (exp_dir, global_step)) last_chkpt = '%s/checkpoint_step_%06d.pth' % (exp_dir, global_step) if epoch == 0: # last_chkpt = '%s/checkpoint_step_%06d.pth' % (exp_dir, 0) # for now checkpoint_1 = torch.load(last_chkpt, map_location=device) if config.compute_eig: # compute eigenvalues for epoch 1, just in case gen_eigs_curr, dis_eigs_curr, game_eigs_curr = comp_and_save_eigs( global_step) # if (epoch + 1) % 10 == 0: if global_step > 30000 and epoch % 5 == 0: checkpoint_2 = torch.load(last_chkpt, map_location=device) print("Computing path statistics...") t = time.time() hist = compute_path_stats(evalG, evalD, checkpoint_1, checkpoint_2, eval_dataloader, config, model_loss_gen, model_loss_dis, device, verbose=True) with open("%s/hist_%d.pkl" % (plots_dir, global_step), 'wb') as f: pickle.dump(hist, f) plot_path_stats(hist, plots_dir, summary_writer, global_step) print("Took %.2f minutes" % ((time.time() - t) / 60.)) if config.compute_eig and global_step > 30000 and epoch % 10 == 0: # compute eigenvalues and save them gen_eigs_curr, dis_eigs_curr, game_eigs_curr = comp_and_save_eigs( global_step) plot_eigenvalues([gen_eigs_init, gen_eigs_curr], [dis_eigs_init, dis_eigs_curr], [game_eigs_init, game_eigs_curr], ['init', 'step_%d' % global_step], plots_dir, summary_writer, step=global_step)
plt.hist(x_gen.cpu().squeeze().data, bins=100) writer.add_figure('hist', fig, n_gen_update) plt.clf() fig = plt.figure() plt.hist(x_gen_avg.cpu().squeeze().data, bins=100) writer.add_figure('hist_avg', fig, n_gen_update) plt.clf() if args.save_stats: if n_gen_update == 1: checkpoint_1 = torch.load(os.path.join(OUTPUT_PATH, 'checkpoints/%i.state'%(n_gen_update)), map_location=device) if n_gen_update > 1: checkpoint_2 = torch.load(os.path.join(OUTPUT_PATH, 'checkpoints/%i.state'%(n_gen_update)), map_location=device) hist = compute_path_stats(gen, dis, checkpoint_1, checkpoint_2, dataloader, args, model_loss_gen, model_loss_dis, device, verbose=True) gen_eigs2, dis_eigs2, game_eigs2 = compute_eigenvalues(gen, dis, dataloader, args, model_loss_gen, model_loss_dis, device, verbose=True, n_eigs=100) hist.update({'gen_eigs':[gen_eigs1, gen_eigs2], 'dis_eigs':[dis_eigs1, dis_eigs2], 'game_eigs':[game_eigs1, game_eigs2]}) if not os.path.exists(os.path.join(OUTPUT_PATH, "extra_plots/data")): os.makedirs(os.path.join(OUTPUT_PATH, "extra_plots/data")) if not os.path.exists(os.path.join(OUTPUT_PATH, "extra_plots/plots")): os.makedirs(os.path.join(OUTPUT_PATH, "extra_plots/plots")) with open(os.path.join(OUTPUT_PATH, "extra_plots/data/%i.pkl"%(n_gen_update)), 'wb') as f: pickle.dump(hist, f) plot_path_stats(hist, os.path.join(OUTPUT_PATH,"extra_plots/plots"), writer, n_gen_update)