Beispiel #1
0
def bgan_synth(synth_dataset,
               z_dim,
               batch_size=64,
               numz=5,
               num_iter=1000,
               wasserstein=False,
               rpath="synth_results",
               gen_observed=10,
               base_learning_rate=1e-2,
               lr_decay=3.,
               save_weights=False):

    bgan = BGAN(
        [synth_dataset.x_dim],
        z_dim,
        synth_dataset.N,
        batch_size=batch_size,
        prior_std=10.0,
        alpha=1e-3,
        J=numz,
        M=1,
        ml=(numz == 1),
        num_classes=1,
        wasserstein=wasserstein,  # unsupervised only
        gen_observed=gen_observed)
    print "Starting session"
    session = get_session()
    tf.global_variables_initializer().run()

    print "Starting training loop"

    num_train_iter = num_iter

    sample_z = np.random.uniform(-1, 1, size=(batch_size, z_dim))

    all_aics_fake, all_data_fake, all_dists = [], [], []

    for train_iter in range(num_train_iter):

        learning_rate = base_learning_rate * np.exp(-lr_decay * min(
            1.0, (train_iter * batch_size) / float(synth_dataset.N)))
        print learning_rate

        batch_z = np.random.uniform(-1, 1, [batch_size, z_dim])

        input_batch = synth_dataset.next_batch(batch_size)
        _, d_loss = session.run(
            [bgan.d_optim, bgan.d_loss],
            feed_dict={
                bgan.inputs: input_batch,
                bgan.z: batch_z,
                bgan.d_learning_rate: learning_rate
            })
        if wasserstein:
            session.run(bgan.clip_d, feed_dict={})

        g_losses = []
        for gi in xrange(bgan.num_gen):

            # compute g_sample loss
            batch_z = np.random.uniform(-1, 1, [batch_size, z_dim])
            _, g_loss = session.run(
                [bgan.g_optims[gi], bgan.generation["g_losses"][gi]],
                feed_dict={
                    bgan.z: batch_z,
                    bgan.g_learning_rate: learning_rate
                })
            g_losses.append(g_loss)

        print "Disc loss = %.2f, Gen loss = %s" % (d_loss, ", ".join(
            ["%.2f" % gl for gl in g_losses]))

        if (train_iter + 1) % 100 == 0:
            print "Disc loss = %.2f, Gen loss = %s" % (d_loss, ", ".join(
                ["%.2f" % gl for gl in g_losses]))
            print "Running GMM on sampled data"
            fake_data = []
            for num_samples in xrange(10):
                for gi in xrange(bgan.num_gen):
                    # collect sample
                    sample_z = np.random.uniform(-1,
                                                 1,
                                                 size=(batch_size, z_dim))
                    sampled_data = session.run(
                        bgan.generation["gen_samplers"][gi],
                        feed_dict={bgan.z: sample_z})
                    fake_data.append(sampled_data)

            X_real = synth_dataset.X
            X_sample = np.concatenate(fake_data)

            all_data_fake.append(X_sample)

            aics_fake = gmm_ms(X_sample)
            print "Fake number of clusters (AIC estimate):", aics_fake.argmin()
            dist, X_trans_real, X_trans_fake = analyze_div(X_real, X_sample)
            print "JS div:", dist
            fp = FigPrinter((1, 2))
            xmin1 = np.min(X_trans_real[:, 0]) - 1.0
            xmax1 = np.max(X_trans_real[:, 0]) + 1.0
            xmin2 = np.min(X_trans_real[:, 1]) - 1.0
            xmax2 = np.max(X_trans_real[:, 1]) + 1.0
            fp.ax_arr[0].plot(X_trans_real[:, 0], X_trans_real[:, 1], '.r')
            fp.ax_arr[0].set_xlim([xmin1, xmax1])
            fp.ax_arr[0].set_ylim([xmin2, xmax2])
            fp.ax_arr[1].plot(X_trans_fake[:, 0], X_trans_fake[:, 1], '.g')
            fp.ax_arr[1].set_xlim([xmin1, xmax1])
            fp.ax_arr[1].set_ylim([xmin2, xmax2])
            fp.ax_arr[0].set_aspect('equal', adjustable='box')
            fp.ax_arr[1].set_aspect('equal', adjustable='box')
            fp.ax_arr[1].set_title("Iter %i" % (train_iter + 1))
            fp.print_to_file(
                os.path.join(
                    rpath,
                    "pca_distribution_%i_%i.png" % (numz, train_iter + 1)))

            all_aics_fake.append(aics_fake)
            all_dists.append(dist)

            if save_weights:
                var_dict = {}
                for var in tf.trainable_variables():
                    var_dict[var.name] = session.run(var.name)

                np.savez_compressed(
                    os.path.join(rpath, "weights_%i.npz" % train_iter),
                    **var_dict)

    return {
        "data_fake": all_data_fake,
        "data_real": synth_dataset.X,
        "z_dim": z_dim,
        "numz": numz,
        "num_iter": num_iter,
        "divergences": all_dists,
        "all_aics_fake": np.array(all_aics_fake)
    }
Beispiel #2
0
def bgan_synth(synth_dataset,
               z_dim,
               batch_size=64,
               numz=5,
               num_iter=1000,
               wasserstein=False,
               rpath="synth_results",
               base_learning_rate=1e-2,
               lr_decay=3.,
               save_weights=False):

    bgan = BGAN(
        [synth_dataset.x_dim],
        z_dim,
        synth_dataset.N,
        batch_size=batch_size,
        prior_std=10.0,
        alpha=1e-3,
        J=numz,
        M=1,
        ml=(numz == 1),
        num_classes=1,
        wasserstein=wasserstein,  # unsupervised only
    )

    print("Starting session")
    session = get_session()

    tf.global_variables_initializer().run()

    print("Starting training loop")

    num_train_iter = num_iter

    all_aics_fake, all_data_fake, all_dists = [], [], []
    it_loss, it_acc, acc = [], [], []

    mean = np.random.uniform(-1, 1, size=z_dim)
    cov = np.random.uniform(-1, 1, size=(z_dim, z_dim))
    cov = np.dot(cov, cov.transpose())

    for train_iter in range(num_train_iter):

        learning_rate = base_learning_rate * np.exp(-lr_decay * min(
            1.0, (train_iter * batch_size) / float(synth_dataset.N)))
        print(learning_rate)

        batch_z = np.random.uniform(-1, 1, [batch_size, z_dim])
        #batch_z = np.random.multivariate_normal(mean, cov, size=batch_size)
        #batch_z = np.random.normal(0, 1.0, size=(batch_size, z_dim))

        _, input_batch = synth_dataset.next_batch(batch_size)
        _, d_loss = session.run(
            [bgan.d_optim, bgan.d_loss],
            feed_dict={
                bgan.inputs: input_batch,
                bgan.z: batch_z,
                bgan.d_learning_rate: learning_rate
            })
        if wasserstein:
            session.run(bgan.clip_d, feed_dict={})

        g_losses = []
        for gi in range(bgan.num_gen):

            # compute g_sample loss
            batch_z = np.random.uniform(-1, 1, [batch_size, z_dim])
            #batch_z = np.random.multivariate_normal(mean, cov, size=batch_size)
            #batch_z = np.random.normal(0, 1.0, size=(batch_size, z_dim))

            _, g_loss, prior = session.run([
                bgan.g_optims[gi], bgan.generation["g_losses"][gi],
                bgan.generation["g_prior"][gi]
            ],
                                           feed_dict={
                                               bgan.z: batch_z,
                                               bgan.g_learning_rate:
                                               learning_rate
                                           })
            print(prior, gi)
            g_losses.append(g_loss)

        print("Disc loss = %.2f, Gen loss = %s" %
              (d_loss, ", ".join(["%.2f" % gl for gl in g_losses])))

        if (train_iter + 1) % 100 == 0:
            print("Disc loss = %.2f, Gen loss = %s" %
                  (d_loss, ", ".join(["%.2f" % gl for gl in g_losses])))
            print("Running GMM on sampled data")
            fake_data = []
            for num_samples in range(10):
                for gi in range(bgan.num_gen):
                    # collect sample
                    sample_z = np.random.uniform(-1,
                                                 1,
                                                 size=(batch_size, z_dim))
                    #sample_z = np.random.multivariate_normal(mean, cov, size=batch_size)
                    #sample_z = np.random.normal(0, 1.0, size=(batch_size, z_dim))

                    sampled_data = session.run(
                        bgan.generation["gen_samplers"][gi],
                        feed_dict={bgan.z: sample_z})
                    fake_data.append(sampled_data)

            X_real = synth_dataset.X
            X_sample = np.concatenate(fake_data)

            all_data_fake.append(X_sample)
            """aics_fake = gmm_ms(X_sample)
            print ("Fake number of clusters (AIC estimate):", aics_fake.argmin())"""
            dist, X_trans_real, X_trans_fake = analyze_div(X_real, X_sample)
            print("JS div:", dist)
            """fp = FigPrinter((1,2))
            xmin1 = np.min(X_trans_real[:, 0]) - 1.0
            xmax1 = np.max(X_trans_real[:, 0]) + 1.0
            xmin2 = np.min(X_trans_real[:, 1]) - 1.0
            xmax2 = np.max(X_trans_real[:, 1]) + 1.0
            fp.ax_arr[0].plot(X_trans_real[:, 0], X_trans_real[:, 1], '.r')
            fp.ax_arr[0].set_xlim([xmin1, xmax1]); fp.ax_arr[0].set_ylim([xmin2, xmax2])
            fp.ax_arr[1].plot(X_trans_fake[:, 0], X_trans_fake[:, 1], '.g')
            fp.ax_arr[1].set_xlim([xmin1, xmax1]); fp.ax_arr[1].set_ylim([xmin2, xmax2])
            fp.ax_arr[0].set_aspect('equal', adjustable='box')
            fp.ax_arr[1].set_aspect('equal', adjustable='box')
            fp.ax_arr[1].set_title("Iter %i" % (train_iter+1))            
            fp.print_to_file(os.path.join(rpath, "pca_distribution_%i_%i.png" % (numz, train_iter+1)))"""

            all_dists.append(dist)
            it_loss.append(train_iter + 1)

            if save_weights:
                var_dict = {}
                for var in tf.trainable_variables():
                    var_dict[var.name] = session.run(var.name)

                np.savez_compressed(
                    os.path.join(rpath, "weights_%i.npz" % train_iter),
                    **var_dict)

        if (train_iter + 1) % 200 == 0:
            count_pos = 0
            count = 0
            for i in range(int(synth_dataset.N_test / batch_size)):
                d_logit = session.run(
                    [bgan.D],
                    feed_dict={
                        bgan.inputs:
                        synth_dataset.X_test[i * batch_size:(i + 1) *
                                             batch_size]
                    })
                d_logit = d_logit[0]
                count_pos += np.sum(np.argmax(d_logit, 1))
                count += batch_size

            acc.append(100.0 * count_pos / count)
            it_acc.append(train_iter + 1)

    fp = FigPrinter((1, 2))
    fp.ax_arr[0].plot(it_loss, all_dists)
    fp.ax_arr[0].set_title("JS Divergence")
    fp.ax_arr[1].plot(it_acc, acc)
    fp.ax_arr[1].set_title("Test Accuracy")
    fp.print_to_file(os.path.join(rpath, "divergence_disc.png"))

    return {
        "data_fake": all_data_fake,
        "data_real": synth_dataset.X,
        "z_dim": z_dim,
        "numz": numz,
        "num_iter": num_iter,
        "divergences": all_dists,
        "it_loss": it_loss,
        "acc": acc,
        "it_acc": it_acc
    }