def bgan_synth(synth_dataset, z_dim, batch_size=64, numz=5, num_iter=1000, wasserstein=False, rpath="synth_results", gen_observed=10, base_learning_rate=1e-2, lr_decay=3., save_weights=False): bgan = BGAN( [synth_dataset.x_dim], z_dim, synth_dataset.N, batch_size=batch_size, prior_std=10.0, alpha=1e-3, J=numz, M=1, ml=(numz == 1), num_classes=1, wasserstein=wasserstein, # unsupervised only gen_observed=gen_observed) print "Starting session" session = get_session() tf.global_variables_initializer().run() print "Starting training loop" num_train_iter = num_iter sample_z = np.random.uniform(-1, 1, size=(batch_size, z_dim)) all_aics_fake, all_data_fake, all_dists = [], [], [] for train_iter in range(num_train_iter): learning_rate = base_learning_rate * np.exp(-lr_decay * min( 1.0, (train_iter * batch_size) / float(synth_dataset.N))) print learning_rate batch_z = np.random.uniform(-1, 1, [batch_size, z_dim]) input_batch = synth_dataset.next_batch(batch_size) _, d_loss = session.run( [bgan.d_optim, bgan.d_loss], feed_dict={ bgan.inputs: input_batch, bgan.z: batch_z, bgan.d_learning_rate: learning_rate }) if wasserstein: session.run(bgan.clip_d, feed_dict={}) g_losses = [] for gi in xrange(bgan.num_gen): # compute g_sample loss batch_z = np.random.uniform(-1, 1, [batch_size, z_dim]) _, g_loss = session.run( [bgan.g_optims[gi], bgan.generation["g_losses"][gi]], feed_dict={ bgan.z: batch_z, bgan.g_learning_rate: learning_rate }) g_losses.append(g_loss) print "Disc loss = %.2f, Gen loss = %s" % (d_loss, ", ".join( ["%.2f" % gl for gl in g_losses])) if (train_iter + 1) % 100 == 0: print "Disc loss = %.2f, Gen loss = %s" % (d_loss, ", ".join( ["%.2f" % gl for gl in g_losses])) print "Running GMM on sampled data" fake_data = [] for num_samples in xrange(10): for gi in xrange(bgan.num_gen): # collect sample sample_z = np.random.uniform(-1, 1, size=(batch_size, z_dim)) sampled_data = session.run( bgan.generation["gen_samplers"][gi], feed_dict={bgan.z: sample_z}) fake_data.append(sampled_data) X_real = synth_dataset.X X_sample = np.concatenate(fake_data) all_data_fake.append(X_sample) aics_fake = gmm_ms(X_sample) print "Fake number of clusters (AIC estimate):", aics_fake.argmin() dist, X_trans_real, X_trans_fake = analyze_div(X_real, X_sample) print "JS div:", dist fp = FigPrinter((1, 2)) xmin1 = np.min(X_trans_real[:, 0]) - 1.0 xmax1 = np.max(X_trans_real[:, 0]) + 1.0 xmin2 = np.min(X_trans_real[:, 1]) - 1.0 xmax2 = np.max(X_trans_real[:, 1]) + 1.0 fp.ax_arr[0].plot(X_trans_real[:, 0], X_trans_real[:, 1], '.r') fp.ax_arr[0].set_xlim([xmin1, xmax1]) fp.ax_arr[0].set_ylim([xmin2, xmax2]) fp.ax_arr[1].plot(X_trans_fake[:, 0], X_trans_fake[:, 1], '.g') fp.ax_arr[1].set_xlim([xmin1, xmax1]) fp.ax_arr[1].set_ylim([xmin2, xmax2]) fp.ax_arr[0].set_aspect('equal', adjustable='box') fp.ax_arr[1].set_aspect('equal', adjustable='box') fp.ax_arr[1].set_title("Iter %i" % (train_iter + 1)) fp.print_to_file( os.path.join( rpath, "pca_distribution_%i_%i.png" % (numz, train_iter + 1))) all_aics_fake.append(aics_fake) all_dists.append(dist) if save_weights: var_dict = {} for var in tf.trainable_variables(): var_dict[var.name] = session.run(var.name) np.savez_compressed( os.path.join(rpath, "weights_%i.npz" % train_iter), **var_dict) return { "data_fake": all_data_fake, "data_real": synth_dataset.X, "z_dim": z_dim, "numz": numz, "num_iter": num_iter, "divergences": all_dists, "all_aics_fake": np.array(all_aics_fake) }
def bgan_synth(synth_dataset, z_dim, batch_size=64, numz=5, num_iter=1000, wasserstein=False, rpath="synth_results", base_learning_rate=1e-2, lr_decay=3., save_weights=False): bgan = BGAN( [synth_dataset.x_dim], z_dim, synth_dataset.N, batch_size=batch_size, prior_std=10.0, alpha=1e-3, J=numz, M=1, ml=(numz == 1), num_classes=1, wasserstein=wasserstein, # unsupervised only ) print("Starting session") session = get_session() tf.global_variables_initializer().run() print("Starting training loop") num_train_iter = num_iter all_aics_fake, all_data_fake, all_dists = [], [], [] it_loss, it_acc, acc = [], [], [] mean = np.random.uniform(-1, 1, size=z_dim) cov = np.random.uniform(-1, 1, size=(z_dim, z_dim)) cov = np.dot(cov, cov.transpose()) for train_iter in range(num_train_iter): learning_rate = base_learning_rate * np.exp(-lr_decay * min( 1.0, (train_iter * batch_size) / float(synth_dataset.N))) print(learning_rate) batch_z = np.random.uniform(-1, 1, [batch_size, z_dim]) #batch_z = np.random.multivariate_normal(mean, cov, size=batch_size) #batch_z = np.random.normal(0, 1.0, size=(batch_size, z_dim)) _, input_batch = synth_dataset.next_batch(batch_size) _, d_loss = session.run( [bgan.d_optim, bgan.d_loss], feed_dict={ bgan.inputs: input_batch, bgan.z: batch_z, bgan.d_learning_rate: learning_rate }) if wasserstein: session.run(bgan.clip_d, feed_dict={}) g_losses = [] for gi in range(bgan.num_gen): # compute g_sample loss batch_z = np.random.uniform(-1, 1, [batch_size, z_dim]) #batch_z = np.random.multivariate_normal(mean, cov, size=batch_size) #batch_z = np.random.normal(0, 1.0, size=(batch_size, z_dim)) _, g_loss, prior = session.run([ bgan.g_optims[gi], bgan.generation["g_losses"][gi], bgan.generation["g_prior"][gi] ], feed_dict={ bgan.z: batch_z, bgan.g_learning_rate: learning_rate }) print(prior, gi) g_losses.append(g_loss) print("Disc loss = %.2f, Gen loss = %s" % (d_loss, ", ".join(["%.2f" % gl for gl in g_losses]))) if (train_iter + 1) % 100 == 0: print("Disc loss = %.2f, Gen loss = %s" % (d_loss, ", ".join(["%.2f" % gl for gl in g_losses]))) print("Running GMM on sampled data") fake_data = [] for num_samples in range(10): for gi in range(bgan.num_gen): # collect sample sample_z = np.random.uniform(-1, 1, size=(batch_size, z_dim)) #sample_z = np.random.multivariate_normal(mean, cov, size=batch_size) #sample_z = np.random.normal(0, 1.0, size=(batch_size, z_dim)) sampled_data = session.run( bgan.generation["gen_samplers"][gi], feed_dict={bgan.z: sample_z}) fake_data.append(sampled_data) X_real = synth_dataset.X X_sample = np.concatenate(fake_data) all_data_fake.append(X_sample) """aics_fake = gmm_ms(X_sample) print ("Fake number of clusters (AIC estimate):", aics_fake.argmin())""" dist, X_trans_real, X_trans_fake = analyze_div(X_real, X_sample) print("JS div:", dist) """fp = FigPrinter((1,2)) xmin1 = np.min(X_trans_real[:, 0]) - 1.0 xmax1 = np.max(X_trans_real[:, 0]) + 1.0 xmin2 = np.min(X_trans_real[:, 1]) - 1.0 xmax2 = np.max(X_trans_real[:, 1]) + 1.0 fp.ax_arr[0].plot(X_trans_real[:, 0], X_trans_real[:, 1], '.r') fp.ax_arr[0].set_xlim([xmin1, xmax1]); fp.ax_arr[0].set_ylim([xmin2, xmax2]) fp.ax_arr[1].plot(X_trans_fake[:, 0], X_trans_fake[:, 1], '.g') fp.ax_arr[1].set_xlim([xmin1, xmax1]); fp.ax_arr[1].set_ylim([xmin2, xmax2]) fp.ax_arr[0].set_aspect('equal', adjustable='box') fp.ax_arr[1].set_aspect('equal', adjustable='box') fp.ax_arr[1].set_title("Iter %i" % (train_iter+1)) fp.print_to_file(os.path.join(rpath, "pca_distribution_%i_%i.png" % (numz, train_iter+1)))""" all_dists.append(dist) it_loss.append(train_iter + 1) if save_weights: var_dict = {} for var in tf.trainable_variables(): var_dict[var.name] = session.run(var.name) np.savez_compressed( os.path.join(rpath, "weights_%i.npz" % train_iter), **var_dict) if (train_iter + 1) % 200 == 0: count_pos = 0 count = 0 for i in range(int(synth_dataset.N_test / batch_size)): d_logit = session.run( [bgan.D], feed_dict={ bgan.inputs: synth_dataset.X_test[i * batch_size:(i + 1) * batch_size] }) d_logit = d_logit[0] count_pos += np.sum(np.argmax(d_logit, 1)) count += batch_size acc.append(100.0 * count_pos / count) it_acc.append(train_iter + 1) fp = FigPrinter((1, 2)) fp.ax_arr[0].plot(it_loss, all_dists) fp.ax_arr[0].set_title("JS Divergence") fp.ax_arr[1].plot(it_acc, acc) fp.ax_arr[1].set_title("Test Accuracy") fp.print_to_file(os.path.join(rpath, "divergence_disc.png")) return { "data_fake": all_data_fake, "data_real": synth_dataset.X, "z_dim": z_dim, "numz": numz, "num_iter": num_iter, "divergences": all_dists, "it_loss": it_loss, "acc": acc, "it_acc": it_acc }