Example #1
0
def weighted_ml_opt(X_train,
                    oracles,
                    ground_truth,
                    vae_0,
                    weights_type='dbas',
                    LD=20,
                    iters=20,
                    samples=500,
                    homoscedastic=False,
                    homo_y_var=0.1,
                    quantile=0.95,
                    verbose=False,
                    alpha=1,
                    train_gt_evals=None,
                    cutoff=1e-6,
                    it_epochs=10,
                    enc1_units=50):
    """
    Runs weighted maximum likelihood optimization algorithms ('CbAS', 'DbAS',
    RWR, and CEM-PI)
    """

    assert weights_type in ['cbas', 'dbas', 'rwr', 'cem-pi']
    L = X_train.shape[1]
    vae = util.build_vae(latent_dim=LD,
                         n_tokens=20,
                         seq_length=L,
                         enc1_units=enc1_units)

    traj = np.zeros((iters, 7))
    oracle_samples = np.zeros((iters, samples))
    gt_samples = np.zeros((iters, samples))
    oracle_max_seq = None
    oracle_max = -np.inf
    gt_of_oracle_max = -np.inf
    y_star = -np.inf

    for t in range(iters):
        ### Take Samples ###
        zt = np.random.randn(samples, LD)
        if t > 0:
            Xt_p = vae.decoder_.predict(zt)
            Xt = util.get_samples(Xt_p)
        else:
            Xt = X_train

        ### Evaluate ground truth and oracle ###
        yt, yt_var = util.get_balaji_predictions(oracles, Xt)
        if homoscedastic:
            yt_var = np.ones_like(yt) * homo_y_var
        Xt_aa = np.argmax(Xt, axis=-1)
        if t == 0 and train_gt_evals is not None:
            yt_gt = train_gt_evals
        else:
            yt_gt = ground_truth.predict(Xt_aa, print_every=1000000)[:, 0]

        ### Calculate weights for different schemes ###
        if t > 0:
            if weights_type == 'cbas':
                log_pxt = np.sum(np.log(Xt_p) * Xt, axis=(1, 2))
                X0_p = vae_0.decoder_.predict(zt)
                log_px0 = np.sum(np.log(X0_p) * Xt, axis=(1, 2))
                w1 = np.exp(log_px0 - log_pxt)
                y_star_1 = np.percentile(yt, quantile * 100)
                if y_star_1 > y_star:
                    y_star = y_star_1
                w2 = scipy.stats.norm.sf(y_star, loc=yt, scale=np.sqrt(yt_var))
                weights = w1 * w2
            elif weights_type == 'cem-pi':
                pi = scipy.stats.norm.sf(max_train_gt,
                                         loc=yt,
                                         scale=np.sqrt(yt_var))
                pi_thresh = np.percentile(pi, quantile * 100)
                weights = (pi > pi_thresh).astype(int)
            elif weights_type == 'dbas':
                y_star_1 = np.percentile(yt, quantile * 100)
                if y_star_1 > y_star:
                    y_star = y_star_1
                weights = scipy.stats.norm.sf(y_star,
                                              loc=yt,
                                              scale=np.sqrt(yt_var))
            elif weights_type == 'rwr':
                weights = np.exp(alpha * yt)
                weights /= np.sum(weights)
        else:
            weights = np.ones(yt.shape[0])
            max_train_gt = np.max(yt_gt)

        yt_max_idx = np.argmax(yt)
        yt_max = yt[yt_max_idx]
        if yt_max > oracle_max:
            oracle_max = yt_max
            try:
                oracle_max_seq = util.convert_idx_array_to_aas(
                    Xt_aa[yt_max_idx - 1:yt_max_idx])[0]
            except IndexError:
                print(Xt_aa[yt_max_idx - 1:yt_max_idx])
            gt_of_oracle_max = yt_gt[yt_max_idx]

        ### Record and print results ##
        if t == 0:
            rand_idx = np.random.randint(0, len(yt), samples)
            oracle_samples[t, :] = yt[rand_idx]
            gt_samples[t, :] = yt_gt[rand_idx]
        if t > 0:
            oracle_samples[t, :] = yt
            gt_samples[t, :] = yt_gt

        traj[t, 0] = np.max(yt_gt)
        traj[t, 1] = np.mean(yt_gt)
        traj[t, 2] = np.std(yt_gt)
        traj[t, 3] = np.max(yt)
        traj[t, 4] = np.mean(yt)
        traj[t, 5] = np.std(yt)
        traj[t, 6] = np.mean(yt_var)

        if verbose:
            print(weights_type.upper(), t, traj[t, 0],
                  color.BOLD + str(traj[t, 1]) + color.END, traj[t, 2],
                  traj[t, 3], color.BOLD + str(traj[t, 4]) + color.END,
                  traj[t, 5], traj[t, 6])

        ### Train model ###
        if t == 0:
            vae.encoder_.set_weights(vae_0.encoder_.get_weights())
            vae.decoder_.set_weights(vae_0.decoder_.get_weights())
            vae.vae_.set_weights(vae_0.vae_.get_weights())
        else:
            cutoff_idx = np.where(weights < cutoff)
            Xt = np.delete(Xt, cutoff_idx, axis=0)
            yt = np.delete(yt, cutoff_idx, axis=0)
            weights = np.delete(weights, cutoff_idx, axis=0)
            vae.fit([Xt], [Xt, np.zeros(Xt.shape[0])],
                    epochs=it_epochs,
                    batch_size=10,
                    shuffle=False,
                    sample_weight=[weights, weights],
                    verbose=0)

    max_dict = {
        'oracle_max': oracle_max,
        'oracle_max_seq': oracle_max_seq,
        'gt_of_oracle_max': gt_of_oracle_max
    }
    return traj, oracle_samples, gt_samples, max_dict
Example #2
0
def gomez_bombarelli_opt(X_train,
                         pred_vae,
                         ground_truth,
                         total_it=10000,
                         constrained=True):
    """Runs the Gomez-Bombarelli optimization algorithm"""
    LD = 20
    L = X_train.shape[1]

    embeddings = pred_vae.encoder_.predict([X_train])[0]
    predictions = pred_vae.predictor_.predict(embeddings)[0]
    predictions = predictions.reshape((predictions.shape[0]))

    kernel = C(1.0, (1e-3, 1e3)) * RBF(10, (1e-2, 1e2))

    gp = GaussianProcessRegressor(kernel=kernel,
                                  n_restarts_optimizer=0,
                                  normalize_y=True)
    print("Fitting GP...")
    gp.fit(embeddings, predictions)
    print("Finishing fitting GP")

    def gp_z_estimator(z):
        val = gp.predict(z.reshape(1, -1)).ravel()[0]
        return -val

    method = 'COBYLA'
    num_it = 0
    f_opts = []
    gt_opts = []

    z0 = np.random.randn(LD)
    if constrained:
        z_norm = np.linalg.norm(embeddings, axis=1)
        z_norm_mean = np.mean(z_norm)
        z_norm_std = np.std(z_norm)

        lower_rad = z_norm_mean - z_norm_std * 2
        higher_rad = z_norm_mean + z_norm_std * 2
        constraints = [{
            'type': 'ineq',
            'fun': lambda x: np.linalg.norm(x) - lower_rad
        }, {
            'type': 'ineq',
            'fun': lambda x: higher_rad - np.linalg.norm(x)
        }]
        res = minimize(gp_z_estimator,
                       z0,
                       method=method,
                       constraints=constraints,
                       options={
                           'disp': False,
                           'tol': 0.1,
                           'maxiter': 10000
                       },
                       tol=0.1)
    else:
        res = minimize(gp_z_estimator,
                       z0,
                       method=method,
                       options={
                           'disp': False,
                           'tol': 0.1,
                           'maxiter': 10000
                       },
                       tol=0.1)
    z_opt = res.x.reshape((1, LD))
    f_opt = -res.fun

    x_opt = np.argmax(pred_vae.decoder_.predict(z_opt), axis=-1)
    gt_opt = ground_truth.predict(x_opt)[:, 0][0]
    print(f_opt, gt_opt)

    oracle_max_seq = util.convert_idx_array_to_aas(x_opt)

    max_dict = {
        'oracle_max': f_opt,
        'oracle_max_seq': oracle_max_seq,
        'gt_of_oracle_max': gt_opt
    }

    results = np.array([f_opt, gt_opt])
    return results, max_dict
Example #3
0
def fb_opt(X_train,
           oracles,
           ground_truth,
           vae_0,
           weights_type='fbvae',
           LD=20,
           iters=20,
           samples=500,
           quantile=0.8,
           verbose=False,
           train_gt_evals=None,
           it_epochs=10,
           enc1_units=50):
    """Runs FBVAE optimization algorithm"""

    assert weights_type in ['fbvae']
    L = X_train.shape[1]
    vae = util.build_vae(latent_dim=LD,
                         n_tokens=20,
                         seq_length=L,
                         enc1_units=enc1_units)

    traj = np.zeros((iters, 7))
    oracle_samples = np.zeros((iters, samples))
    gt_samples = np.zeros((iters, samples))
    oracle_max_seq = None
    oracle_max = -np.inf
    gt_of_oracle_max = -np.inf
    y_star = -np.inf
    for t in range(iters):
        ### Take Samples and evaluate ground truth and oracle ##
        zt = np.random.randn(samples, LD)
        if t > 0:
            Xt_sample_p = vae.decoder_.predict(zt)
            Xt_sample = get_samples(Xt_sample_p)
            yt_sample, _ = get_balaji_predictions(oracles, Xt_sample)
            Xt_aa_sample = np.argmax(Xt_sample, axis=-1)
            yt_gt_sample = ground_truth.predict(Xt_aa_sample,
                                                print_every=1000000)[:, 0]
        else:
            Xt = X_train
            yt, _ = util.get_balaji_predictions(oracles, Xt)
            Xt_aa = np.argmax(Xt, axis=-1)
            fb_thresh = np.percentile(yt, quantile * 100)
            if train_gt_evals is not None:
                yt_gt = train_gt_evals
            else:
                yt_gt = ground_truth.predict(Xt_aa, print_every=1000000)[:, 0]

        ### Calculate threshold ###
        if t > 0:
            threshold_idx = np.where(yt_sample >= fb_thresh)[0]
            n_top = len(threshold_idx)
            sample_arrs = [Xt_sample, yt_sample, yt_gt_sample, Xt_aa_sample]
            full_arrs = [Xt, yt, yt_gt, Xt_aa]

            for l in range(len(full_arrs)):
                sample_arr = sample_arrs[l]
                full_arr = full_arrs[l]
                sample_top = sample_arr[threshold_idx]
                full_arr = np.concatenate([sample_top, full_arr])
                full_arr = np.delete(full_arr,
                                     range(full_arr.shape[0] - n_top,
                                           full_arr.shape[0]),
                                     axis=0)
                full_arrs[l] = full_arr
            Xt, yt, yt_gt, Xt_aa = full_arrs
        yt_max_idx = np.argmax(yt)
        yt_max = yt[yt_max_idx]
        if yt_max > oracle_max:
            oracle_max = yt_max
            try:
                oracle_max_seq = util.convert_idx_array_to_aas(
                    Xt_aa[yt_max_idx - 1:yt_max_idx])[0]
            except IndexError:
                print(Xt_aa[yt_max_idx - 1:yt_max_idx])
            gt_of_oracle_max = yt_gt[yt_max_idx]

        ### Record and print results ##

        rand_idx = np.random.randint(0, len(yt), samples)
        oracle_samples[t, :] = yt[rand_idx]
        gt_samples[t, :] = yt_gt[rand_idx]

        traj[t, 0] = np.max(yt_gt)
        traj[t, 1] = np.mean(yt_gt)
        traj[t, 2] = np.std(yt_gt)
        traj[t, 3] = np.max(yt)
        traj[t, 4] = np.mean(yt)
        traj[t, 5] = np.std(yt)
        if t > 0:
            traj[t, 6] = n_top
        else:
            traj[t, 6] = 0

        if verbose:
            print(weights_type.upper(), t, traj[t, 0],
                  color.BOLD + str(traj[t, 1]) + color.END, traj[t, 2],
                  traj[t, 3], color.BOLD + str(traj[t, 4]) + color.END,
                  traj[t, 5], traj[t, 6])

        ### Train model ###
        if t == 0:
            vae.encoder_.set_weights(vae_0.encoder_.get_weights())
            vae.decoder_.set_weights(vae_0.decoder_.get_weights())
            vae.vae_.set_weights(vae_0.vae_.get_weights())
        else:

            vae.fit([Xt], [Xt, np.zeros(Xt.shape[0])],
                    epochs=1,
                    batch_size=10,
                    shuffle=False,
                    verbose=0)

    max_dict = {
        'oracle_max': oracle_max,
        'oracle_max_seq': oracle_max_seq,
        'gt_of_oracle_max': gt_of_oracle_max
    }
    return traj, oracle_samples, gt_samples, max_dict