Example #1
0
def main(config):
    print("Hyper-params:")
    print(config)

    # create exp folder and save config
    exp_dir = os.path.join(config.exp_dir, config.exp_name)
    if not os.path.exists(exp_dir):
        os.makedirs(exp_dir)

    plots_dir = os.path.join(exp_dir, 'extra_plots')
    if not os.path.exists(plots_dir):
        os.makedirs(plots_dir)

    if config.manualSeed is None:
        config.manualSeed = random.randint(1, 10000)
    print("Random Seed: ", config.manualSeed)
    random.seed(config.manualSeed)
    torch.manual_seed(config.manualSeed)
    np.random.seed(config.manualSeed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(config.manualSeed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device {0!s}".format(device))

    dataloader = load_mnist(config.batchSize)
    eval_dataloader = load_mnist(config.batchSize, subset=5000)
    eig_dataloader = load_mnist(1000, train=True, subset=1000)
    fixed_noise = torch.randn(64, config.nz, 1, 1, device=device)

    # define the model
    netG = Generator(config.ngpu, config.nc, config.ngf, config.nz).to(device)
    netG.apply(weights_init)
    if config.netG != '':
        print('loading generator from %s' % config.netG)
        netG.load_state_dict(torch.load(config.netG)['state_gen'])
    print(netG)

    # sigmoid = config.model == 'dcgan'
    sigmoid = False

    netD = Discriminator(config.ngpu, config.nc, config.ndf, config.dnorm,
                         sigmoid).to(device)
    netD.apply(weights_init)
    if config.netD != '':
        print('loading discriminator from %s' % config.netD)
        netD.load_state_dict(torch.load(config.netD)['state_dis'])
    print(netD)

    # evaluation G and D
    evalG = Generator(config.ngpu, config.nc, config.ngf, config.nz).to(device)
    evalG.apply(weights_init)
    evalD = Discriminator(config.ngpu, config.nc, config.ndf, config.dnorm,
                          sigmoid).to(device)
    evalD.apply(weights_init)

    # defining the loss function
    model_loss_dis, model_loss_gen = define_model_loss(config)

    # # defining learning rates based on the model
    # if config.model in ['wgan', 'wgan_gp']:
    #     config.lrG = config.lrD / config.n_critic
    #     warnings.warn('modifying learning rates to lrD=%f, lrG=%f' % (config.lrD, config.lrG))

    if config.lrG is None:
        config.lrG = config.lrD

    # setup optimizer
    if config.optimizer == 'adam':
        optimizerD = optim.Adam(netD.parameters(),
                                lr=config.lrD,
                                betas=(config.beta1, config.beta2))
        optimizerG = optim.Adam(netG.parameters(),
                                lr=config.lrG,
                                betas=(config.beta1, config.beta2))
    elif config.optimizer == 'extraadam':
        optimizerD = ExtraAdam(netD.parameters(), lr=config.lrD)
        optimizerG = ExtraAdam(netG.parameters(), lr=config.lrG)

    elif config.optimizer == 'rmsprop':
        optimizerD = optim.RMSprop(netD.parameters(), lr=config.lrD)
        optimizerG = optim.RMSprop(netG.parameters(), lr=config.lrG)

    elif config.optimizer == 'sgd':
        optimizerD = optim.SGD(netD.parameters(),
                               lr=config.lrD,
                               momentum=config.beta1)
        optimizerG = optim.SGD(netG.parameters(),
                               lr=config.lrG,
                               momentum=config.beta1)
    else:
        raise ValueError('Optimizer %s not supported' % config.optimizer)

    with open(os.path.join(exp_dir, 'config.json'), 'w') as f:
        json.dump(vars(config), f, indent=4)

    summary_writer = SummaryWriter(log_dir=exp_dir)

    global_step = 0
    torch.save({
        'state_gen': netG.state_dict(),
        'state_dis': netD.state_dict()
    }, '%s/checkpoint_step_%06d.pth' % (exp_dir, global_step))

    # compute and save eigen values function
    def comp_and_save_eigs(step, n_eigs=20):
        eig_checkpoint = torch.load('%s/checkpoint_step_%06d.pth' %
                                    (exp_dir, step),
                                    map_location=device)
        evalG.load_state_dict(eig_checkpoint['state_gen'])
        evalD.load_state_dict(eig_checkpoint['state_dis'])
        gen_eigs, dis_eigs, game_eigs = \
            compute_eigenvalues(evalG, evalD, eig_dataloader, config,
                                model_loss_gen, model_loss_dis,
                                device, verbose=True, n_eigs=n_eigs)
        np.savez(os.path.join(plots_dir, 'eigenvalues_%d' % step),
                 gen_eigs=gen_eigs,
                 dis_eigs=dis_eigs,
                 game_eigs=game_eigs)

        return gen_eigs, dis_eigs, game_eigs

    if config.compute_eig:
        # eigenvalues of initialization
        gen_eigs_init, dis_eigs_init, game_eigs_init = comp_and_save_eigs(0)

    for epoch in range(config.niter):
        for i, data in enumerate(dataloader, 0):
            global_step += 1
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD.zero_grad()
            x_real = data[0].to(device)
            batch_size = x_real.size(0)
            noise = torch.randn(batch_size, config.nz, 1, 1, device=device)
            x_fake = netG(noise)

            errD, D_x, D_G_z1 = model_loss_dis(x_real, x_fake.detach(), netD,
                                               device)

            # gradient penalty
            if config.model == 'wgan_gp':
                errD += config.gp_lambda * netD.get_penalty(
                    x_real.detach(), x_fake.detach())

            errD.backward()
            D_x = D_x.mean().item()
            D_G_z1 = D_G_z1.mean().item()

            if config.optimizer == "extraadam":
                if i % 2 == 0:
                    optimizerD.extrapolation()
                else:
                    optimizerD.step()
            else:
                optimizerD.step()

            # weight clipping
            if config.model == 'wgan':
                for p in netD.parameters():
                    p.data.clamp_(-config.clip, config.clip)

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################

            if config.model == 'dcgan' or (config.model in ['wgan', 'wgan_gp']
                                           and i % config.n_critic == 0):
                netG.zero_grad()
                errG, D_G_z2 = model_loss_gen(x_fake, netD, device)
                errG.backward()
                D_G_z2 = D_G_z2.mean().item()

                if config.optimizer == "extraadam":
                    if i % 2 == 0:
                        optimizerG.extrapolation()
                    else:
                        optimizerG.step()
                else:
                    optimizerG.step()

            if global_step % config.printFreq == 0:
                print(
                    '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                    % (epoch, config.niter, i, len(dataloader), errD.item(),
                       errG.item(), D_x, D_G_z1, D_G_z2))
                summary_writer.add_scalar("loss/D", errD.item(), global_step)
                summary_writer.add_scalar("loss/G", errG.item(), global_step)
                summary_writer.add_scalar("output/D_real", D_x, global_step)
                summary_writer.add_scalar("output/D_fake", D_G_z1, global_step)

        # every epoch save samples
        fake = netG(fixed_noise)
        # vutils.save_image(fake.detach(),
        #                   '%s/fake_samples_step-%06d.png' % (exp_dir, global_step),
        #                   normalize=True)
        fake_grid = vutils.make_grid(fake.detach(), normalize=True)
        summary_writer.add_image("G_samples", fake_grid, global_step)

        # generate samples for IS evaluation
        IS_fake = []
        for i in range(10):
            noise = torch.randn(500, config.nz, 1, 1, device=device)
            IS_fake.append(netG(noise))
        IS_fake = torch.cat(IS_fake)

        IS_mean, IS_std = mnist_inception_score(IS_fake, device)
        print("IS score: mean=%.4f, std=%.4f" % (IS_mean, IS_std))
        summary_writer.add_scalar("IS_mean", IS_mean, global_step)

        # do checkpointing
        checkpoint = {
            'state_gen': netG.state_dict(),
            'state_dis': netD.state_dict()
        }
        torch.save(checkpoint,
                   '%s/checkpoint_step_%06d.pth' % (exp_dir, global_step))
        last_chkpt = '%s/checkpoint_step_%06d.pth' % (exp_dir, global_step)

        if epoch == 0:
            # last_chkpt = '%s/checkpoint_step_%06d.pth' % (exp_dir, 0)  # for now
            checkpoint_1 = torch.load(last_chkpt, map_location=device)

            if config.compute_eig:
                # compute eigenvalues for epoch 1, just in case
                gen_eigs_curr, dis_eigs_curr, game_eigs_curr = comp_and_save_eigs(
                    global_step)

        # if (epoch + 1) % 10 == 0:
        if global_step > 30000 and epoch % 5 == 0:
            checkpoint_2 = torch.load(last_chkpt, map_location=device)
            print("Computing path statistics...")
            t = time.time()

            hist = compute_path_stats(evalG,
                                      evalD,
                                      checkpoint_1,
                                      checkpoint_2,
                                      eval_dataloader,
                                      config,
                                      model_loss_gen,
                                      model_loss_dis,
                                      device,
                                      verbose=True)

            with open("%s/hist_%d.pkl" % (plots_dir, global_step), 'wb') as f:
                pickle.dump(hist, f)

            plot_path_stats(hist, plots_dir, summary_writer, global_step)

            print("Took %.2f minutes" % ((time.time() - t) / 60.))

        if config.compute_eig and global_step > 30000 and epoch % 10 == 0:
            # compute eigenvalues and save them
            gen_eigs_curr, dis_eigs_curr, game_eigs_curr = comp_and_save_eigs(
                global_step)

            plot_eigenvalues([gen_eigs_init, gen_eigs_curr],
                             [dis_eigs_init, dis_eigs_curr],
                             [game_eigs_init, game_eigs_curr],
                             ['init', 'step_%d' % global_step],
                             plots_dir,
                             summary_writer,
                             step=global_step)
            plt.hist(x_gen.cpu().squeeze().data, bins=100)
            writer.add_figure('hist', fig, n_gen_update)
            plt.clf()

            fig = plt.figure()
            plt.hist(x_gen_avg.cpu().squeeze().data, bins=100)
            writer.add_figure('hist_avg', fig, n_gen_update)
            plt.clf()

            if args.save_stats:
                if n_gen_update == 1:
                    checkpoint_1 = torch.load(os.path.join(OUTPUT_PATH, 'checkpoints/%i.state'%(n_gen_update)), map_location=device)

                if n_gen_update > 1:
                    checkpoint_2 = torch.load(os.path.join(OUTPUT_PATH, 'checkpoints/%i.state'%(n_gen_update)), map_location=device)
                    hist = compute_path_stats(gen, dis, checkpoint_1, checkpoint_2, dataloader,
                                              args, model_loss_gen, model_loss_dis, device, verbose=True)

                    gen_eigs2, dis_eigs2, game_eigs2 = compute_eigenvalues(gen, dis, dataloader, args, model_loss_gen,
                    model_loss_dis, device, verbose=True, n_eigs=100)
                    hist.update({'gen_eigs':[gen_eigs1, gen_eigs2], 'dis_eigs':[dis_eigs1, dis_eigs2],
                                'game_eigs':[game_eigs1, game_eigs2]})

                    if not os.path.exists(os.path.join(OUTPUT_PATH, "extra_plots/data")):
                        os.makedirs(os.path.join(OUTPUT_PATH, "extra_plots/data"))
                    if not os.path.exists(os.path.join(OUTPUT_PATH, "extra_plots/plots")):
                        os.makedirs(os.path.join(OUTPUT_PATH, "extra_plots/plots"))

                    with open(os.path.join(OUTPUT_PATH, "extra_plots/data/%i.pkl"%(n_gen_update)), 'wb') as f:
                        pickle.dump(hist, f)

                    plot_path_stats(hist, os.path.join(OUTPUT_PATH,"extra_plots/plots"), writer, n_gen_update)