def main(_):
    ed.set_seed(FLAGS.seed)
    # setting up output directory
    outdir = FLAGS.outdir
    if '~' in outdir: outdir = os.path.expanduser(outdir)
    os.makedirs(outdir, exist_ok=True)

    is_vector = FLAGS.base_dist in ['mvnormal', 'mvlaplace']

    ((Xtrain, ytrain), (Xtest, ytest)) = blr_utils.get_data()
    N, D = Xtrain.shape
    N_test, D_test = Xtest.shape
    assert D_test == D, 'Test dimension %d different than train %d' % (D_test,
                                                                       D)
    logger.info('D = %d, Ntrain = %d, Ntest = %d' % (D, N, N_test))

    # Solution components
    weights, q_params = [], []
    # L-continous gradient estimate
    lipschitz_estimate = None

    # Metrics to log
    times_filename = os.path.join(outdir, 'times.csv')
    open(times_filename, 'w').close()

    # (mean, +- std)
    elbos_filename = os.path.join(outdir, 'elbos.csv')
    logger.info('saving elbos to, %s' % elbos_filename)
    open(elbos_filename, 'w').close()

    rocs_filename = os.path.join(outdir, 'roc.csv')
    logger.info('saving rocs to, %s' % rocs_filename)
    open(rocs_filename, 'w').close()

    gap_filename = os.path.join(outdir, 'gap.csv')
    open(gap_filename, 'w').close()

    step_filename = os.path.join(outdir, 'steps.csv')
    open(step_filename, 'w').close()

    # (mean, std)
    ll_train_filename = os.path.join(outdir, 'll_train.csv')
    open(ll_train_filename, 'w').close()
    ll_test_filename = os.path.join(outdir, 'll_test.csv')
    open(ll_test_filename, 'w').close()

    # (bin_ac_train, bin_ac_test)
    bin_ac_filename = os.path.join(outdir, 'bin_ac.csv')
    open(bin_ac_filename, 'w').close()

    # 'adafw', 'ada_afw', 'ada_pfw'
    if FLAGS.fw_variant.startswith('ada'):
        lipschitz_filename = os.path.join(outdir, 'lipschitz.csv')
        open(lipschitz_filename, 'w').close()

        iter_info_filename = os.path.join(outdir, 'iter_info.txt')
        open(iter_info_filename, 'w').close()

    for t in range(FLAGS.n_fw_iter):
        g = tf.Graph()
        with g.as_default():
            sess = tf.InteractiveSession()
            with sess.as_default():
                tf.set_random_seed(FLAGS.seed)

                # Build Model
                w = Normal(loc=tf.zeros(D, tf.float32),
                           scale=tf.ones(D, tf.float32))

                X = tf.placeholder(tf.float32, [None, D])
                y = Bernoulli(logits=ed.dot(X, w))

                p_joint = blr_utils.Joint(Xtrain, ytrain, sess,
                                          FLAGS.n_monte_carlo_samples, logger)

                # vectorized Model evaluations
                n_test_samples = 100
                W = tf.placeholder(tf.float32, [n_test_samples, D])
                y_data = tf.placeholder(tf.float32, [None])  # N -> (N, n_test)
                y_data_matrix = tf.tile(tf.expand_dims(y_data, 1),
                                        (1, n_test_samples))
                pred_logits = tf.matmul(X, tf.transpose(W))  # (N, n_test)
                ypred = tf.sigmoid(tf.reduce_mean(pred_logits, axis=1))
                pY = Bernoulli(logits=pred_logits)  # (N, n_test)
                log_likelihoods = pY.log_prob(y_data_matrix)  # (N, n_test)
                log_likelihood_expectation = tf.reduce_mean(log_likelihoods,
                                                            axis=1)  # (N, )
                ll_mean, ll_std = tf.nn.moments(log_likelihood_expectation,
                                                axes=[0])

                if t == 0:
                    fw_iterates = {}
                else:
                    # Current solution
                    prev_components = [
                        coreutils.base_loc_scale(FLAGS.base_dist,
                                                 c['loc'],
                                                 c['scale'],
                                                 multivariate=is_vector)
                        for c in q_params
                    ]
                    qtw_prev = coreutils.get_mixture(weights, prev_components)
                    fw_iterates = {w: qtw_prev}

                # s is the solution to LMO, random initialization
                s = coreutils.construct_base(FLAGS.base_dist, [D],
                                             t,
                                             's',
                                             multivariate=is_vector)

                sess.run(tf.global_variables_initializer())

                total_time = 0.
                inference_time_start = time.time()
                # Run relbo to solve LMO problem
                # If the first atom is being selected through running LMO
                # it is equivalent to running vi on a uniform prior
                # Since uniform is not in our variational family try
                # only random element (without LMO inference) as initial iterate
                if FLAGS.iter0 == 'vi' or t > 0:
                    inference = relbo.KLqp({w: s},
                                           fw_iterates=fw_iterates,
                                           data={
                                               X: Xtrain,
                                               y: ytrain
                                           },
                                           fw_iter=t)
                    inference.run(n_iter=FLAGS.LMO_iter)
                inference_time_end = time.time()
                # compute only step size selection time
                #total_time += float(inference_time_end - inference_time_start)

                loc_s = s.mean().eval()
                scale_s = s.stddev().eval()

                # Evaluate the next step
                step_result = {}
                if t == 0:
                    # Initialization, q_0
                    q_params.append({'loc': loc_s, 'scale': scale_s})
                    weights.append(1.)
                    if FLAGS.fw_variant.startswith('ada'):
                        lipschitz_estimate = opt.adafw_linit(s, p_joint)
                    step_type = 'init'
                elif FLAGS.fw_variant == 'fixed':
                    start_step_time = time.time()
                    step_result = opt.fixed(weights, q_params, qtw_prev, loc_s,
                                            scale_s, s, p_joint, t)
                    end_step_time = time.time()
                    total_time += float(end_step_time - start_step_time)
                elif FLAGS.fw_variant == 'adafw':
                    start_step_time = time.time()
                    step_result = opt.adaptive_fw(weights, q_params, qtw_prev,
                                                  loc_s, scale_s, s, p_joint,
                                                  t, lipschitz_estimate)
                    end_step_time = time.time()
                    total_time += float(end_step_time - start_step_time)
                    step_type = step_result['step_type']
                    if step_type == 'adaptive':
                        lipschitz_estimate = step_result['l_estimate']
                elif FLAGS.fw_variant == 'ada_pfw':
                    start_step_time = time.time()
                    step_result = opt.adaptive_pfw(weights, q_params, qtw_prev,
                                                   loc_s, scale_s, s, p_joint,
                                                   t, lipschitz_estimate)
                    end_step_time = time.time()
                    total_time += float(end_step_time - start_step_time)
                    step_type = step_result['step_type']
                    if step_type in ['adaptive', 'drop']:
                        lipschitz_estimate = step_result['l_estimate']
                elif FLAGS.fw_variant == 'ada_afw':
                    start_step_time = time.time()
                    step_result = opt.adaptive_afw(weights, q_params, qtw_prev,
                                                   loc_s, scale_s, s, p_joint,
                                                   t, lipschitz_estimate)
                    end_step_time = time.time()
                    total_time += float(end_step_time - start_step_time)
                    step_type = step_result['step_type']
                    if step_type in ['adaptive', 'away', 'drop']:
                        lipschitz_estimate = step_result['l_estimate']
                elif FLAGS.fw_variant == 'line_search':
                    start_step_time = time.time()
                    step_result = opt.line_search_dkl(weights, q_params,
                                                      qtw_prev, loc_s, scale_s,
                                                      s, p_joint, t)
                    end_step_time = time.time()
                    total_time += float(end_step_time - start_step_time)
                    step_type = step_result['step_type']
                else:
                    raise NotImplementedError(
                        'Step size variant %s not implemented' %
                        FLAGS.fw_variant)

                if t == 0:
                    gamma = 1.
                    new_components = [s]
                else:
                    q_params = step_result['params']
                    weights = step_result['weights']
                    gamma = step_result['gamma']
                    new_components = [
                        coreutils.base_loc_scale(FLAGS.base_dist,
                                                 c['loc'],
                                                 c['scale'],
                                                 multivariate=is_vector)
                        for c in q_params
                    ]
                qtw_new = coreutils.get_mixture(weights, new_components)

                # Log metrics for current iteration
                logger.info('total time %f' % total_time)
                append_to_file(times_filename, total_time)

                elbo_t = elbo(qtw_new, p_joint, return_std=False)
                # testing elbo directory from KLqp
                elbo_loss = elboModel.KLqp({w: qtw_new},
                                           data={
                                               X: Xtrain,
                                               y: ytrain
                                           })
                res_update = elbo_loss.run()

                logger.info("iter, %d, elbo, %.2f loss %.2f" %
                            (t, elbo_t, res_update['loss']))
                append_to_file(elbos_filename,
                               "%f,%f" % (elbo_t, res_update['loss']))

                logger.info('iter %d, gamma %.4f' % (t, gamma))
                append_to_file(step_filename, gamma)

                if t > 0:
                    gap_t = step_result['gap']
                    logger.info('iter %d, gap %.4f' % (t, gap_t))
                    append_to_file(gap_filename, gap_t)

                if FLAGS.fw_variant.startswith('ada'):
                    append_to_file(lipschitz_filename, lipschitz_estimate)
                    append_to_file(iter_info_filename, step_type)
                    logger.info('lt = %.5f, iter_type = %s' %
                                (lipschitz_estimate, step_type))

                # get weight samples to evaluate expectations
                w_samples = qtw_new.sample([n_test_samples]).eval()
                ll_train_mean, ll_train_std = sess.run([ll_mean, ll_std],
                                                       feed_dict={
                                                           W: w_samples,
                                                           X: Xtrain,
                                                           y_data: ytrain
                                                       })
                logger.info("iter, %d, train ll, %.2f +/- %.2f" %
                            (t, ll_train_mean, ll_train_std))
                append_to_file(ll_train_filename,
                               "%f,%f" % (ll_train_mean, ll_train_std))

                ll_test_mean, ll_test_std, y_test_pred = sess.run(
                    [ll_mean, ll_std, ypred],
                    feed_dict={
                        W: w_samples,
                        X: Xtest,
                        y_data: ytest
                    })
                logger.info("iter, %d, test ll, %.2f +/- %.2f" %
                            (t, ll_test_mean, ll_test_std))
                append_to_file(ll_test_filename,
                               "%f,%f" % (ll_test_mean, ll_test_std))

                roc_score = roc_auc_score(ytest, y_test_pred)
                logger.info("iter %d, roc %.4f" % (t, roc_score))
                append_to_file(rocs_filename, roc_score)

                y_post = ed.copy(y, {w: qtw_new})
                # eq. to y = Bernoulli(logits=ed.dot(X, qtw_new))

                ed_train_ll = ed.evaluate('log_likelihood',
                                          data={
                                              X: Xtrain,
                                              y_post: ytrain,
                                          })
                ed_test_ll = ed.evaluate('log_likelihood',
                                         data={
                                             X: Xtest,
                                             y_post: ytest,
                                         })
                logger.info("edward train ll %.2f test ll %.2f" %
                            (ed_train_ll, ed_test_ll))

                bin_ac_train = ed.evaluate('binary_accuracy',
                                           data={
                                               X: Xtrain,
                                               y_post: ytrain,
                                           })
                bin_ac_test = ed.evaluate('binary_accuracy',
                                          data={
                                              X: Xtest,
                                              y_post: ytest,
                                          })
                append_to_file(bin_ac_filename,
                               "%f,%f" % (bin_ac_train, bin_ac_test))
                logger.info(
                    "edward binary accuracy train ll %.2f test ll %.2f" %
                    (bin_ac_train, bin_ac_test))

                mse_test = ed.evaluate('mean_squared_error',
                                       data={
                                           X: Xtest,
                                           y_post: ytest,
                                       })
                logger.info("edward mse test ll %.2f" % (mse_test))

            sess.close()
        tf.reset_default_graph()
def main(argv):
    del argv

    outdir = FLAGS.outdir
    if '~' in outdir: outdir = os.path.expanduser(outdir)
    os.makedirs(outdir, exist_ok=True)

    # Files to log metrics
    times_filename = os.path.join(outdir, 'times.csv')
    elbos_filename = os.path.join(outdir, 'elbos.csv')
    objective_filename = os.path.join(outdir, 'kl.csv')
    reference_filename = os.path.join(outdir, 'ref_kl.csv')
    step_filename = os.path.join(outdir, 'steps.csv')
    # 'adafw', 'ada_afw', 'ada_pfw'
    if FLAGS.fw_variant.startswith('ada'):
        curvature_filename = os.path.join(outdir, 'curvature.csv')
        gap_filename = os.path.join(outdir, 'gap.csv')
        iter_info_filename = os.path.join(outdir, 'iter_info.txt')
    elif FLAGS.fw_variant == 'line_search':
        goutdir = os.path.join(outdir, 'gradients')

    # empty the files present in the folder already
    open(times_filename, 'w').close()
    open(elbos_filename, 'w').close()
    open(objective_filename, 'w').close()
    open(reference_filename, 'w').close()
    open(step_filename, 'w').close()
    # 'adafw', 'ada_afw', 'ada_pfw'
    if FLAGS.fw_variant.startswith('ada'):
        open(curvature_filename, 'w').close()
        append_to_file(curvature_filename, "c_local,c_global")
        open(gap_filename, 'w').close()
        open(iter_info_filename, 'w').close()
    elif FLAGS.fw_variant == 'line_search':
        os.makedirs(goutdir, exist_ok=True)

    for i in range(FLAGS.n_fw_iter):
        # NOTE: First iteration (t = 0) is initialization
        g = tf.Graph()
        with g.as_default():
            tf.set_random_seed(FLAGS.seed)
            sess = tf.InteractiveSession()
            with sess.as_default():
                p, mus, stds = create_target_dist()

                # current iterate (solution until now)
                if FLAGS.init == 'random':
                    muq = np.random.randn(D).astype(np.float32)
                    stdq = softplus(np.random.randn(D).astype(np.float32))
                    raise ValueError
                else:
                    muq = mus[0]
                    stdq = stds[0]

                # 1 correct LMO
                t = 1
                comps = [{'loc': muq, 'scale_diag': stdq}]
                weights = [1.0]
                curvature_estimate = opt.adafw_linit()

                qtx = MultivariateNormalDiag(
                    loc=tf.convert_to_tensor(muq, dtype=tf.float32),
                    scale_diag=tf.convert_to_tensor(stdq, dtype=tf.float32))
                fw_iterates = {p: qtx}

                # calculate kl-div with 1 component
                objective_old = kl_divergence(qtx, p).eval()
                logger.info("kl with init %.4f" % (objective_old))
                append_to_file(reference_filename, objective_old)

                # s is the solution to LMO. It is initialized randomly
                # mu ~ N(0, 1), std ~ softplus(N(0, 1))
                s = coreutils.construct_multivariatenormaldiag([D], t, 's')

                sess.run(tf.global_variables_initializer())

                total_time = 0
                start_inference_time = time.time()
                if FLAGS.LMO == 'vi':
                    # we have to iterate over parameter space
                    raise ValueError
                    inference = relbo.KLqp({p: s},
                                           fw_iterates=fw_iterates,
                                           fw_iter=t)
                    inference.run(n_iter=FLAGS.LMO_iter)
                # s now contains solution to LMO
                end_inference_time = time.time()

                mu_s = s.mean().eval()
                cov_s = s.stddev().eval()

                # NOTE: keep only step size time
                #total_time += end_inference_time - start_inference_time

                # compute step size to update the next iterate
                step_result = {}
                if FLAGS.fw_variant == 'fixed':
                    gamma = 2. / (t + 2.)
                elif FLAGS.fw_variant == 'line_search':
                    start_line_search_time = time.time()
                    step_result = opt.line_search_dkl(
                        weights, [c['loc'] for c in comps],
                        [c['scale_diag']
                         for c in comps], qtx, mu_s, cov_s, s, p, t)
                    end_line_search_time = time.time()
                    total_time += (end_line_search_time -
                                   start_line_search_time)
                    gamma = step_result['gamma']
                elif FLAGS.fw_variant == 'adafw':
                    start_adafw_time = time.time()
                    step_result = opt.adaptive_fw(
                        weights, [c['loc'] for c in comps],
                        [c['scale_diag'] for c in comps], qtx, mu_s, cov_s, s,
                        p, t, curvature_estimate)
                    end_adafw_time = time.time()
                    total_time += end_adafw_time - start_adafw_time
                    gamma = step_result['gamma']
                else:
                    raise NotImplementedError

                comps.append({'loc': mu_s, 'scale_diag': cov_s})
                weights = [(1. - gamma), gamma]

                c_global = estimate_global_curvature(comps, qtx)

                q_latest = Mixture(
                    cat=Categorical(probs=tf.convert_to_tensor(weights)),
                    components=[MultivariateNormalDiag(**c) for c in comps])

                # Log metrics for current iteration
                time_t = float(total_time)
                logger.info('total time %f' % (time_t))
                append_to_file(times_filename, time_t)

                elbo_t = elbo(q_latest, p, n_samples=1000)
                logger.info("iter, %d, elbo, %.2f +/- %.2f" %
                            (t, elbo_t[0], elbo_t[1]))
                append_to_file(elbos_filename,
                               "%f,%f" % (elbo_t[0], elbo_t[1]))

                logger.info('iter %d, gamma %.4f' % (t, gamma))
                append_to_file(step_filename, gamma)

                objective_t = kl_divergence(q_latest, p).eval()
                logger.info("run %d, kl %.4f" % (i, objective_t))
                append_to_file(objective_filename, objective_t)

                if FLAGS.fw_variant.startswith('ada'):
                    curvature_estimate = step_result['c_estimate']
                    append_to_file(gap_filename, step_result['gap'])
                    append_to_file(iter_info_filename,
                                   step_result['step_type'])
                    logger.info('gap = %.3f, ct = %.5f, iter_type = %s' %
                                (step_result['gap'], step_result['c_estimate'],
                                 step_result['step_type']))
                    append_to_file(curvature_filename,
                                   '%f,%f' % (curvature_estimate, c_global))
                elif FLAGS.fw_variant == 'line_search':
                    n_line_search_samples = step_result['n_samples']
                    grad_t = step_result['grad_gamma']
                    g_outfile = os.path.join(
                        goutdir, 'line_search_samples_%d.npy.%d' %
                        (n_line_search_samples, t))
                    logger.info('saving line search data to, %s' % g_outfile)
                    np.save(open(g_outfile, 'wb'), grad_t)

            sess.close()

        tf.reset_default_graph()
def main(_):
    # true latent factors
    U_true = np.random.randn(FLAGS.D, FLAGS.N)
    V_true = np.random.randn(FLAGS.D, FLAGS.M)

    ## DATA
    #R_true = build_toy_dataset(U_true, V_true, FLAGS.N, FLAGS.M)
    #I_train = get_indicators(FLAGS.N, FLAGS.M)
    #I_test = 1 - I_train
    #N = FLAGS.N
    #M = FLAGS.M

    #tr = sio.loadmat(os.path.expanduser("~/data/bbbvi/trainData1.mat"))['X']
    #te = sio.loadmat(os.path.expanduser("~/data/bbbvi/testData1.mat"))['X']
    #tr = tr[:,:100]
    #te = te[:,:100]
    #I_train = tr != 0
    #I_test = te != 0
    #R_true = (tr + te).astype(np.float32)
    #N,M = R_true.shape

    tr = sio.loadmat(os.path.expanduser("~/data/bbbvi/cbcl.mat"))['V']
    te = sio.loadmat(os.path.expanduser("~/data/bbbvi/cbcl.mat"))['V']
    #I_train = np.ones(tr.shape)
    #I_test = np.ones(tr.shape)
    R_true = tr
    N, M = tr.shape
    D = FLAGS.D
    I_train = get_indicators(N, M, FLAGS.mask_ratio)
    I_test = 1 - I_train

    it_best = 0
    weights, qUVt_components, mses = [], [], []
    test_mses, test_lls = [], []
    for iter in range(FLAGS.n_fw_iter):
        print("iter", iter)
        g = tf.Graph()
        with g.as_default():
            tf.set_random_seed(FLAGS.seed)
            sess = tf.InteractiveSession()
            with sess.as_default():
                # MODEL
                I = tf.placeholder(tf.float32, [N, M])

                scale_uv = tf.concat(
                    [tf.ones([FLAGS.D, N]),
                     tf.ones([FLAGS.D, M])], axis=1)
                mean_uv = tf.concat(
                    [tf.zeros([FLAGS.D, N]),
                     tf.zeros([FLAGS.D, M])], axis=1)

                UV = Normal(loc=mean_uv, scale=scale_uv)
                R = Normal(loc=tf.matmul(tf.transpose(UV[:, :N]), UV[:, N:]) *
                           I,
                           scale=tf.ones([N, M]))
                mean_quv = tf.concat([
                    tf.get_variable("qU/loc", [FLAGS.D, N]),
                    tf.get_variable("qV/loc", [FLAGS.D, M])
                ],
                                     axis=1)
                scale_quv = tf.concat([
                    tf.nn.softplus(tf.get_variable("qU/scale", [FLAGS.D, N])),
                    tf.nn.softplus(tf.get_variable("qV/scale", [FLAGS.D, M]))
                ],
                                      axis=1)

                qUV = Normal(loc=mean_quv, scale=scale_quv)

                inference = relbo.KLqp({UV: qUV},
                                       data={
                                           R: R_true,
                                           I: I_train
                                       },
                                       fw_iterates=get_fw_iterates(
                                           iter, weights, UV, qUVt_components),
                                       fw_iter=iter)
                inference.run(n_iter=100)

                gamma = 2. / (iter + 2.)
                weights = [(1. - gamma) * w for w in weights]
                weights.append(gamma)

                qUVt_components = update_iterate(qUVt_components, qUV)

                qUV_new = build_mixture(weights, qUVt_components)
                qR = Normal(loc=tf.matmul(tf.transpose(qUV_new[:, :N]),
                                          qUV_new[:, N:]),
                            scale=tf.ones([N, M]))

                # CRITICISM
                test_mse = ed.evaluate('mean_squared_error',
                                       data={
                                           qR: R_true,
                                           I: I_test.astype(bool)
                                       })
                test_mses.append(test_mse)
                print('test mse', test_mse)

                test_ll = ed.evaluate('log_lik',
                                      data={
                                          qR: R_true.astype('float32'),
                                          I: I_test.astype(bool)
                                      })
                test_lls.append(test_ll)
                print('test_ll', test_ll)

                np.savetxt(os.path.join(FLAGS.outdir, 'test_mse.csv'),
                           test_mses,
                           delimiter=',')
                np.savetxt(os.path.join(FLAGS.outdir, 'test_ll.csv'),
                           test_lls,
                           delimiter=',')
Exemplo n.º 4
0
def main(_):
    outdir = setup_outdir()
    ed.set_seed(FLAGS.seed)

    ((Xtrain, ytrain), (Xtest, ytest)) = blr_utils.get_data()
    N, D = Xtrain.shape
    N_test, D_test = Xtest.shape

    print("Xtrain")
    print(Xtrain)
    print(Xtrain.shape)

    if 'synthetic' in FLAGS.exp:
        w = Normal(loc=tf.zeros(D), scale=1.0 * tf.ones(D))
        X = tf.placeholder(tf.float32, [N, D])
        y = Bernoulli(logits=ed.dot(X, w))

        #n_posterior_samples = 100000
        n_posterior_samples = 10
        qw_empirical = Empirical(
            params=tf.get_variable("qw/params", [n_posterior_samples, D]))
        inference = ed.HMC({w: qw_empirical}, data={X: Xtrain, y: ytrain})
        inference.initialize(n_print=10, step_size=0.6)

        tf.global_variables_initializer().run()
        inference.run()

        empirical_samples = qw_empirical.sample(50).eval()
        #fig, ax = plt.subplots()
        #ax.scatter(posterior_samples[:,0], posterior_samples[:,1])
        #plt.show()

    weights, q_components = [], []
    ll_trains, ll_tests, bin_ac_trains, bin_ac_tests, elbos, rocs, gaps = [], [], [], [], [], [], []
    total_time, times = 0., []
    for iter in range(0, FLAGS.n_fw_iter):
        print("iter %d" % iter)
        g = tf.Graph()
        with g.as_default():
            sess = tf.InteractiveSession()
            with sess.as_default():
                tf.set_random_seed(FLAGS.seed)
                # MODEL
                w = Normal(loc=tf.zeros(D), scale=1.0 * tf.ones(D))

                X = tf.placeholder(tf.float32, [N, D])
                y = Bernoulli(logits=ed.dot(X, w))

                X_test = tf.placeholder(tf.float32, [N_test, D_test])
                y_test = Bernoulli(logits=ed.dot(X_test, w))

                qw = construct_base_dist([D], iter, 'qw')
                inference_time_start = time.time()
                inference = relbo.KLqp({w: qw},
                                       fw_iterates=get_fw_iterates(
                                           weights, w, q_components),
                                       data={
                                           X: Xtrain,
                                           y: ytrain
                                       },
                                       fw_iter=iter)
                tf.global_variables_initializer().run()
                inference.run(n_iter=FLAGS.LMO_iter)
                inference_time_end = time.time()
                total_time += float(inference_time_end - inference_time_start)

                joint = Joint(Xtrain, ytrain, sess)
                if iter > 0:
                    qtw_prev = build_mixture(weights, q_components)
                    gap = compute_duality_gap(joint, qtw_prev, qw)
                    gaps.append(gap)
                    np.savetxt(os.path.join(outdir, "gaps.csv"),
                               gaps,
                               delimiter=',')
                    print("duality gap", gap)

                # update weights
                gamma = 2. / (iter + 2.)
                weights = [(1. - gamma) * w for w in weights]
                weights.append(gamma)

                # update components
                q_components = update_iterate(q_components, qw)

                if len(q_components) > 1 and FLAGS.fw_variant == 'fc':
                    print("running fully corrective")
                    # overwrite the weights
                    weights = fully_corrective(
                        build_mixture(weights, q_components), joint)

                    if True:
                        # remove inactivate iterates
                        weights = list(weights)
                        for i in reversed(range(len(weights))):
                            if weights[i] == 0:
                                del weights[i]
                                del q_components[i]
                        weights = np.array(
                            weights
                        )  # TODO type acrobatics to make elements deletable
                elif len(q_components
                         ) > 1 and FLAGS.fw_variant == 'line_search':
                    print("running line search")
                    weights = line_search(
                        build_mixture(weights[:-1], q_components[:-1]), qw,
                        joint)

                qtw_new = build_mixture(weights, q_components)

                if False:
                    for i, comp in enumerate(qtw_new.components):
                        print("component", i, "\tmean",
                              comp.mean().eval(), "\tstddev",
                              comp.stddev().eval())

                train_lls = [
                    sess.run(y.log_prob(ytrain),
                             feed_dict={
                                 X: Xtrain,
                                 w: qtw_new.sample().eval()
                             }) for _ in range(50)
                ]
                train_lls = np.mean(train_lls, axis=0)
                ll_trains.append((np.mean(train_lls), np.std(train_lls)))

                test_lls = [
                    sess.run(y_test.log_prob(ytest),
                             feed_dict={
                                 X_test: Xtest,
                                 w: qtw_new.sample().eval()
                             }) for _ in range(50)
                ]
                test_lls = np.mean(test_lls, axis=0)
                ll_tests.append((np.mean(test_lls), np.std(test_lls)))

                logits = np.mean([
                    np.dot(Xtest,
                           qtw_new.sample().eval()) for _ in range(50)
                ],
                                 axis=0)
                ypred = tf.sigmoid(logits).eval()
                roc_score = roc_auc_score(ytest, ypred)
                rocs.append(roc_score)

                print('roc_score', roc_score)
                print('ytrain', np.mean(train_lls), np.std(train_lls))
                print('ytest', np.mean(test_lls), np.std(test_lls))

                order = np.argsort(ytest)
                plt.scatter(range(len(ypred)), ypred[order], c=ytest[order])
                plt.savefig(os.path.join(outdir, 'ypred%d.pdf' % iter))
                plt.close()

                np.savetxt(os.path.join(outdir, "train_lls.csv"),
                           ll_trains,
                           delimiter=',')
                np.savetxt(os.path.join(outdir, "test_lls.csv"),
                           ll_tests,
                           delimiter=',')
                np.savetxt(os.path.join(outdir, "rocs.csv"),
                           rocs,
                           delimiter=',')

                x_post = ed.copy(y, {w: qtw_new})
                x_post_t = ed.copy(y_test, {w: qtw_new})

                print(
                    'log lik train',
                    ed.evaluate('log_likelihood',
                                data={
                                    x_post: ytrain,
                                    X: Xtrain
                                }))
                print(
                    'log lik test',
                    ed.evaluate('log_likelihood',
                                data={
                                    x_post_t: ytest,
                                    X_test: Xtest
                                }))

                #ll_train = ed.evaluate('log_likelihood', data={x_post: ytrain, X:Xtrain})
                #ll_test = ed.evaluate('log_likelihood', data={x_post_t: ytest, X_test:Xtest})
                bin_ac_train = ed.evaluate('binary_accuracy',
                                           data={
                                               x_post: ytrain,
                                               X: Xtrain
                                           })
                bin_ac_test = ed.evaluate('binary_accuracy',
                                          data={
                                              x_post_t: ytest,
                                              X_test: Xtest
                                          })
                print('binary accuracy train', bin_ac_train)
                print('binary accuracy test', bin_ac_test)
                #latest_elbo = elbo(qtw_new, joint, w)

                #foo = ed.KLqp({w: qtw_new}, data={X: Xtrain, y: ytrain})
                #op = myloss(foo)
                #print("myloss", sess.run(op[0], feed_dict={X: Xtrain, y: ytrain}), sess.run(op[1], feed_dict={X: Xtrain, y: ytrain}))

                #append_and_save(ll_trains, ll_train, "loglik_train.csv", np.savetxt)
                #append_and_save(ll_tests, ll_train, "loglik_test.csv", np.savetxt) #append_and_save(bin_ac_trains, bin_ac_train, "bin_acc_train.csv", np.savetxt) #append_and_save(bin_ac_tests, bin_ac_test, "bin_acc_test.csv", np.savetxt)
                ##append_and_save(elbos, latest_elbo, "elbo.csv", np.savetxt)

                #print('log-likelihood train ', ll_train)
                #print('log-likelihood test ', ll_test)
                #print('binary_accuracy train ', bin_ac_train)
                #print('binary_accuracy test ', bin_ac_test)
                #print('elbo', latest_elbo)
                times.append(total_time)
                np.savetxt(os.path.join(setup_outdir(), 'times.csv'), times)

        tf.reset_default_graph()
Exemplo n.º 5
0
def run_gap(pi, mus, stds):
    weights, comps = [], []
    elbos = []
    relbo_vals = []
    for t in range(FLAGS.n_fw_iter):
        logger.info('Frank Wolfe Iteration %d' % t)
        g = tf.Graph()
        with g.as_default():
            tf.set_random_seed(FLAGS.seed)
            sess = tf.InteractiveSession()
            with sess.as_default():
                # target distribution components
                pcomps = [
                    MultivariateNormalDiag(
                        loc=tf.convert_to_tensor(mus[i], dtype=tf.float32),
                        scale_diag=tf.convert_to_tensor(stds[i],
                                                        dtype=tf.float32))
                    for i in range(len(mus))
                ]
                # target distribution
                p = Mixture(cat=Categorical(probs=tf.convert_to_tensor(pi)),
                            components=pcomps)

                # LMO appoximation
                s = construct_normal([1], t, 's')
                fw_iterates = {}
                if t > 0:
                    qtx = Mixture(
                        cat=Categorical(probs=tf.convert_to_tensor(weights)),
                        components=[
                            MultivariateNormalDiag(**c) for c in comps
                        ])
                    fw_iterates = {p: qtx}
                sess.run(tf.global_variables_initializer())
                # Run inference on relbo to solve LMO problem
                # NOTE: KLqp has a side effect, it is modifying s
                inference = relbo.KLqp({p: s},
                                       fw_iterates=fw_iterates,
                                       fw_iter=t)
                inference.run(n_iter=FLAGS.LMO_iter)
                # s now contains solution to LMO

                if t > 0:
                    sample_s = s.sample([FLAGS.n_monte_carlo_samples])
                    sample_q = qtx.sample([FLAGS.n_monte_carlo_samples])
                    step_s = tf.reduce_mean(grad_kl(qtx, p, sample_s)).eval()
                    step_q = tf.reduce_mean(grad_kl(qtx, p, sample_q)).eval()
                    gap = step_q - step_s
                    logger.info('Frank-Wolfe gap at iter %d is %.5f' %
                                (t, gap))
                    if gap < 0:
                        eprint('Frank-Wolfe gab becoming negative!')
                    # f(q*) = f(p) = 0
                    logger.info('Objective value (actual gap) is %.5f' %
                                kl_divergence(qtx, p).eval())

                gamma = 2. / (t + 2.)
                comps.append({
                    'loc': s.mean().eval(),
                    'scale_diag': s.stddev().eval()
                })
                weights = coreutils.update_weights(weights, gamma, t)

        tf.reset_default_graph()
Exemplo n.º 6
0
def main(argv):
    del argv

    x_train, components = build_toy_dataset(N)
    n_examples, n_features = x_train.shape

    # save the target
    outdir = setup_outdir()
    np.savez(os.path.join(outdir, 'target_dist.npz'),
             pi=pi,
             mus=mus,
             stds=stds)

    weights, comps = [], []
    elbos = []
    relbo_vals = []
    times = []
    for iter in range(FLAGS.n_fw_iter):
        g = tf.Graph()
        with g.as_default():
            tf.set_random_seed(FLAGS.seed)
            sess = tf.InteractiveSession()
            with sess.as_default():
                # build model
                xcomps = [
                    Normal(loc=tf.convert_to_tensor(mus[i]),
                           scale=tf.convert_to_tensor(stds[i]))
                    for i in range(len(mus))
                ]
                x = Mixture(cat=Categorical(probs=tf.convert_to_tensor(pi)),
                            components=xcomps,
                            sample_shape=N)

                qx = construct_normal([n_features], iter, 'qx')
                if iter > 0:
                    qtx = Mixture(
                        cat=Categorical(probs=tf.convert_to_tensor(weights)),
                        components=[
                            Normal(
                                loc=c['loc'][0],
                                #scale_diag=tf.nn.softplus(c['scale_diag'])) for c in comps], sample_shape=N)
                                scale=c['scale_diag'][0]) for c in comps
                        ],
                        sample_shape=N)
                    fw_iterates = {x: qtx}
                else:
                    fw_iterates = {}

                sess.run(tf.global_variables_initializer())

                total_time = 0
                start_inference_time = time.time()
                inference = relbo.KLqp({x: qx},
                                       fw_iterates=fw_iterates,
                                       fw_iter=iter)
                inference.run(n_iter=FLAGS.LMO_iter)
                end_inference_time = time.time()

                total_time += end_inference_time - start_inference_time

                if iter > 0:
                    relbo_vals.append(-utils.compute_relbo(
                        qx, fw_iterates[x], x, np.log(iter + 1)))

                if iter == 0:
                    gamma = 1.
                elif iter > 0 and FLAGS.fw_variant == 'fixed':
                    gamma = 2. / (iter + 2.)
                elif iter > 0 and FLAGS.fw_variant == 'line_search':
                    start_line_search_time = time.time()
                    gamma = line_search_dkl(weights, [c['loc'] for c in comps],
                                            [c['scale_diag'] for c in comps],
                                            qx.loc.eval(),
                                            qx.stddev().eval(), x, iter)
                    end_line_search_time = time.time()
                    total_time += end_line_search_time - start_line_search_time
                elif iter > 0 and FLAGS.fw_variant == 'fc':
                    gamma = 2. / (iter + 2.)

                comps.append({
                    'loc': qx.mean().eval(),
                    'scale_diag': qx.stddev().eval()
                })
                weights = utils.update_weights(weights, gamma, iter)

                print("weights", weights)
                print("comps", [c['loc'] for c in comps])
                print("scale_diags", [c['scale_diag'] for c in comps])

                q_latest = Mixture(
                    cat=Categorical(probs=tf.convert_to_tensor(weights)),
                    components=[MultivariateNormalDiag(**c) for c in comps],
                    sample_shape=N)

                if FLAGS.fw_variant == "fc":
                    start_fc_time = time.time()
                    weights = fully_corrective(q_latest, x)
                    weights = list(weights)
                    for i in reversed(range(len(weights))):
                        w = weights[i]
                        if w == 0:
                            del weights[i]
                            del comps[i]
                    weights = np.array(weights)
                    end_fc_time = time.time()
                    total_time += end_fc_time - start_fc_time

                q_latest = Mixture(
                    cat=Categorical(probs=tf.convert_to_tensor(weights)),
                    components=[MultivariateNormalDiag(**c) for c in comps],
                    sample_shape=N)

                elbos.append(elbo(q_latest, x))

                outdir = setup_outdir()

                print("total time", total_time)
                times.append(float(total_time))
                utils.save_times(os.path.join(outdir, 'times.csv'), times)

                elbos_filename = os.path.join(outdir, 'elbos.csv')
                logger.info("iter, %d, elbo, %.2f +/- %.2f" %
                            (iter, *elbos[-1]))
                np.savetxt(elbos_filename, elbos, delimiter=',')
                logger.info("saving elbos to, %s" % elbos_filename)

                relbos_filename = os.path.join(outdir, 'relbos.csv')
                np.savetxt(relbos_filename, relbo_vals, delimiter=',')
                logger.info("saving relbo values to, %s" % relbos_filename)

                for_serialization = {
                    'locs': np.array([c['loc'] for c in comps]),
                    'scale_diags': np.array([c['scale_diag'] for c in comps])
                }
                qt_outfile = os.path.join(outdir, 'qt_iter%d.npz' % iter)
                np.savez(qt_outfile, weights=weights, **for_serialization)
                np.savez(os.path.join(outdir, 'qt_latest.npz'),
                         weights=weights,
                         **for_serialization)
                logger.info("saving qt to, %s" % qt_outfile)
        tf.reset_default_graph()
Exemplo n.º 7
0
def main(_):
    # setting up output directory
    outdir = os.path.expanduser(FLAGS.outdir)
    os.makedirs(outdir, exist_ok=True)

    N, M, D, R_true, I_train, I_test = get_data()
    debug('N, M, D', N, M, D)

    # Solution components
    weights, qUVt_components = [], []

    # Files to log metrics
    times_filename = os.path.join(outdir, 'times.csv')
    mse_train_filename = os.path.join(outdir, 'mse_train.csv')
    mse_test_filename = os.path.join(outdir, 'mse_test.csv')
    ll_test_filename = os.path.join(outdir, 'll_test.csv')
    ll_train_filename = os.path.join(outdir, 'll_train.csv')
    elbos_filename = os.path.join(outdir, 'elbos.csv')
    gap_filename = os.path.join(outdir, 'gap.csv')
    step_filename = os.path.join(outdir, 'steps.csv')
    # 'adafw', 'ada_afw', 'ada_pfw'
    if FLAGS.fw_variant.startswith('ada'):
        lipschitz_filename = os.path.join(outdir, 'lipschitz.csv')
        iter_info_filename = os.path.join(outdir, 'iter_info.txt')

    start = 0
    if FLAGS.restore:
        #start = 50
        #qUVt_components = get_random_components(D, N, M, start)
        #weights = np.random.dirichlet([1.] * start).astype(np.float32)
        #lipschitz_estimate = opt.adafw_linit()
        parameters = np.load(os.path.join(outdir, 'qt_latest.npz'))
        weights = list(parameters['weights'])
        start = parameters['fw_iter']
        qUVt_components = list(parameters['comps'])
        assert len(weights) == len(qUVt_components), "Inconsistent storage"
        # get lipschitz estimate from the file, could've stored it
        # in params but that would mean different saved file for
        # adaptive variants
        if FLAGS.fw_variant.startswith('ada'):
            lipschitz_filename = os.path.join(outdir, 'lipschitz.csv')
            if not os.path.isfile(lipschitz_filename):
                raise ValueError("Inconsistent storage")
            with open(lipschitz_filename, 'r') as f:
                l = f.readlines()
                lipschitz_estimate = float(l[-1].strip())
    else:
        # empty the files present in the folder already
        open(times_filename, 'w').close()
        open(mse_train_filename, 'w').close()
        open(mse_test_filename, 'w').close()
        open(ll_test_filename, 'w').close()
        open(ll_train_filename, 'w').close()
        open(elbos_filename, 'w').close()
        open(gap_filename, 'w').close()
        open(step_filename, 'w').close()
        # 'adafw', 'ada_afw', 'ada_pfw'
        if FLAGS.fw_variant.startswith('ada'):
            open(lipschitz_filename, 'w').close()
            open(iter_info_filename, 'w').close()

    for t in range(start, start + FLAGS.n_fw_iter):
        g = tf.Graph()
        with g.as_default():
            tf.set_random_seed(FLAGS.seed)
            sess = tf.InteractiveSession()
            with sess.as_default():
                # MODEL
                I = tf.placeholder(tf.float32, [N, M])

                scale_uv = tf.concat(
                    [tf.ones([D, N]), tf.ones([D, M])], axis=1)
                mean_uv = tf.concat(
                    [tf.zeros([D, N]), tf.zeros([D, M])], axis=1)

                UV = Normal(loc=mean_uv, scale=scale_uv)
                R = Normal(loc=tf.matmul(tf.transpose(UV[:, :N]), UV[:, N:]),
                           scale=tf.ones([N, M]))  # generator dist. for matrix
                R_mask = R * I  # generated masked matrix

                p_joint = Joint(R_true, I_train, sess, D, N, M)

                if t == 0:
                    fw_iterates = {}
                else:
                    # Current solution
                    prev_components = [
                        coreutils.base_loc_scale('mvn0',
                                                 c['loc'],
                                                 c['scale'],
                                                 multivariate=False)
                        for c in qUVt_components
                    ]
                    qUV_prev = coreutils.get_mixture(weights, prev_components)
                    fw_iterates = {UV: qUV_prev}

                # LMO (via relbo INFERENCE)
                mean_suv = tf.concat([
                    tf.get_variable("qU/loc", [D, N]),
                    tf.get_variable("qV/loc", [D, M])
                ],
                                     axis=1)
                scale_suv = tf.concat([
                    tf.nn.softplus(tf.get_variable("qU/scale", [D, N])),
                    tf.nn.softplus(tf.get_variable("qV/scale", [D, M]))
                ],
                                      axis=1)

                sUV = Normal(loc=mean_suv, scale=scale_suv)

                #inference = relbo.KLqp({UV: sUV}, data={R: R_true, I: I_train},
                inference = relbo.KLqp({UV: sUV},
                                       data={
                                           R_mask: R_true,
                                           I: I_train
                                       },
                                       fw_iterates=fw_iterates,
                                       fw_iter=t)
                inference.run(n_iter=FLAGS.LMO_iter)

                loc_s = sUV.mean().eval()
                scale_s = sUV.stddev().eval()
                # sUV is batched distrbution, there are issues making
                # Mixture with batch distributions. mvn0
                # with event size (D, N + M) and batch size ()
                # NOTE log_prob(sample) still returns tensor
                # mvn and multivariatenormaldiag work for 1-D not 2-D shapes
                sUV_mv = coreutils.base_loc_scale('mvn0',
                                                  loc_s,
                                                  scale_s,
                                                  multivariate=False)
                # TODO send sUV or sUV_mv as argument to step size? sample
                # works the same way. same with log_prob

                total_time = 0.
                data = {R: R_true, I: I_train}
                if t == 0:
                    gamma = 1.
                    lipschitz_estimate = opt.adafw_linit()
                    step_type = 'init'
                elif FLAGS.fw_variant == 'fixed':
                    start_step_time = time.time()
                    step_result = opt.fixed(weights, qUVt_components, qUV_prev,
                                            loc_s, scale_s, sUV, p_joint, data,
                                            t)
                    end_step_time = time.time()
                    total_time += float(end_step_time - start_step_time)
                elif FLAGS.fw_variant == 'line_search':
                    start_step_time = time.time()
                    step_result = opt.line_search_dkl(weights, qUVt_components,
                                                      qUV_prev, loc_s, scale_s,
                                                      sUV, p_joint, data, t)
                    end_step_time = time.time()
                    total_time += float(end_step_time - start_step_time)
                elif FLAGS.fw_variant == 'adafw':
                    start_step_time = time.time()
                    step_result = opt.adaptive_fw(weights, qUVt_components,
                                                  qUV_prev, loc_s, scale_s,
                                                  sUV, p_joint, data, t,
                                                  lipschitz_estimate)
                    end_step_time = time.time()
                    total_time += float(end_step_time - start_step_time)

                    step_type = step_result['step_type']
                    if step_type == 'adaptive':
                        lipschitz_estimate = step_result['l_estimate']
                elif FLAGS.fw_variant == 'ada_pfw':
                    start_step_time = time.time()
                    step_result = opt.adaptive_pfw(weights, qUVt_components,
                                                   qUV_prev, loc_s, scale_s,
                                                   sUV, p_joint, data, t,
                                                   lipschitz_estimate)
                    end_step_time = time.time()
                    total_time += float(end_step_time - start_step_time)

                    step_type = step_result['step_type']
                    if step_type in ['adaptive', 'drop']:
                        lipschitz_estimate = step_result['l_estimate']
                elif FLAGS.fw_variant == 'ada_afw':
                    start_step_time = time.time()
                    step_result = opt.adaptive_pfw(weights, qUVt_components,
                                                   qUV_prev, loc_s, scale_s,
                                                   sUV, p_joint, data, t,
                                                   lipschitz_estimate)
                    end_step_time = time.time()
                    total_time += float(end_step_time - start_step_time)

                    step_type = step_result['step_type']
                    if step_type in ['adaptive', 'away', 'drop']:
                        lipschitz_estimate = step_result['l_estimate']

                if t == 0:
                    gamma = 1.
                    weights.append(gamma)
                    qUVt_components.append({'loc': loc_s, 'scale': scale_s})
                    new_components = [sUV_mv]
                else:
                    qUVt_components = step_result['params']
                    weights = step_result['weights']
                    gamma = step_result['gamma']
                    new_components = [
                        coreutils.base_loc_scale('mvn0',
                                                 c['loc'],
                                                 c['scale'],
                                                 multivariate=False)
                        for c in qUVt_components
                    ]

                qUV_new = coreutils.get_mixture(weights, new_components)

                #qR = Normal(
                #    loc=tf.matmul(
                #        tf.transpose(qUV_new[:, :N]), qUV_new[:, N:]),
                #    scale=tf.ones([N, M]))
                qR = ed.copy(R, {UV: qUV_new})
                cR = ed.copy(R_mask, {UV: qUV_new})  # reconstructed matrix

                # Log metrics for current iteration
                logger.info('total time %f' % total_time)
                append_to_file(times_filename, total_time)

                logger.info('iter %d, gamma %.4f' % (t, gamma))
                append_to_file(step_filename, gamma)

                if t > 0:
                    gap_t = step_result['gap']
                    logger.info('iter %d, gap %.4f' % (t, gap_t))
                    append_to_file(gap_filename, gap_t)

                # CRITICISM
                if FLAGS.fw_variant.startswith('ada'):
                    append_to_file(lipschitz_filename, lipschitz_estimate)
                    append_to_file(iter_info_filename, step_type)
                    logger.info('lt = %.5f, iter_type = %s' %
                                (lipschitz_estimate, step_type))

                test_mse = ed.evaluate('mean_squared_error',
                                       data={
                                           cR: R_true,
                                           I: I_test
                                       })
                logger.info("iter %d ed test mse %.5f" % (t, test_mse))
                append_to_file(mse_test_filename, test_mse)

                train_mse = ed.evaluate('mean_squared_error',
                                        data={
                                            cR: R_true,
                                            I: I_train
                                        })
                logger.info("iter %d ed train mse %.5f" % (t, train_mse))
                append_to_file(mse_train_filename, train_mse)

                # very slow
                #train_ll = log_likelihood(qUV_new, R_true, I_train, sess, D, N,
                #                          M)
                train_ll = ed.evaluate('log_lik',
                                       data={
                                           qR: R_true.astype(np.float32),
                                           I: I_train
                                       })
                logger.info("iter %d train log lik %.5f" % (t, train_ll))
                append_to_file(ll_train_filename, train_ll)

                #test_ll = log_likelihood(qUV_new, R_true, I_test, sess, D, N, M)
                test_ll = ed.evaluate('log_lik',
                                      data={
                                          qR: R_true.astype(np.float32),
                                          I: I_test
                                      })
                logger.info("iter %d test log lik %.5f" % (t, test_ll))
                append_to_file(ll_test_filename, test_ll)

                # elbo_loss might be meaningless
                elbo_loss = elboModel.KLqp({UV: qUV_new},
                                           data={
                                               R: R_true,
                                               I: I_train
                                           })
                elbo_t = elbo(qUV_new, p_joint)
                res_update = elbo_loss.run()
                logger.info('iter %d -elbo loss %.2f or %.2f' %
                            (t, res_update['loss'], elbo_t))
                append_to_file(elbos_filename,
                               "%f,%f" % (elbo_t, res_update['loss']))

                # serialize the current iterate
                np.savez(os.path.join(outdir, 'qt_latest.npz'),
                         weights=weights,
                         comps=qUVt_components,
                         fw_iter=t + 1)

                sess.close()
        tf.reset_default_graph()
    def run(self, outdir, pi, mus, stds, n_features):
        """Run Boosted BBVI.

        Args:
            outdir: output directory
            pi: weights of target mixture
            mus: means of target mixture
            stds: scale of target mixture
            n_features: dimensionality

        Returns:
            runs FLAGS.n_fw_iter of frank-wolfe and logs
            relevant metrics
        """

        # comps: component atoms of boosting (contains a dict of params)
        # weights: weights given to every atom over comps
        # Together S = {weights, comps} make the active set
        weights, comps = [], []
        # L-continuous gradient estimate
        lipschitz_estimate = None

        #debug('target', mus, stds)
        start = 0
        if FLAGS.restore:
            # 1 correct LMO
            start = 1
            comps.append({'loc': mus[0], 'scale_diag': stds[0]})
            weights.append(1.0)
            lipschitz_estimate = opt.adafw_linit(None, None)


        # Metrics to log
        times_filename = os.path.join(outdir, 'times.csv')
        open(times_filename, 'w').close() # truncate the file if exists

        elbos_filename = os.path.join(outdir, 'elbos.csv')
        logger.info("saving elbos to, %s" % elbos_filename)
        open(elbos_filename, 'w').close()

        relbos_filename = os.path.join(outdir, 'relbos.csv')
        logger.info('saving relbos to, %s' % relbos_filename)
        open(relbos_filename, 'w').close()

        objective_filename = os.path.join(outdir, 'kl.csv')
        logger.info("saving kl divergence to, %s" % objective_filename)
        if not FLAGS.restore:
            open(objective_filename, 'w').close()

        step_filename = os.path.join(outdir, 'steps.csv')
        logger.info("saving gamma values to, %s" % step_filename)
        if not FLAGS.restore:
            open(step_filename, 'w').close()

        # 'adafw', 'ada_afw', 'ada_pfw'
        if FLAGS.fw_variant.startswith('ada'):
            lipschitz_filename = os.path.join(outdir, 'lipschitz.csv')
            open(lipschitz_filename, 'w').close()

            gap_filename = os.path.join(outdir, 'gap.csv')
            open(gap_filename, 'w').close()

            iter_info_filename = os.path.join(outdir, 'iter_info.txt')
            open(iter_info_filename, 'w').close()
        elif FLAGS.fw_variant == 'line_search':
            goutdir = os.path.join(outdir, 'gradients')
            os.makedirs(goutdir, exist_ok=True)

        for t in range(start, start + FLAGS.n_fw_iter):
            # NOTE: First iteration (t = 0) is initialization
            g = tf.Graph()
            with g.as_default():
                tf.set_random_seed(FLAGS.seed)
                sess = tf.InteractiveSession()
                with sess.as_default():
                    # build target distribution
                    p = self.target_dist(pi=pi, mus=mus, stds=stds)

                    if t == 0:
                        fw_iterates = {}
                    else:
                        # current iterate (solution until now)
                        qtx = Mixture(
                            cat=Categorical(
                                probs=tf.convert_to_tensor(weights)),
                            components=[
                                MultivariateNormalDiag(**c) for c in comps
                            ])
                        fw_iterates = {p: qtx}

                    # s is the solution to LMO. It is initialized randomly
                    #s = coreutils.construct_normal([n_features], t, 's')
                    s = coreutils.construct_multivariatenormaldiag([n_features], t, 's')

                    sess.run(tf.global_variables_initializer())

                    total_time = 0
                    start_inference_time = time.time()
                    # Run inference on relbo to solve LMO problem
                    # If initilization of mixture is random, then the
                    # first component will be random distribution, in
                    # that case no inference is needed.
                    # NOTE: KLqp has a side effect, it is modifying s
                    #if FLAGS.iter0 == 'vi' or t > 0:
                    if FLAGS.iter0 == 'vi':
                        inference = relbo.KLqp(
                            {
                                p: s
                            }, fw_iterates=fw_iterates, fw_iter=t)
                        inference.run(n_iter=FLAGS.LMO_iter)
                    # s now contains solution to LMO
                    end_inference_time = time.time()

                    mu_s = s.mean().eval()
                    cov_s = s.stddev().eval()
                    #debug('LMO', mu_s, cov_s)

                    # NOTE: keep only step size time
                    #total_time += end_inference_time - start_inference_time

                    # compute step size to update the next iterate
                    step_result = {}
                    if t == 0:
                        gamma = 1.
                        if FLAGS.fw_variant.startswith('ada'):
                            lipschitz_estimate = opt.adafw_linit(s, p)
                    elif FLAGS.fw_variant == 'fixed':
                        gamma = 2. / (t + 2.)
                    elif FLAGS.fw_variant == 'line_search':
                        start_line_search_time = time.time()
                        step_result = opt.line_search_dkl(
                            weights, [c['loc'] for c in comps],
                            [c['scale_diag'] for c in comps], qtx, mu_s, cov_s,
                            s, p, t)
                        end_line_search_time = time.time()
                        total_time += (
                            end_line_search_time - start_line_search_time)
                        gamma = step_result['gamma']
                    elif FLAGS.fw_variant == 'fc':
                        # Add a fixed component. Correct later
                        gamma = 2. / (t + 2.)
                    elif FLAGS.fw_variant == 'adafw':
                        start_adafw_time = time.time()
                        step_result = opt.adaptive_fw(
                            weights, [c['loc'] for c in comps],
                            [c['scale_diag'] for c in comps], qtx, mu_s, cov_s,
                            s, p, t, lipschitz_estimate)
                        end_adafw_time = time.time()
                        total_time += end_adafw_time - start_adafw_time
                        gamma = step_result['gamma']
                    elif FLAGS.fw_variant == 'ada_afw':
                        start_adaafw_time = time.time()
                        step_result = opt.adaptive_afw(
                            weights, comps, [c['loc'] for c in comps],
                            [c['scale_diag'] for c in comps], qtx, mu_s, cov_s,
                            s, p, t, lipschitz_estimate)
                        end_adaafw_time = time.time()
                        total_time += end_adaafw_time - start_adaafw_time
                        gamma = step_result['gamma'] # just for logging
                    elif FLAGS.fw_variant == 'ada_pfw':
                        start_adapfw_time = time.time()
                        step_result = opt.adaptive_pfw(
                            weights, comps, [c['loc'] for c in comps],
                            [c['scale_diag'] for c in comps], qtx, mu_s, cov_s,
                            s, p, t, lipschitz_estimate)
                        end_adapfw_time = time.time()
                        total_time += end_adapfw_time - start_adapfw_time
                        gamma = step_result['gamma'] # just for logging

                    if ((FLAGS.fw_variant == 'ada_afw'
                         or FLAGS.fw_variant == 'ada_pfw') and t > 0):
                        comps = step_result['comps']
                        weights = step_result['weights']
                    else:
                        comps.append({'loc': mu_s, 'scale_diag': cov_s})
                        weights = coreutils.update_weights(weights, gamma, t)

                    # TODO: Move this to fw_step_size.py
                    if FLAGS.fw_variant == "fc":
                        q_latest = Mixture(
                            cat=Categorical(
                                probs=tf.convert_to_tensor(weights)),
                            components=[
                                MultivariateNormalDiag(**c) for c in comps
                            ])
                        # Correction
                        start_fc_time = time.time()
                        weights = opt.fully_corrective(q_latest, p)
                        weights = list(weights)
                        for i in reversed(range(len(weights))):
                            # Remove components whose weight is 0
                            w = weights[i]
                            if w == 0:
                                del weights[i]
                                del comps[i]
                        weights = np.array(weights)
                        end_fc_time = time.time()
                        total_time += end_fc_time - start_fc_time

                    q_latest = Mixture(
                        cat=Categorical(probs=tf.convert_to_tensor(weights)),
                        components=[
                            MultivariateNormalDiag(**c) for c in comps
                        ])

                    # Log metrics for current iteration
                    time_t = float(total_time)
                    logger.info('total time %f' % (time_t))
                    append_to_file(times_filename, time_t)

                    elbo_t = elbo(q_latest, p, n_samples=10)
                    logger.info("iter, %d, elbo, %.2f +/- %.2f" %
                                (t, elbo_t[0], elbo_t[1]))
                    append_to_file(elbos_filename,
                                   "%f,%f" % (elbo_t[0], elbo_t[1]))

                    logger.info('iter %d, gamma %.4f' % (t, gamma))
                    append_to_file(step_filename, gamma)

                    if t > 0:
                        relbo_t = -coreutils.compute_relbo(
                            s, fw_iterates[p], p, np.log(t + 1))
                        append_to_file(relbos_filename, relbo_t)

                    objective_t = kl_divergence(q_latest, p).eval()
                    logger.info("iter, %d, kl, %.2f" % (t, objective_t))
                    append_to_file(objective_filename, objective_t)

                    if FLAGS.fw_variant.startswith('ada'):
                        if t > 0:
                            lipschitz_estimate = step_result['l_estimate']
                            append_to_file(gap_filename, step_result['gap'])
                            append_to_file(iter_info_filename,
                                        step_result['step_type'])
                            logger.info(
                                'gap = %.3f, lt = %.5f, iter_type = %s' %
                                (step_result['gap'], step_result['l_estimate'],
                                step_result['step_type']))
                        # l_estimate for iter 0 is the intial value
                        append_to_file(lipschitz_filename, lipschitz_estimate)
                    elif FLAGS.fw_variant == 'line_search' and t > 0:
                        n_line_search_samples = step_result['n_samples']
                        grad_t = step_result['grad_gamma']
                        g_outfile = os.path.join(
                            goutdir, 'line_search_samples_%d.npy.%d' %
                            (n_line_search_samples, t))
                        logger.info(
                            'saving line search data to, %s' % g_outfile)
                        np.save(open(g_outfile, 'wb'), grad_t)

                    for_serialization = {
                        'locs': np.array([c['loc'] for c in comps]),
                        'scale_diags':
                        np.array([c['scale_diag'] for c in comps])
                    }
                    qt_outfile = os.path.join(outdir, 'qt_iter%d.npz' % t)
                    np.savez(qt_outfile, weights=weights, **for_serialization)
                    np.savez(
                        os.path.join(outdir, 'qt_latest.npz'),
                        weights=weights,
                        **for_serialization)
                    logger.info("saving qt to, %s" % qt_outfile)
            tf.reset_default_graph()