def train_experimental_vaes(): """Trains and saves VAEs on the GFP data for use in the weighted ML methods""" TRAIN_SIZE = 5000 train_size_str = "%ik" % (TRAIN_SIZE / 1000) suffix = '_%s' % train_size_str for i in [0, 2]: RANDOM_STATE = i + 1 X_train, _, _ = util.get_experimental_X_y(random_state=RANDOM_STATE, train_size=TRAIN_SIZE) vae_0 = util.build_vae(latent_dim=20, n_tokens=20, seq_length=X_train.shape[1], enc1_units=50)[X_train], [X_train, np.zeros(X_train.shape[0])], epochs=100, batch_size=10, verbose=2) vae_0.encoder_.save_weights("../models/vae_0_encoder_weights%s_%i.h5" % (suffix, RANDOM_STATE)) vae_0.decoder_.save_weights("../models/vae_0_decoder_weights%s_%i.h5" % (suffix, RANDOM_STATE)) vae_0.vae_.save_weights("../models/vae_0_vae_weights%s_%i.h5" % (suffix, RANDOM_STATE))
def weighted_ml_opt(X_train, oracles, ground_truth, vae_0, weights_type='dbas', LD=20, iters=20, samples=500, homoscedastic=False, homo_y_var=0.1, quantile=0.95, verbose=False, alpha=1, train_gt_evals=None, cutoff=1e-6, it_epochs=10, enc1_units=50): """ Runs weighted maximum likelihood optimization algorithms ('CbAS', 'DbAS', RWR, and CEM-PI) """ assert weights_type in ['cbas', 'dbas', 'rwr', 'cem-pi'] L = X_train.shape[1] vae = util.build_vae(latent_dim=LD, n_tokens=20, seq_length=L, enc1_units=enc1_units) traj = np.zeros((iters, 7)) oracle_samples = np.zeros((iters, samples)) gt_samples = np.zeros((iters, samples)) oracle_max_seq = None oracle_max = -np.inf gt_of_oracle_max = -np.inf y_star = -np.inf for t in range(iters): ### Take Samples ### zt = np.random.randn(samples, LD) if t > 0: Xt_p = vae.decoder_.predict(zt) Xt = util.get_samples(Xt_p) else: Xt = X_train ### Evaluate ground truth and oracle ### yt, yt_var = util.get_balaji_predictions(oracles, Xt) if homoscedastic: yt_var = np.ones_like(yt) * homo_y_var Xt_aa = np.argmax(Xt, axis=-1) if t == 0 and train_gt_evals is not None: yt_gt = train_gt_evals else: yt_gt = ground_truth.predict(Xt_aa, print_every=1000000)[:, 0] ### Calculate weights for different schemes ### if t > 0: if weights_type == 'cbas': log_pxt = np.sum(np.log(Xt_p) * Xt, axis=(1, 2)) X0_p = vae_0.decoder_.predict(zt) log_px0 = np.sum(np.log(X0_p) * Xt, axis=(1, 2)) w1 = np.exp(log_px0 - log_pxt) y_star_1 = np.percentile(yt, quantile * 100) if y_star_1 > y_star: y_star = y_star_1 w2 = scipy.stats.norm.sf(y_star, loc=yt, scale=np.sqrt(yt_var)) weights = w1 * w2 elif weights_type == 'cem-pi': pi = scipy.stats.norm.sf(max_train_gt, loc=yt, scale=np.sqrt(yt_var)) pi_thresh = np.percentile(pi, quantile * 100) weights = (pi > pi_thresh).astype(int) elif weights_type == 'dbas': y_star_1 = np.percentile(yt, quantile * 100) if y_star_1 > y_star: y_star = y_star_1 weights = scipy.stats.norm.sf(y_star, loc=yt, scale=np.sqrt(yt_var)) elif weights_type == 'rwr': weights = np.exp(alpha * yt) weights /= np.sum(weights) else: weights = np.ones(yt.shape[0]) max_train_gt = np.max(yt_gt) yt_max_idx = np.argmax(yt) yt_max = yt[yt_max_idx] if yt_max > oracle_max: oracle_max = yt_max try: oracle_max_seq = util.convert_idx_array_to_aas( Xt_aa[yt_max_idx - 1:yt_max_idx])[0] except IndexError: print(Xt_aa[yt_max_idx - 1:yt_max_idx]) gt_of_oracle_max = yt_gt[yt_max_idx] ### Record and print results ## if t == 0: rand_idx = np.random.randint(0, len(yt), samples) oracle_samples[t, :] = yt[rand_idx] gt_samples[t, :] = yt_gt[rand_idx] if t > 0: oracle_samples[t, :] = yt gt_samples[t, :] = yt_gt traj[t, 0] = np.max(yt_gt) traj[t, 1] = np.mean(yt_gt) traj[t, 2] = np.std(yt_gt) traj[t, 3] = np.max(yt) traj[t, 4] = np.mean(yt) traj[t, 5] = np.std(yt) traj[t, 6] = np.mean(yt_var) if verbose: print(weights_type.upper(), t, traj[t, 0], color.BOLD + str(traj[t, 1]) + color.END, traj[t, 2], traj[t, 3], color.BOLD + str(traj[t, 4]) + color.END, traj[t, 5], traj[t, 6]) ### Train model ### if t == 0: vae.encoder_.set_weights(vae_0.encoder_.get_weights()) vae.decoder_.set_weights(vae_0.decoder_.get_weights()) vae.vae_.set_weights(vae_0.vae_.get_weights()) else: cutoff_idx = np.where(weights < cutoff) Xt = np.delete(Xt, cutoff_idx, axis=0) yt = np.delete(yt, cutoff_idx, axis=0) weights = np.delete(weights, cutoff_idx, axis=0)[Xt], [Xt, np.zeros(Xt.shape[0])], epochs=it_epochs, batch_size=10, shuffle=False, sample_weight=[weights, weights], verbose=0) max_dict = { 'oracle_max': oracle_max, 'oracle_max_seq': oracle_max_seq, 'gt_of_oracle_max': gt_of_oracle_max } return traj, oracle_samples, gt_samples, max_dict
def fb_opt(X_train, oracles, ground_truth, vae_0, weights_type='fbvae', LD=20, iters=20, samples=500, quantile=0.8, verbose=False, train_gt_evals=None, it_epochs=10, enc1_units=50): """Runs FBVAE optimization algorithm""" assert weights_type in ['fbvae'] L = X_train.shape[1] vae = util.build_vae(latent_dim=LD, n_tokens=20, seq_length=L, enc1_units=enc1_units) traj = np.zeros((iters, 7)) oracle_samples = np.zeros((iters, samples)) gt_samples = np.zeros((iters, samples)) oracle_max_seq = None oracle_max = -np.inf gt_of_oracle_max = -np.inf y_star = -np.inf for t in range(iters): ### Take Samples and evaluate ground truth and oracle ## zt = np.random.randn(samples, LD) if t > 0: Xt_sample_p = vae.decoder_.predict(zt) Xt_sample = get_samples(Xt_sample_p) yt_sample, _ = get_balaji_predictions(oracles, Xt_sample) Xt_aa_sample = np.argmax(Xt_sample, axis=-1) yt_gt_sample = ground_truth.predict(Xt_aa_sample, print_every=1000000)[:, 0] else: Xt = X_train yt, _ = util.get_balaji_predictions(oracles, Xt) Xt_aa = np.argmax(Xt, axis=-1) fb_thresh = np.percentile(yt, quantile * 100) if train_gt_evals is not None: yt_gt = train_gt_evals else: yt_gt = ground_truth.predict(Xt_aa, print_every=1000000)[:, 0] ### Calculate threshold ### if t > 0: threshold_idx = np.where(yt_sample >= fb_thresh)[0] n_top = len(threshold_idx) sample_arrs = [Xt_sample, yt_sample, yt_gt_sample, Xt_aa_sample] full_arrs = [Xt, yt, yt_gt, Xt_aa] for l in range(len(full_arrs)): sample_arr = sample_arrs[l] full_arr = full_arrs[l] sample_top = sample_arr[threshold_idx] full_arr = np.concatenate([sample_top, full_arr]) full_arr = np.delete(full_arr, range(full_arr.shape[0] - n_top, full_arr.shape[0]), axis=0) full_arrs[l] = full_arr Xt, yt, yt_gt, Xt_aa = full_arrs yt_max_idx = np.argmax(yt) yt_max = yt[yt_max_idx] if yt_max > oracle_max: oracle_max = yt_max try: oracle_max_seq = util.convert_idx_array_to_aas( Xt_aa[yt_max_idx - 1:yt_max_idx])[0] except IndexError: print(Xt_aa[yt_max_idx - 1:yt_max_idx]) gt_of_oracle_max = yt_gt[yt_max_idx] ### Record and print results ## rand_idx = np.random.randint(0, len(yt), samples) oracle_samples[t, :] = yt[rand_idx] gt_samples[t, :] = yt_gt[rand_idx] traj[t, 0] = np.max(yt_gt) traj[t, 1] = np.mean(yt_gt) traj[t, 2] = np.std(yt_gt) traj[t, 3] = np.max(yt) traj[t, 4] = np.mean(yt) traj[t, 5] = np.std(yt) if t > 0: traj[t, 6] = n_top else: traj[t, 6] = 0 if verbose: print(weights_type.upper(), t, traj[t, 0], color.BOLD + str(traj[t, 1]) + color.END, traj[t, 2], traj[t, 3], color.BOLD + str(traj[t, 4]) + color.END, traj[t, 5], traj[t, 6]) ### Train model ### if t == 0: vae.encoder_.set_weights(vae_0.encoder_.get_weights()) vae.decoder_.set_weights(vae_0.decoder_.get_weights()) vae.vae_.set_weights(vae_0.vae_.get_weights()) else:[Xt], [Xt, np.zeros(Xt.shape[0])], epochs=1, batch_size=10, shuffle=False, verbose=0) max_dict = { 'oracle_max': oracle_max, 'oracle_max_seq': oracle_max_seq, 'gt_of_oracle_max': gt_of_oracle_max } return traj, oracle_samples, gt_samples, max_dict
def run_killoran(killoran=True): """Runs the GFP comparative tests on the Killoran (aka AM-VAE) optimization algorithm""" TRAIN_SIZE = 5000 train_size_str = "%ik" % (TRAIN_SIZE / 1000) for i in range(3): RANDOM_STATE = i + 1 print(RANDOM_STATE) num_models = [1, 5, 20][i] X_train, _, _ = util.get_experimental_X_y(random_state=RANDOM_STATE, train_size=TRAIN_SIZE) LD = 20 L = X_train.shape[1] vae_suffix = '_%s_%i' % (train_size_str, RANDOM_STATE) ground_truth = gfp_gp.SequenceGP(load=True, load_prefix="data/gfp_gp") loss = losses.neg_log_likelihood keras.utils.get_custom_objects().update({"neg_log_likelihood": loss}) oracle_suffix = '_%s_%i_%i' % (train_size_str, num_models, RANDOM_STATE) sess = tf.Session(graph=tf.get_default_graph()) K.set_session(sess) vae = util.build_vae(latent_dim=20, n_tokens=20, seq_length=X_train.shape[1], enc1_units=50) vae.encoder_.load_weights("../models/vae_0_encoder_weights%s.h5" % vae_suffix) vae.decoder_.load_weights("../models/vae_0_decoder_weights%s.h5" % vae_suffix) vae.vae_.load_weights("../models/vae_0_vae_weights%s.h5" % vae_suffix) oracles = [ keras.models.load_model("../models/oracle_%i%s.h5" % (i, oracle_suffix)) for i in range(num_models) ] if not killoran: results, test_max = optimization_algs.killoran_opt(X_train, vae, oracles, ground_truth, steps=30000, epsilon1=1e-5, epsilon2=1., noise_std=1e-5, LD=20, verbose=False, adam=False) "../results/mala_results_%s_%i.npy" % (train_size_str, RANDOM_STATE), results) suffix = "_%s_%i" % (train_size_str, RANDOM_STATE) with open('results/%s_max%s.json' % ('mala', suffix), 'w') as outfile: json.dump(test_max, outfile) else: results, test_max = optimization_algs.killoran_opt(X_train, vae, oracles, ground_truth, steps=10000, epsilon1=0., epsilon2=0.1, noise_std=1e-6, LD=20, verbose=False, adam=True) "../results/killoran_may_results_%s_%i.npy" % (train_size_str, RANDOM_STATE), results) suffix = "_%s_%i" % (train_size_str, RANDOM_STATE) with open('../results/%s_max%s.json' % ('killoran', suffix), 'w') as outfile: json.dump(test_max, outfile)