def comp_and_save_eigs(step, n_eigs=20):
        eig_checkpoint = torch.load('%s/checkpoint_step_%06d.pth' % (exp_dir, step),
                                    map_location=device)
        evalG.load_state_dict(eig_checkpoint['state_gen'])
        evalD.load_state_dict(eig_checkpoint['state_dis'])
        gen_eigs, dis_eigs, game_eigs = \
            compute_eigenvalues(evalG, evalD, eig_dataloader, config,
                                model_loss_gen, model_loss_dis,
                                device, verbose=True, n_eigs=n_eigs)
        np.savez(os.path.join(plots_dir, 'eigenvalues_%d' % step),
                 gen_eigs=gen_eigs, dis_eigs=dis_eigs, game_eigs=game_eigs)

        return gen_eigs, dis_eigs, game_eigs
sigmas = np.repeat(np.array([SIGMA,SIGMA]).reshape(2,1), DIM_DATA, axis=1)

gm = GaussianMixture(prob, mus, sigmas)
x_examples = gm.sample(NUM_SAMPLES)
z_examples = torch.zeros(NUM_SAMPLES, DIM_LATENT).normal_()
dataset = TensorDataset(torch.tensor(x_examples), torch.tensor(z_examples))
if DETERMINISTIC:
    np.savez(os.path.join(OUTPUT_PATH, 'data.npz'), x=x_examples, z=z_examples)
    dataloader = DataLoader(dataset, batch_size=NUM_SAMPLES)
else:
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE)

i = 0
n_dis_update = 0
n_gen_update = 0
gen_eigs1, dis_eigs1, game_eigs1 = compute_eigenvalues(gen, dis, dataloader, args, model_loss_gen, model_loss_dis, device, verbose=True)
for epoch in range(NUM_ITER):
    for x, z in dataloader:
        update_gen = False
        if CUDA:
            z = z.cuda()
            x = x.cuda()

        x_gen = gen(z)
        loss_dis, D_x, D_G_z1 = model_loss_dis(x, x_gen, dis, device)
        loss_gen, D_G_z2 = model_loss_gen(x_gen, dis, device)
        if MODEL == 'wgan_gp':
            penalty = dis.get_penalty(x, x_gen, mode="data").mean()
            loss_dis += GRADIENT_PENALTY*penalty

        if i%2 == 0 or METHOD == "sim":