def load_dataframe_for_run(runfile, offset=0):
    """Load a run and create a DataFrame from metrics.
    
    Args:
        runfile: path to load the run from
        offset: start iteration for plotting
    Returns:
        pandas DataFrame
    """
    data = np.load(runfile)
    # (n_line_search_iter, n_line_search_samples, 1)
    e_s = np.array([d['E_s'] for d in data])[offset:]
    e_q = np.array([d['E_q'] for d in data])[offset:]
    # (n_line_search_iter)
    gamma = np.array([d['gamma'] for d in data])[offset:]
    n_line_search_samples = e_s.shape[1]
    debug('line search samples %d' % n_line_search_samples)
    n_line_search_iter = e_s.shape[0]
    # (n_line_search_iter)
    iter_nos = np.arange(offset, offset + n_line_search_iter)

    # construct flattened columns for dataframe
    e_s_flat = e_s.flatten()
    e_q_flat = e_q.flatten()
    gamma_flat = np.repeat(gamma, n_line_search_samples)
    iters_flat = np.repeat(iter_nos, n_line_search_samples)
    line_search_samples = np.repeat(n_line_search_samples,
                                    n_line_search_samples * n_line_search_iter)
    return pd.DataFrame({
        'E_s': e_s_flat,
        'E_q': e_q_flat,
        'gamma': gamma_flat,
        'iterations': iters_flat,
        'n_samples': line_search_samples
    })
def main(argv):
    del argv
    if FLAGS.grid2d:
        raise NotImplementedError('Only 1D Normal supported...')

    if FLAGS.qt == "":
        eprint("provide some qt to the `--qt` option if you would like to "
               "plot")

    if FLAGS.label:
        label = FLAGS.label
    else:
        qt_file = os.path.splitext(FLAGS.qt)[0]
        label = qt_file[qt_file.find('qt_') + len('qt_'):]

    plt.figure(1)
    debug("visualizing %s" % FLAGS.qt)
    mixture_params = get_mixture_params_from_file(FLAGS.qt)
    #plot_normal_mix(mixture_params['weights'], mixture_params['locs'],
    #                mixture_params['scale_diags'], plt, label)

    plt.figure(2)
    w = mixture_params['weights']
    barlist = plt.bar(np.arange(len(w)), w, color='b', label=label)

    if FLAGS.iter_labels:
        label_name = os.path.basename(FLAGS.iter_labels)
        if label_name.startswith('iter_types'):
            # label which iterations came from adaptive which
            # from fixed.
            with open(FLAGS.iter_labels, 'r') as f:
                iter_types = f.readlines()
            for i, it in enumerate(iter_types):
                it = it.strip()
                if it != 'adaptive':
                    if it == 'fixed':
                        barlist[i + 1].set_color('r')
                    elif it == 'fixed_adaptive_MAXITER':
                        barlist[i + 1].set_color('g')
                    else:
                        barlist[i + 1].set_color('k')
            ad = mpatches.Patch(color='b', label='Adaptive step')
            fi = mpatches.Patch(color='r',
                                label='Fixed step (adafw loop long)')
            fa = mpatches.Patch(color='g', label='Fixed step (-ve gap)')
            plt.legend(handles=[ad, fi, fa], loc=2)

    if FLAGS.outdir == 'stdout':
        plt.show()
    else:
        fig.tight_layout()
        outname = os.path.join(os.path.expanduser(FLAGS.outdir), FLAGS.outfile)
        fig.savefig(outname,
                    bbox_extra_artists=(legend, ),
                    bbox_inches='tight')
        print('saved to ', outname)
def adafw_linit(q_0, p):
    """Initialization of L estimate for Adaptive
    Frank Wolfe algorithm. Given in v2 of the
    paper https://arxiv.org/pdf/1806.05123.pdf

    Args:
        q_0: initial iterate
        p: target distribution
    Returns:
        L initialized value, float
    """
    if FLAGS.linit == 'fixed':
        return FLAGS.linit_fixed
    elif FLAGS.linit != 'lipschitz_v2':
        raise NotImplementedError('v1 not implemented')

    logger.warning('AdaFW initializer might not be correct')
    # larger sample size for more accuracy
    N_samples = FLAGS.n_monte_carlo_samples * 5
    theta = q_0.sample([N_samples])
    # grad_q0 = grad_kl(q_0, p, theta).eval()
    log_q0 = q_0.log_prob(theta).eval()
    log_p = p.log_prob(theta).eval()
    grad_q0 = log_q0 - log_p
    prob_q0 = q_0.prob(theta).eval()

    def get_diff(L):
        h = -1.*grad_q0 / L
        # q_0 + h is not a valid probability distribution so values
        # can get negative. Performing clipping before taking log
        t0 = np.clip(prob_q0 + h, 1e-5, None)
        t1 = np.log(t0)
        t2 = np.mean(t1 - log_q0)
        t3 = t1 - log_p
        t4 = (h * t3) / prob_q0
        t5 = np.mean(t4)
        return t2 - t5

    L_init_estimate = FLAGS.linit_fixed
    while get_diff(L_init_estimate) > 0.:
        debug('L estimate diff is %.5f for L %.2f' %
              (get_diff(L_init_estimate), L_init_estimate))
        L_init_estimate *= 10.
    debug('L estimate diff is %.5f for L %.2f' %
            (get_diff(L_init_estimate), L_init_estimate))
    logger.info('initial Lipschitz estimate is %.5f\n' % L_init_estimate)
    return L_init_estimate
def main(argv):
    # NOTE: keep values monotonic
    tau_list = [1.01, 1.1, 1.5, 2.0]
    eta_list = [0.1, 0.01, 0.5, 0.99]
    if FLAGS.metric != 'kl':
        raise NotImplementedError(
            'metric %s not supported, only kl supported' % (FLAGS.metric))
    val_matrix = np.full((len(tau_list), len(eta_list)), np.inf)
    for folder in FLAGS.dirlist:
        # sending string as some folders may not have meta.info
        hyper_params = parse(
            open(os.path.join(folder, "meta.info"), 'r').readline())
        x = tau_list.index(hyper_params['tau'])
        y = eta_list.index(hyper_params['eta'])
        val_matrix[x, y] = get_best_metric(os.path.join(folder, "kl.csv"))

    debug(val_matrix)

    fig, ax = plt.subplots()
    im = ax.imshow(val_matrix, cmap='magma_r')
    ax.set_xticks(np.arange(len(tau_list)))
    ax.set_yticks(np.arange(len(eta_list)))
    ax.set_xticklabels(tau_list)
    ax.set_yticklabels(eta_list)
    ax.set_xlabel('tau')
    ax.set_ylabel('eta')
    plt.setp(ax.get_xticklabels(), fontsize=16)
    plt.setp(ax.get_yticklabels(), fontsize=16)

    #for i in range(len(tau_list)):
    #    for j in range(len(eta_list)):
    #        text = ax.text(j, i, val_matrix[i, j],
    #                    ha="center", va="center", color="w")

    ax.set_title("%s value for different hp configurations" % (FLAGS.metric))
    fig.tight_layout()
    plt.show()
def adaptive_pfw(weights, comps, locs, diags, q_t, mu_s, cov_s, s_t, p,
                 k, l_prev):
    """
        Adaptive pairwise variant.
    Args:
        same as fixed
    """
    d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval()
    logger.info('distance norm is %.5f' % d_t_norm)

    # Find v_t
    qcomps = q_t.components
    index_v_t, step_v_t = argmax_grad_dotp(p, q_t, qcomps,
                                           FLAGS.n_monte_carlo_samples)
    v_t = qcomps[index_v_t]

    # Pairwise gap
    sample_s = s_t.sample([FLAGS.n_monte_carlo_samples])
    step_s = tf.reduce_mean(grad_kl(q_t, p, sample_s)).eval()
    gap_pw = step_v_t - step_s
    if gap_pw < 0: eprint("Pairwise gap is negative")

    def default_fixed_step(fail_type='fixed'):
        # adaptive failed, return to fixed
        gamma = 2. / (k + 2.)
        new_comps = copy.copy(comps)
        new_comps.append({'loc': mu_s, 'scale_diag': cov_s})
        new_weights = [(1. - gamma) * w for w in weights]
        new_weights.append(gamma)
        return {
            'gamma': 2. / (k + 2.),
            'l_estimate': l_prev,
            'weights': new_weights,
            'comps': new_comps,
            'gap': gap_pw,
            'step_type': fail_type
        }

    logger.info('Pairwise gap %.5f' % gap_pw)

    # Set $q_{t+1}$'s params
    new_locs = copy.copy(locs)
    new_diags = copy.copy(diags)
    new_locs.append(mu_s)
    new_diags.append(cov_s)
    gap = gap_pw
    if gap <= 0:
        return default_fixed_step()
    gamma_max = weights[index_v_t]
    step_type = 'adaptive'

    tau = FLAGS.exp_adafw
    eta = FLAGS.damping_adafw
    pow_tau = 1.0
    i, l_t = 0, l_prev
    f_t =  kl_divergence(q_t, p, allow_nan_stats=False).eval()
    drop_step = False
    debug('f(q_t) = %.5f' % (f_t))
    gamma = 2. / (k + 2)
    while gamma >= MIN_GAMMA and i < FLAGS.adafw_MAXITER:
        # compute $L_t$ and $\gamma_t$
        l_t = pow_tau * eta * l_prev
        gamma = min(gap / (l_t * d_t_norm), gamma_max)

        d_1 = - gamma * gap
        d_2 = gamma * gamma * l_t * d_t_norm / 2.
        debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2))
        quad_bound_rhs = f_t  + d_1 + d_2

        # construct $q_{t + 1}$
        new_weights = copy.copy(weights)
        new_weights.append(gamma)
        if gamma == gamma_max:
            # hardcoding to 0 for precision issues
            new_weights[index_v_t] = 0
            drop_step = True
        else:
            new_weights[index_v_t] -= gamma
            drop_step = False

        qt_new = Mixture(
            cat=Categorical(probs=tf.convert_to_tensor(new_weights)),
            components=[
                MultivariateNormalDiag(loc=loc, scale_diag=diag)
                for loc, diag in zip(new_locs, new_diags)
            ])

        quad_bound_lhs = kl_divergence(qt_new, p, allow_nan_stats=False).eval()
        logger.info('lt = %.5f, gamma = %.3f, f_(qt_new) = %.5f, '
                    'linear extrapolated = %.5f' % (l_t, gamma, quad_bound_lhs,
                                                    quad_bound_rhs))
        if quad_bound_lhs <= quad_bound_rhs:
            new_comps = copy.copy(comps)
            new_comps.append({'loc': mu_s, 'scale_diag': cov_s})
            if drop_step:
                del new_comps[index_v_t]
                del new_weights[index_v_t]
                logger.info("...drop step")
                step_type = 'drop'
            return {
                'gamma': gamma,
                'l_estimate': l_t,
                'weights': new_weights,
                'comps': new_comps,
                'gap': gap,
                'step_type': step_type
            }
        pow_tau *= tau
        i += 1
    
    # gamma below MIN_GAMMA
    logger.warning("gamma below threshold value, returning fixed step")
    return default_fixed_step("fixed_adaptive_MAXITER")
def adaptive_fw(weights, locs, diags, q_t, mu_s, cov_s, s_t, p, k, l_prev,
                return_gamma=False):
    """Adaptive Frank-Wolfe algorithm.
    
    Sets step size as suggested in Algorithm 1 of
    https://arxiv.org/pdf/1806.05123.pdf

    Args:
        weights: [k], weights of the mixture components of q_t
        locs: [k x dim], means of mixture components of q_t
        diags: [k x dim], std deviations of mixture components of q_t
        q_t: current mixture iterate q_t
        mu_s: [dim], mean for LMO solution s
        cov_s: [dim], cov matrix for LMO solution s
        s_t: Current atom & LMO Solution s
        p: edward.model, target distribution p
        k: iteration number of Frank-Wolfe
        l_prev: previous lipschitz estimate
        return_gamma: only return the value of gamma
    Returns:
        If return_gamma is True, only the computed value of gamma
        is returned. Else returns a dictionary containing gamma, 
        lipschitz estimate, duality gap and step information
    """

    # Set $q_{t+1}$'s params
    new_locs = copy.copy(locs)
    new_diags = copy.copy(diags)
    new_locs.append(mu_s)
    new_diags.append(cov_s)

    d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval()
    logger.info('distance norm is %.5f' % d_t_norm)

    N_samples = FLAGS.n_monte_carlo_samples
    # create and sample from $s_t, q_t$
    sample_q = q_t.sample([N_samples])
    sample_s = s_t.sample([N_samples])
    step_s = tf.reduce_mean(grad_kl(q_t, p, sample_s)).eval()
    step_q = tf.reduce_mean(grad_kl(q_t, p, sample_q)).eval()
    gap = step_q - step_s
    logger.info('duality gap %.5f' % gap)
    if gap < 0: logger.warning("Duality gap is negative returning 0 step")

    #gamma = 2. / (k + 2.)
    gamma = 0.
    tau = FLAGS.exp_adafw
    eta = FLAGS.damping_adafw
    # did the adaptive loop suceed or not
    step_type = "fixed"
    # NOTE: this is from v1 of the paper, new version
    # replaces multiplicative tau with divisor eta
    pow_tau = 1.0
    i, l_t = 0, l_prev
    f_t =  kl_divergence(q_t, p, allow_nan_stats=False).eval()
    debug('f(q_t) = %.5f' % (f_t))
    # return intial estimate if gap is -ve
    while gap >= 0:
        # compute $L_t$ and $\gamma_t$
        l_t = pow_tau * eta * l_prev
        gamma = min(gap / (l_t * d_t_norm), 1.0)
        d_1 = - gamma * gap
        d_2 = gamma * gamma * l_t * d_t_norm / 2.
        debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2))
        quad_bound_rhs = f_t  + d_1 + d_2

        # $w_{t + 1} = [(1 - \gamma)w_t, \gamma]$
        new_weights = copy.copy(weights)
        new_weights = [(1. - gamma) * w for w in new_weights]
        new_weights.append(gamma)
        qt_new = Mixture(
            cat=Categorical(probs=tf.convert_to_tensor(new_weights)),
            components=[
                MultivariateNormalDiag(loc=loc, scale_diag=diag)
                for loc, diag in zip(new_locs, new_diags)
            ])
        quad_bound_lhs = kl_divergence(qt_new, p, allow_nan_stats=False).eval()
        logger.info('lt = %.5f, gamma = %.3f, f_(qt_new) = %.5f, '
                    'linear extrapolated = %.5f' % (l_t, gamma, quad_bound_lhs,
                                                    quad_bound_rhs))
        if quad_bound_lhs <= quad_bound_rhs:
            step_type = "adaptive"
            break
        pow_tau *= tau
        i += 1
        #if i > FLAGS.adafw_MAXITER or gamma < MIN_GAMMA:
        if i > FLAGS.adafw_MAXITER:
            # estimate not good
            #gamma = 2. / (k + 2.)
            gamma = 0.
            l_t = l_prev
            step_type = "fixed_adaptive_MAXITER"
            break

    if return_gamma: return gamma
    return {
        'gamma': gamma,
        'l_estimate': l_t,
        'gap': gap,
        'step_type': step_type
    }
def adaptive_afw(weights, comps, locs, diags, q_t, mu_s, cov_s, s_t, p,
                 k, l_prev):
    """
        Away steps variant
    Args:
        same as fixed
    """
    d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval()
    logger.info('distance norm is %.5f' % d_t_norm)

    # Find v_t
    qcomps = q_t.components
    index_v_t, step_v_t = argmax_grad_dotp(p, q_t, qcomps,
                                           FLAGS.n_monte_carlo_samples)
    v_t = qcomps[index_v_t]

    # Frank-Wolfe gap
    sample_q = q_t.sample([FLAGS.n_monte_carlo_samples])
    sample_s = s_t.sample([FLAGS.n_monte_carlo_samples])
    step_s = tf.reduce_mean(grad_kl(q_t, p, sample_s)).eval()
    step_q = tf.reduce_mean(grad_kl(q_t, p, sample_q)).eval()
    gap_fw = step_q - step_s
    if gap_fw < 0: logger.warning("Frank-Wolfe duality gap is negative")
    # Away gap
    gap_a = step_v_t - step_q
    if gap_a < 0: eprint('Away gap < 0!!!')
    logger.info('fw gap %.5f, away gap %.5f' % (gap_fw, gap_a))

    # Set $q_{t+1}$'s params
    new_locs = copy.copy(locs)
    new_diags = copy.copy(diags)
    if (gap_fw >= gap_a) or (len(comps) == 1):
        # FW direction, proceeds exactly as adafw
        logger.info('Proceeding in FW direction ')
        adaptive_step_type = 'fw'
        gap = gap_fw
        new_locs.append(mu_s)
        new_diags.append(cov_s)
        gamma_max = 1.0
    else:
        # Away direction
        logger.info('Proceeding in Away direction ')
        adaptive_step_type = 'away'
        gap = gap_a
        if weights[index_v_t] < 1.0:
            gamma_max = weights[index_v_t] / (1.0 - weights[index_v_t])
        else:
            gamma_max = 100. # Large value when t = 1

    def default_fixed_step(fail_type='fixed'):
        # adaptive failed, return to fixed
        gamma = 2. / (k + 2.)
        new_comps = copy.copy(comps)
        new_comps.append({'loc': mu_s, 'scale_diag': cov_s})
        new_weights = [(1. - gamma) * w for w in weights]
        new_weights.append(gamma)
        return {
            'gamma': 2. / (k + 2.),
            'l_estimate': l_prev,
            'weights': new_weights,
            'comps': new_comps,
            'gap': gap,
            'step_type': fail_type
        }
    
    if gap <= 0:
        return default_fixed_step()

    tau = FLAGS.exp_adafw
    eta = FLAGS.damping_adafw
    pow_tau = 1.0
    i, l_t = 0, l_prev
    f_t =  kl_divergence(q_t, p, allow_nan_stats=False).eval()
    debug('f(q_t) = %.5f' % (f_t))
    gamma = 2. / (k + 2)
    is_drop_step = False
    while gamma >= MIN_GAMMA and i < FLAGS.adafw_MAXITER:
        # compute $L_t$ and $\gamma_t$
        l_t = pow_tau * eta * l_prev
        # NOTE: Handle extreme values of gamma carefully
        gamma = min(gap / (l_t * d_t_norm), gamma_max)

        d_1 = - gamma * gap
        d_2 = gamma * gamma * l_t * d_t_norm / 2.
        debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2))
        quad_bound_rhs = f_t  + d_1 + d_2

        # construct $q_{t + 1}$
        if adaptive_step_type == 'fw':
            if gamma == gamma_max:
                # gamma = 1.0, q_{t + 1} = s_t
                new_comps = [{'loc': mu_s, 'scale_diag': cov_s}]
                new_weights = [1.]
                qt_new = MultivariateNormalDiag(loc=mu_s, scale_diag=cov_s)
            else:
                new_comps = copy.copy(comps)
                new_comps.append({'loc': mu_s, 'scale_diag': cov_s})
                new_weights = copy.copy(weights)
                new_weights = [(1. - gamma) * w for w in new_weights]
                new_weights.append(gamma)
                qt_new = Mixture(
                    cat=Categorical(probs=tf.convert_to_tensor(new_weights)),
                    components=[
                        MultivariateNormalDiag(loc=loc, scale_diag=diag)
                        for loc, diag in zip(new_locs, new_diags)
                    ])
        elif adaptive_step_type == 'away':
            new_weights = copy.copy(weights)
            new_comps = copy.copy(comps)
            if gamma == gamma_max:
                # drop v_t
                is_drop_step = True
                logger.info('...drop step')
                del new_weights[index_v_t]
                new_weights = [(1. + gamma) * w for w in new_weights]
                del new_comps[index_v_t]
                # NOTE: recompute locs and diags after dropping v_t
                drop_locs = [c['loc'] for c in new_comps]
                drop_diags = [c['scale_diag'] for c in new_comps]
                qt_new = Mixture(
                    cat=Categorical(probs=tf.convert_to_tensor(new_weights)),
                    components=[
                        MultivariateNormalDiag(loc=loc, scale_diag=diag)
                        for loc, diag in zip(drop_locs, drop_diags)
                    ])
            else:
                is_drop_step = False
                new_weights = [(1. + gamma) * w for w in new_weights]
                new_weights[index_v_t] -= gamma
                qt_new = Mixture(
                    cat=Categorical(probs=tf.convert_to_tensor(new_weights)),
                    components=[
                        MultivariateNormalDiag(loc=loc, scale_diag=diag)
                        for loc, diag in zip(new_locs, new_diags)
                    ])

        quad_bound_lhs = kl_divergence(qt_new, p, allow_nan_stats=False).eval()
        logger.info('lt = %.5f, gamma = %.3f, f_(qt_new) = %.5f, '
                    'linear extrapolated = %.5f' % (l_t, gamma, quad_bound_lhs,
                                                    quad_bound_rhs))
        if quad_bound_lhs <= quad_bound_rhs:
            step_type = "adaptive"
            if adaptive_step_type == "away": step_type = "away"
            if is_drop_step: step_type = "drop"
            return {
                'gamma': gamma,
                'l_estimate': l_t,
                'weights': new_weights,
                'comps': new_comps,
                'gap': gap,
                'step_type': step_type
            }
        pow_tau *= tau
        i += 1

    # adaptive loop failed, return fixed step size
    logger.warning("gamma below threshold value, returning fixed step")
    return default_fixed_step()
def adaptive_fw(weights,
                params,
                q_t,
                mu_s,
                cov_s,
                s_t,
                p,
                k,
                l_prev,
                gap=None):
    """Adaptive Frank-Wolfe algorithm.
    
    Sets step size as suggested in Algorithm 1 of
    https://arxiv.org/pdf/1806.05123.pdf

    Args:
        weights: [k], weights of the mixture components of q_t
        params: list containing dictionary of mixture params ('mu', 'scale')
        q_t: current mixture iterate q_t
        mu_s: [dim], mean for LMO solution s
        cov_s: [dim], cov matrix for LMO solution s
        s_t: Current atom & LMO Solution s
        p: edward.model, target distribution p
        k: iteration number of Frank-Wolfe
        l_prev: previous lipschitz estimate
        gap: Duality-Gap (if already computed)
    Returns:
        a dictionary containing gamma, new weights, new parameters
        lipschitz estimate, duality gap of current iterate
        and step information
    """

    # FIXME
    is_vector = FLAGS.base_dist in ['mvnormal', 'mvlaplace']

    d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval()
    logger.info('\ndistance norm is %.3e' % d_t_norm)

    N_samples = FLAGS.n_monte_carlo_samples
    if gap is None:
        # create and sample from $s_t, q_t$
        sample_q = q_t.sample([N_samples])
        sample_s = s_t.sample([N_samples])
        step_s = tf.reduce_mean(grad_elbo(q_t, p, sample_s)).eval()
        step_q = tf.reduce_mean(grad_elbo(q_t, p, sample_q)).eval()
        gap = step_q - step_s
    logger.info('duality gap %.3e' % gap)
    if gap < 0:
        logger.warning("Duality gap is negative returning fixed step")
        return fixed(weights, params, q_t, mu_s, cov_s, s_t, p, k, gap)

    gamma = 2. / (k + 2.)
    tau = FLAGS.exp_adafw
    eta = FLAGS.damping_adafw
    # NOTE: this is from v1 of the paper, new version
    # replaces multiplicative eta with divisor eta
    pow_tau = 1.0
    i, l_t = 0, l_prev
    # Objective in this case is -ELBO
    f_t = -elbo(q_t, p, N_samples, return_std=False)
    debug('f(q_t) = %.3e' % (f_t))
    # return intial estimate if gap is -ve
    while gamma >= MIN_GAMMA and i < FLAGS.adafw_MAXITER:
        # compute $L_t$ and $\gamma_t$
        l_t = pow_tau * eta * l_prev
        gamma = min(gap / (l_t * d_t_norm), 1.0)
        d_1 = -gamma * gap
        d_2 = gamma * gamma * l_t * d_t_norm / 2.
        debug('linear d1 = %.3e, quad d2 = %.3e' % (d_1, d_2))
        quad_bound_rhs = f_t + d_1 + d_2

        # $w_{t + 1} = [(1 - \gamma)w_t, \gamma]$
        # Handling the case of gamma = 1.0
        # separately, weights might not get exactly 0 because
        # of precision issues. 0 wt components should be removed
        if gamma != 1.0:
            new_weights = copy.copy(weights)
            new_weights = [(1. - gamma) * w for w in new_weights]
            new_weights.append(gamma)
            new_params = copy.copy(params)
            new_params.append({'loc': mu_s, 'scale': cov_s})
            new_components = [
                coreutils.base_loc_scale(FLAGS.base_dist,
                                         c['loc'],
                                         c['scale'],
                                         multivariate=is_vector)
                for c in new_params
            ]
        else:
            new_weights = [1.]
            new_params = [{'loc': mu_s, 'scale': cov_s}]
            new_components = [s_t]

        qt_new = coreutils.get_mixture(new_weights, new_components)
        quad_bound_lhs = -elbo(qt_new, p, N_samples, return_std=False)
        logger.info('lt = %.3e, gamma = %.3f, f_(qt_new) = %.3e, '
                    'linear extrapolated = %.3e' %
                    (l_t, gamma, quad_bound_lhs, quad_bound_rhs))
        if quad_bound_lhs <= quad_bound_rhs:
            # Adaptive loop succeeded
            return {
                'gamma': gamma,
                'l_estimate': l_t,
                'weights': new_weights,
                'params': new_params,
                'gap': gap,
                'step_type': 'adaptive'
            }
        pow_tau *= tau
        i += 1

    # gamma below MIN_GAMMA
    logger.warning("gamma below threshold value, returning fixed step")
    return fixed(weights, params, q_t, mu_s, cov_s, s_t, p, k, gap)
def adaptive_afw(weights, params, q_t, mu_s, cov_s, s_t, p, k, l_prev):
    """Adaptive Away Steps algorithm.

    Args:
        weights: [k], weights of the mixture components of q_t
        params: list containing dictionary of mixture params ('mu', 'scale')
        q_t: current mixture iterate q_t
        mu_s: [dim], mean for LMO solution s
        cov_s: [dim], cov matrix for LMO solution s
        s_t: Current atom & LMO Solution s
        p: edward.model, target distribution p
        k: iteration number of Frank-Wolfe
        l_prev: previous lipschitz estimate
    Returns:
        a dictionary containing gamma, new weights, new parameters
        lipschitz estimate, duality gap of current iterate
        and step information
    """
    # FIXME
    is_vector = FLAGS.base_dist in ['mvnormal', 'mvlaplace']

    d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval()
    logger.info('\ndistance norm is %.3e' % d_t_norm)

    # Find v_t
    qcomps = q_t.components
    index_v_t, step_v_t = argmax_grad_dotp(p, q_t, qcomps,
                                           FLAGS.n_monte_carlo_samples)
    v_t = qcomps[index_v_t]

    # Frank-Wolfe gap
    N_samples = FLAGS.n_monte_carlo_samples
    sample_q = q_t.sample([N_samples])
    sample_s = s_t.sample([N_samples])
    step_s = tf.reduce_mean(grad_elbo(q_t, p, sample_s)).eval()
    step_q = tf.reduce_mean(grad_elbo(q_t, p, sample_q)).eval()
    gap_fw = step_q - step_s
    if gap_fw < 0: logger.warning("Frank-Wolfe duality gap is negative")
    # Away gap
    gap_a = step_v_t - step_q
    if gap_a < 0: eprint('Away gap < 0!!!')
    logger.info('fw gap %.3e, away gap %.3e' % (gap_fw, gap_a))

    if (gap_fw >= gap_a) or (len(params) == 1):
        # FW direction, proceeds exactly as adafw
        logger.info('Proceeding in FW direction ')
        return adaptive_fw(weights, params, q_t, mu_s, cov_s, s_t, p, k,
                           l_prev, gap_fw)

    # Away direction
    logger.info('Proceeding in Away direction ')
    adaptive_step_type = 'away'
    gap = gap_a
    if weights[index_v_t] < 1.0:
        MAX_GAMMA = weights[index_v_t] / (1.0 - weights[index_v_t])
    else:
        MAX_GAMMA = 100.  # Large value when t = 1

    gamma = 2. / (k + 2.)
    tau = FLAGS.exp_adafw
    eta = FLAGS.damping_adafw
    pow_tau = 1.0
    i, l_t = 0, l_prev
    f_t = -elbo(q_t, p, N_samples, return_std=False)
    debug('f(q_t) = %.5f' % (f_t))
    is_drop_step = False
    while gamma >= MIN_GAMMA and i < FLAGS.adafw_MAXITER:
        # compute $L_t$ and $\gamma_t$
        l_t = pow_tau * eta * l_prev
        # NOTE: Handle extreme values of gamma carefully
        gamma = min(gap / (l_t * d_t_norm), MAX_GAMMA)

        d_1 = -gamma * gap
        d_2 = gamma * gamma * l_t * d_t_norm / 2.
        debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2))
        quad_bound_rhs = f_t + d_1 + d_2

        # construct $q_{t + 1}$
        new_weights = copy.copy(weights)
        new_params = copy.copy(params)
        if gamma == MAX_GAMMA:
            # drop v_t
            is_drop_step = True
            del new_weights[index_v_t]
            new_weights = [(1. + gamma) * w for w in new_weights]
            del new_params[index_v_t]
        else:
            is_drop_step = False
            new_weights = [(1. + gamma) * w for w in new_weights]
            new_weights[index_v_t] -= gamma

        new_components = [
            coreutils.base_loc_scale(FLAGS.base_dist,
                                     c['loc'],
                                     c['scale'],
                                     multivariate=is_vector)
            for c in new_params
        ]

        qt_new = coreutils.get_mixture(new_weights, new_components)
        quad_bound_lhs = -elbo(qt_new, p, N_samples, return_std=False)
        logger.info('lt = %.3e, gamma = %.3f, f_(qt_new) = %.3e, '
                    'linear extrapolated = %.3e' %
                    (l_t, gamma, quad_bound_lhs, quad_bound_rhs))
        if quad_bound_lhs <= quad_bound_rhs:
            return {
                'gamma': gamma,
                'l_estimate': l_t,
                'weights': new_weights,
                'params': new_params,
                'gap': gap,
                'step_type': "drop" if is_drop_step else "away"
            }
        pow_tau *= tau
        i += 1

    # gamma below MIN_GAMMA
    logger.warning("gamma below threshold value, returning fixed step")
    return fixed(weights, params, q_t, mu_s, cov_s, s_t, p, k, gap)
def adaptive_pfw(weights, params, q_t, mu_s, cov_s, s_t, p, k, l_prev):
    """Adaptive pairwise variant.
    
    Args:
        weights: [k], weights of the mixture components of q_t
        params: list containing dictionary of mixture params ('mu', 'scale')
        q_t: current mixture iterate q_t
        mu_s: [dim], mean for LMO solution s
        cov_s: [dim], cov matrix for LMO solution s
        s_t: Current atom & LMO Solution s
        p: edward.model, target distribution p
        k: iteration number of Frank-Wolfe
        l_prev: previous lipschitz estimate
    Returns:
        a dictionary containing gamma, new weights, new parameters
        lipschitz estimate, duality gap of current iterate
        and step information
    """

    # FIXME
    is_vector = FLAGS.base_dist in ['mvnormal', 'mvlaplace']

    d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval()
    logger.info('\ndistance norm is %.3e' % d_t_norm)

    # Find v_t
    qcomps = q_t.components
    index_v_t, step_v_t = argmax_grad_dotp(p, q_t, qcomps,
                                           FLAGS.n_monte_carlo_samples)
    v_t = qcomps[index_v_t]

    # Pairwise gap
    N_samples = FLAGS.n_monte_carlo_samples
    sample_s = s_t.sample([N_samples])
    step_s = tf.reduce_mean(grad_elbo(q_t, p, sample_s)).eval()
    gap_pw = step_v_t - step_s
    logger.info('Pairwise gap %.3e' % gap_pw)
    if gap_pw <= 0:
        logger.warning('Pairwise gap <= 0, returning fixed step')
        return fixed(weights, params, q_t, mu_s, cov_s, s_t, p, k, gap_pw)
    gap = gap_pw

    MAX_GAMMA = weights[index_v_t]

    gamma = 2. / (k + 2.)
    tau = FLAGS.exp_adafw
    eta = FLAGS.damping_adafw
    pow_tau = 1.0
    i, l_t = 0, l_prev
    f_t = -elbo(q_t, p, N_samples, return_std=False)
    debug('f(q_t) = %.3e' % f_t)
    is_drop_step = False
    while gamma >= MIN_GAMMA and i < FLAGS.adafw_MAXITER:
        # compute L_t and gamma_t
        l_t = pow_tau * eta * l_prev
        gamma = min(gap / (l_t * d_t_norm), MAX_GAMMA)

        d_1 = -gamma * gap
        d_2 = gamma * gamma * l_t * d_t_norm / 2.
        debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2))
        quad_bound_rhs = f_t + d_1 + d_2

        # construct q_{t + 1}
        # handle the case of gamma = MAX_GAMMA separately
        new_weights = copy.copy(weights)
        new_weights.append(gamma)
        new_params = copy.copy(params)
        new_params.append({'loc': mu_s, 'scale': cov_s})
        if gamma != MAX_GAMMA:
            new_weights[index_v_t] -= gamma
            is_drop_step = False
        else:
            # hardcoding to 0
            del new_weights[index_v_t]
            del new_params[index_v_t]
            is_drop_step = True

        new_components = [
            coreutils.base_loc_scale(FLAGS.base_dist,
                                     c['loc'],
                                     c['scale'],
                                     multivariate=is_vector)
            for c in new_params
        ]

        qt_new = coreutils.get_mixture(new_weights, new_components)
        quad_bound_lhs = -elbo(qt_new, p, N_samples, return_std=False)
        logger.info('lt = %.3e, gamma = %.3f, f_(qt_new) = %.3e, '
                    'linear extrapolated = %.3e' %
                    (l_t, gamma, quad_bound_lhs, quad_bound_rhs))
        if quad_bound_lhs <= quad_bound_rhs:
            # Adaptive loop succeeded
            return {
                'gamma': gamma,
                'l_estimate': l_t,
                'weights': new_weights,
                'params': new_params,
                'gap': gap,
                'step_type': 'drop' if is_drop_step else 'adaptive'
            }
        pow_tau *= tau
        i += 1

    # gamma below MIN_GAMMA
    logger.warning("gamma below threshold value, returning fixed step")
    return fixed(weights, params, q_t, mu_s, cov_s, s_t, p, k, gap)
示例#11
0
def main(argv):
    del argv

    x = deserialize_target_from_file(FLAGS.target)

    if FLAGS.widegrid:
        grid = np.arange(-25, 25, 0.1).astype(np.float32)
    else:
        grid = np.arange(-4, 4, 0.1).astype(np.float32)

    if FLAGS.grid2d:
        # 2D grid
        grid = np.arange(-2, 2, 0.1).astype(np.float32)
        gridx, gridy = np.meshgrid(grid, grid)
        grid = np.vstack((gridx.flatten(), gridy.flatten())).T

    if FLAGS.labels:
        labels = FLAGS.labels
    else:
        labels = ['approximation'] * len(FLAGS.qt)

    if FLAGS.styles:
        styles = FLAGS.styles
    else:
        styles = ['+', 'x', '.', '-']
        colors = ['Greens', 'Reds']

    sess = tf.Session()
    if FLAGS.grid2d:
        fig = plt.figure()
        ax = fig.add_subplot(211)
    else:
        fig, ax = plt.subplots()
        grid = np.expand_dims(grid, 1)  # package dims for tf
    with sess.as_default():
        xprobs = x.log_prob(grid)
        xprobs = tf.exp(xprobs).eval()
        if FLAGS.grid2d:
            ax.pcolormesh(gridx,
                          gridy,
                          xprobs.reshape(gridx.shape),
                          cmap='Blues')
        else:
            ax.plot(grid, xprobs, label='target', linewidth=2.0)

        if len(FLAGS.qt) == 0:
            eprint(
                "provide some qts to the `--qt` option if you would like to "
                "plot them")

        for i, (qt_filename, label) in enumerate(zip(FLAGS.qt, labels)):
            debug("visualizing %s" % qt_filename)
            qt = deserialize_mixture_from_file(qt_filename)
            qtprobs = tf.exp(qt.log_prob(grid))
            qtprobs = qtprobs.eval()
            if FLAGS.grid2d:
                ax2 = fig.add_subplot(212)
                ax2.pcolormesh(gridx,
                               gridy,
                               qtprobs.reshape(gridx.shape),
                               cmap='Greens')
            else:
                ax.plot(grid,
                        qtprobs,
                        styles[i % len(styles)],
                        label=label,
                        linewidth=2.0)

        if len(FLAGS.qt) == 1 and FLAGS.bars:
            locs = [comp.loc.eval() for comp in qt.components]
            ax.plot(locs, [0] * len(locs), '+')

            weights = qt.cat.probs.eval()
            for i in range(len(locs)):
                ax.bar(locs[i], weights[i], .05)

    ax.set_xticks([])
    ax.set_xticklabels([])
    ax.set_xlabel(FLAGS.xlabel)
    ax.set_ylabel(FLAGS.ylabel)
    fig.suptitle(FLAGS.title)
    if not FLAGS.grid2d:
        legend = plt.legend(loc='upper right',
                            prop={'size': 15},
                            bbox_to_anchor=(1.08, 1))
    if FLAGS.outdir == 'stdout':
        plt.show()
    else:
        fig.tight_layout()
        outname = os.path.join(os.path.expanduser(FLAGS.outdir), FLAGS.outfile)
        fig.savefig(outname,
                    bbox_extra_artists=(legend, ),
                    bbox_inches='tight')
        print('saved to ', outname)
示例#12
0
def main(_):
    # setting up output directory
    outdir = os.path.expanduser(FLAGS.outdir)
    os.makedirs(outdir, exist_ok=True)

    N, M, D, R_true, I_train, I_test = get_data()
    debug('N, M, D', N, M, D)

    # Solution components
    weights, qUVt_components = [], []

    # Files to log metrics
    times_filename = os.path.join(outdir, 'times.csv')
    mse_train_filename = os.path.join(outdir, 'mse_train.csv')
    mse_test_filename = os.path.join(outdir, 'mse_test.csv')
    ll_test_filename = os.path.join(outdir, 'll_test.csv')
    ll_train_filename = os.path.join(outdir, 'll_train.csv')
    elbos_filename = os.path.join(outdir, 'elbos.csv')
    gap_filename = os.path.join(outdir, 'gap.csv')
    step_filename = os.path.join(outdir, 'steps.csv')
    # 'adafw', 'ada_afw', 'ada_pfw'
    if FLAGS.fw_variant.startswith('ada'):
        lipschitz_filename = os.path.join(outdir, 'lipschitz.csv')
        iter_info_filename = os.path.join(outdir, 'iter_info.txt')

    start = 0
    if FLAGS.restore:
        #start = 50
        #qUVt_components = get_random_components(D, N, M, start)
        #weights = np.random.dirichlet([1.] * start).astype(np.float32)
        #lipschitz_estimate = opt.adafw_linit()
        parameters = np.load(os.path.join(outdir, 'qt_latest.npz'))
        weights = list(parameters['weights'])
        start = parameters['fw_iter']
        qUVt_components = list(parameters['comps'])
        assert len(weights) == len(qUVt_components), "Inconsistent storage"
        # get lipschitz estimate from the file, could've stored it
        # in params but that would mean different saved file for
        # adaptive variants
        if FLAGS.fw_variant.startswith('ada'):
            lipschitz_filename = os.path.join(outdir, 'lipschitz.csv')
            if not os.path.isfile(lipschitz_filename):
                raise ValueError("Inconsistent storage")
            with open(lipschitz_filename, 'r') as f:
                l = f.readlines()
                lipschitz_estimate = float(l[-1].strip())
    else:
        # empty the files present in the folder already
        open(times_filename, 'w').close()
        open(mse_train_filename, 'w').close()
        open(mse_test_filename, 'w').close()
        open(ll_test_filename, 'w').close()
        open(ll_train_filename, 'w').close()
        open(elbos_filename, 'w').close()
        open(gap_filename, 'w').close()
        open(step_filename, 'w').close()
        # 'adafw', 'ada_afw', 'ada_pfw'
        if FLAGS.fw_variant.startswith('ada'):
            open(lipschitz_filename, 'w').close()
            open(iter_info_filename, 'w').close()

    for t in range(start, start + FLAGS.n_fw_iter):
        g = tf.Graph()
        with g.as_default():
            tf.set_random_seed(FLAGS.seed)
            sess = tf.InteractiveSession()
            with sess.as_default():
                # MODEL
                I = tf.placeholder(tf.float32, [N, M])

                scale_uv = tf.concat(
                    [tf.ones([D, N]), tf.ones([D, M])], axis=1)
                mean_uv = tf.concat(
                    [tf.zeros([D, N]), tf.zeros([D, M])], axis=1)

                UV = Normal(loc=mean_uv, scale=scale_uv)
                R = Normal(loc=tf.matmul(tf.transpose(UV[:, :N]), UV[:, N:]),
                           scale=tf.ones([N, M]))  # generator dist. for matrix
                R_mask = R * I  # generated masked matrix

                p_joint = Joint(R_true, I_train, sess, D, N, M)

                if t == 0:
                    fw_iterates = {}
                else:
                    # Current solution
                    prev_components = [
                        coreutils.base_loc_scale('mvn0',
                                                 c['loc'],
                                                 c['scale'],
                                                 multivariate=False)
                        for c in qUVt_components
                    ]
                    qUV_prev = coreutils.get_mixture(weights, prev_components)
                    fw_iterates = {UV: qUV_prev}

                # LMO (via relbo INFERENCE)
                mean_suv = tf.concat([
                    tf.get_variable("qU/loc", [D, N]),
                    tf.get_variable("qV/loc", [D, M])
                ],
                                     axis=1)
                scale_suv = tf.concat([
                    tf.nn.softplus(tf.get_variable("qU/scale", [D, N])),
                    tf.nn.softplus(tf.get_variable("qV/scale", [D, M]))
                ],
                                      axis=1)

                sUV = Normal(loc=mean_suv, scale=scale_suv)

                #inference = relbo.KLqp({UV: sUV}, data={R: R_true, I: I_train},
                inference = relbo.KLqp({UV: sUV},
                                       data={
                                           R_mask: R_true,
                                           I: I_train
                                       },
                                       fw_iterates=fw_iterates,
                                       fw_iter=t)
                inference.run(n_iter=FLAGS.LMO_iter)

                loc_s = sUV.mean().eval()
                scale_s = sUV.stddev().eval()
                # sUV is batched distrbution, there are issues making
                # Mixture with batch distributions. mvn0
                # with event size (D, N + M) and batch size ()
                # NOTE log_prob(sample) still returns tensor
                # mvn and multivariatenormaldiag work for 1-D not 2-D shapes
                sUV_mv = coreutils.base_loc_scale('mvn0',
                                                  loc_s,
                                                  scale_s,
                                                  multivariate=False)
                # TODO send sUV or sUV_mv as argument to step size? sample
                # works the same way. same with log_prob

                total_time = 0.
                data = {R: R_true, I: I_train}
                if t == 0:
                    gamma = 1.
                    lipschitz_estimate = opt.adafw_linit()
                    step_type = 'init'
                elif FLAGS.fw_variant == 'fixed':
                    start_step_time = time.time()
                    step_result = opt.fixed(weights, qUVt_components, qUV_prev,
                                            loc_s, scale_s, sUV, p_joint, data,
                                            t)
                    end_step_time = time.time()
                    total_time += float(end_step_time - start_step_time)
                elif FLAGS.fw_variant == 'line_search':
                    start_step_time = time.time()
                    step_result = opt.line_search_dkl(weights, qUVt_components,
                                                      qUV_prev, loc_s, scale_s,
                                                      sUV, p_joint, data, t)
                    end_step_time = time.time()
                    total_time += float(end_step_time - start_step_time)
                elif FLAGS.fw_variant == 'adafw':
                    start_step_time = time.time()
                    step_result = opt.adaptive_fw(weights, qUVt_components,
                                                  qUV_prev, loc_s, scale_s,
                                                  sUV, p_joint, data, t,
                                                  lipschitz_estimate)
                    end_step_time = time.time()
                    total_time += float(end_step_time - start_step_time)

                    step_type = step_result['step_type']
                    if step_type == 'adaptive':
                        lipschitz_estimate = step_result['l_estimate']
                elif FLAGS.fw_variant == 'ada_pfw':
                    start_step_time = time.time()
                    step_result = opt.adaptive_pfw(weights, qUVt_components,
                                                   qUV_prev, loc_s, scale_s,
                                                   sUV, p_joint, data, t,
                                                   lipschitz_estimate)
                    end_step_time = time.time()
                    total_time += float(end_step_time - start_step_time)

                    step_type = step_result['step_type']
                    if step_type in ['adaptive', 'drop']:
                        lipschitz_estimate = step_result['l_estimate']
                elif FLAGS.fw_variant == 'ada_afw':
                    start_step_time = time.time()
                    step_result = opt.adaptive_pfw(weights, qUVt_components,
                                                   qUV_prev, loc_s, scale_s,
                                                   sUV, p_joint, data, t,
                                                   lipschitz_estimate)
                    end_step_time = time.time()
                    total_time += float(end_step_time - start_step_time)

                    step_type = step_result['step_type']
                    if step_type in ['adaptive', 'away', 'drop']:
                        lipschitz_estimate = step_result['l_estimate']

                if t == 0:
                    gamma = 1.
                    weights.append(gamma)
                    qUVt_components.append({'loc': loc_s, 'scale': scale_s})
                    new_components = [sUV_mv]
                else:
                    qUVt_components = step_result['params']
                    weights = step_result['weights']
                    gamma = step_result['gamma']
                    new_components = [
                        coreutils.base_loc_scale('mvn0',
                                                 c['loc'],
                                                 c['scale'],
                                                 multivariate=False)
                        for c in qUVt_components
                    ]

                qUV_new = coreutils.get_mixture(weights, new_components)

                #qR = Normal(
                #    loc=tf.matmul(
                #        tf.transpose(qUV_new[:, :N]), qUV_new[:, N:]),
                #    scale=tf.ones([N, M]))
                qR = ed.copy(R, {UV: qUV_new})
                cR = ed.copy(R_mask, {UV: qUV_new})  # reconstructed matrix

                # Log metrics for current iteration
                logger.info('total time %f' % total_time)
                append_to_file(times_filename, total_time)

                logger.info('iter %d, gamma %.4f' % (t, gamma))
                append_to_file(step_filename, gamma)

                if t > 0:
                    gap_t = step_result['gap']
                    logger.info('iter %d, gap %.4f' % (t, gap_t))
                    append_to_file(gap_filename, gap_t)

                # CRITICISM
                if FLAGS.fw_variant.startswith('ada'):
                    append_to_file(lipschitz_filename, lipschitz_estimate)
                    append_to_file(iter_info_filename, step_type)
                    logger.info('lt = %.5f, iter_type = %s' %
                                (lipschitz_estimate, step_type))

                test_mse = ed.evaluate('mean_squared_error',
                                       data={
                                           cR: R_true,
                                           I: I_test
                                       })
                logger.info("iter %d ed test mse %.5f" % (t, test_mse))
                append_to_file(mse_test_filename, test_mse)

                train_mse = ed.evaluate('mean_squared_error',
                                        data={
                                            cR: R_true,
                                            I: I_train
                                        })
                logger.info("iter %d ed train mse %.5f" % (t, train_mse))
                append_to_file(mse_train_filename, train_mse)

                # very slow
                #train_ll = log_likelihood(qUV_new, R_true, I_train, sess, D, N,
                #                          M)
                train_ll = ed.evaluate('log_lik',
                                       data={
                                           qR: R_true.astype(np.float32),
                                           I: I_train
                                       })
                logger.info("iter %d train log lik %.5f" % (t, train_ll))
                append_to_file(ll_train_filename, train_ll)

                #test_ll = log_likelihood(qUV_new, R_true, I_test, sess, D, N, M)
                test_ll = ed.evaluate('log_lik',
                                      data={
                                          qR: R_true.astype(np.float32),
                                          I: I_test
                                      })
                logger.info("iter %d test log lik %.5f" % (t, test_ll))
                append_to_file(ll_test_filename, test_ll)

                # elbo_loss might be meaningless
                elbo_loss = elboModel.KLqp({UV: qUV_new},
                                           data={
                                               R: R_true,
                                               I: I_train
                                           })
                elbo_t = elbo(qUV_new, p_joint)
                res_update = elbo_loss.run()
                logger.info('iter %d -elbo loss %.2f or %.2f' %
                            (t, res_update['loss'], elbo_t))
                append_to_file(elbos_filename,
                               "%f,%f" % (elbo_t, res_update['loss']))

                # serialize the current iterate
                np.savez(os.path.join(outdir, 'qt_latest.npz'),
                         weights=weights,
                         comps=qUVt_components,
                         fw_iter=t + 1)

                sess.close()
        tf.reset_default_graph()
示例#13
0
def line_search_dkl(weights,
                    params,
                    q_t,
                    mu_s,
                    cov_s,
                    s_t,
                    p,
                    data,
                    k,
                    gap=None):
    """Performs line search for the best step size gamma.
    
    Uses gradient ascent to find gamma that minimizes
    ELBO(q_t + gamma (s - q_t) || p)
    
    Args:
        weights: [k], weights of mixture components of q_t
        params: list containing dictionary of mixture params ('mu', 'scale')
        q_t: current mixture iterate q_t
        mu_s: [dim], mean for LMO Solution s
        cov_s: [dim], cov matrix for LMO solution s
        s_t: Current atom & LMO Solution s
        p: edward.model, target distribution p
        k: iteration number of Frank-Wolfe
    Returns:
        a dictionary containing gamma, new weights, new parameters
        lipschitz estimate, duality gap of current iterate
        and step information
    """

    # initialize $\gamma$
    gamma = 2. / (k + 2.)
    for it in range(FLAGS.n_line_search_iter):
        print("line_search iter %d, %.5f" % (it, gamma))
        new_weights = copy.copy(weights)
        new_weights = [(1. - gamma) * w for w in new_weights]
        new_weights.append(gamma)
        new_params = copy.copy(params)
        new_params.append({'loc': mu_s, 'scale': cov_s})
        new_components = [
            coreutils.base_loc_scale(FLAGS.base_dist,
                                     c['loc'],
                                     c['scale'],
                                     multivariate=False) for c in new_params
        ]
        qt_new = coreutils.get_mixture(new_weights, new_components)
        step_s = grad_kl_dotp(qt_new, p, s_t)
        step_q = grad_kl_dotp(qt_new, p, q_t)
        gap = step_q - step_s
        # Gradient descent step size decreasing as $\frac{1}{it + 1}$
        gamma_prime = gamma - FLAGS.linit_fixed * (step_s - step_q) / (it + 1.)
        debug('line search gap %.3e gamma_p %f' % (gap, gamma_prime))
        # Projecting it back to [0, 1]
        if gamma_prime >= 1 or gamma_prime <= 0:
            gamma_prime = max(min(gamma_prime, 1.), 0.)

        if np.abs(gamma - gamma_prime) < 1e-6:
            gamma = gamma_prime
            break

        gamma = gamma_prime

    return {
        'gamma': gamma,
        'weights': new_weights,
        'params': new_params,
        'step_type': 'line_search',
        'gap': gap
    }