def adaptive_pfw(weights, comps, locs, diags, q_t, mu_s, cov_s, s_t, p,
                 k, l_prev):
    """
        Adaptive pairwise variant.
    Args:
        same as fixed
    """
    d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval()
    logger.info('distance norm is %.5f' % d_t_norm)

    # Find v_t
    qcomps = q_t.components
    index_v_t, step_v_t = argmax_grad_dotp(p, q_t, qcomps,
                                           FLAGS.n_monte_carlo_samples)
    v_t = qcomps[index_v_t]

    # Pairwise gap
    sample_s = s_t.sample([FLAGS.n_monte_carlo_samples])
    step_s = tf.reduce_mean(grad_kl(q_t, p, sample_s)).eval()
    gap_pw = step_v_t - step_s
    if gap_pw < 0: eprint("Pairwise gap is negative")

    def default_fixed_step(fail_type='fixed'):
        # adaptive failed, return to fixed
        gamma = 2. / (k + 2.)
        new_comps = copy.copy(comps)
        new_comps.append({'loc': mu_s, 'scale_diag': cov_s})
        new_weights = [(1. - gamma) * w for w in weights]
        new_weights.append(gamma)
        return {
            'gamma': 2. / (k + 2.),
            'l_estimate': l_prev,
            'weights': new_weights,
            'comps': new_comps,
            'gap': gap_pw,
            'step_type': fail_type
        }

    logger.info('Pairwise gap %.5f' % gap_pw)

    # Set $q_{t+1}$'s params
    new_locs = copy.copy(locs)
    new_diags = copy.copy(diags)
    new_locs.append(mu_s)
    new_diags.append(cov_s)
    gap = gap_pw
    if gap <= 0:
        return default_fixed_step()
    gamma_max = weights[index_v_t]
    step_type = 'adaptive'

    tau = FLAGS.exp_adafw
    eta = FLAGS.damping_adafw
    pow_tau = 1.0
    i, l_t = 0, l_prev
    f_t =  kl_divergence(q_t, p, allow_nan_stats=False).eval()
    drop_step = False
    debug('f(q_t) = %.5f' % (f_t))
    gamma = 2. / (k + 2)
    while gamma >= MIN_GAMMA and i < FLAGS.adafw_MAXITER:
        # compute $L_t$ and $\gamma_t$
        l_t = pow_tau * eta * l_prev
        gamma = min(gap / (l_t * d_t_norm), gamma_max)

        d_1 = - gamma * gap
        d_2 = gamma * gamma * l_t * d_t_norm / 2.
        debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2))
        quad_bound_rhs = f_t  + d_1 + d_2

        # construct $q_{t + 1}$
        new_weights = copy.copy(weights)
        new_weights.append(gamma)
        if gamma == gamma_max:
            # hardcoding to 0 for precision issues
            new_weights[index_v_t] = 0
            drop_step = True
        else:
            new_weights[index_v_t] -= gamma
            drop_step = False

        qt_new = Mixture(
            cat=Categorical(probs=tf.convert_to_tensor(new_weights)),
            components=[
                MultivariateNormalDiag(loc=loc, scale_diag=diag)
                for loc, diag in zip(new_locs, new_diags)
            ])

        quad_bound_lhs = kl_divergence(qt_new, p, allow_nan_stats=False).eval()
        logger.info('lt = %.5f, gamma = %.3f, f_(qt_new) = %.5f, '
                    'linear extrapolated = %.5f' % (l_t, gamma, quad_bound_lhs,
                                                    quad_bound_rhs))
        if quad_bound_lhs <= quad_bound_rhs:
            new_comps = copy.copy(comps)
            new_comps.append({'loc': mu_s, 'scale_diag': cov_s})
            if drop_step:
                del new_comps[index_v_t]
                del new_weights[index_v_t]
                logger.info("...drop step")
                step_type = 'drop'
            return {
                'gamma': gamma,
                'l_estimate': l_t,
                'weights': new_weights,
                'comps': new_comps,
                'gap': gap,
                'step_type': step_type
            }
        pow_tau *= tau
        i += 1
    
    # gamma below MIN_GAMMA
    logger.warning("gamma below threshold value, returning fixed step")
    return default_fixed_step("fixed_adaptive_MAXITER")
def line_search_dkl(weights, locs, diags, q_t, mu_s, cov_s, s_t, p, k,
                    return_gamma=False):
    """Performs line search for the best step size gamma.
    
    Uses gradient ascent to find gamma that minimizes
    KL(q_t + gamma (s - q_t) || p)
    
    Args:
        weights: [k], weights of mixture components of q_t
        locs: [k x dim], means of mixture components of q_t
        diags: [k x dim], deviations of mixture components of q_t
        q_t: current mixture iterate q_t
        mu_s: [dim], mean for LMO Solution s
        cov_s: [dim], cov matrix for LMO solution s
        s_t: Current atom & LMO Solution s
        p: edward.model, target distribution p
        k: iteration number of Frank-Wolfe
        return_gamma: only return the value of gamma
    Returns:
        If return_gamma is True, only the computed value of
        gamma is returned. Else along with gradient data
        is returned in a dict
    """
    N_samples = FLAGS.n_monte_carlo_samples
    # sample from $q_t$ and s
    sample_q = q_t.sample([N_samples])
    sample_s = s_t.sample([N_samples])
    # set $q_{t+1}$'s parameters
    new_locs = copy.copy(locs)
    new_diags = copy.copy(diags)
    new_locs.append(mu_s)
    new_diags.append(cov_s)
    # initialize $\gamma$
    gamma = 2. / (k + 2.)
    n_steps = FLAGS.n_line_search_iter
    prog_bar = ed.util.Progbar(n_steps)
    # storing gradients for analysis
    grad_gamma = []
    for it in range(n_steps):
        print("line_search iter %d, %.5f" % (it, gamma))
        new_weights = copy.copy(weights)
        new_weights = [(1. - gamma) * w for w in new_weights]
        new_weights.append(gamma)
        qt_new = Mixture(
            cat=Categorical(probs=tf.convert_to_tensor(new_weights)),
            components=[
                MultivariateNormalDiag(loc=loc, scale_diag=diag)
                for loc, diag in zip(new_locs, new_diags)
            ])
        rez_s = grad_kl(qt_new, p, sample_s).eval()
        rez_q = grad_kl(qt_new, p, sample_q).eval()
        grad_gamma.append({'E_s': rez_s, 'E_q': rez_q, 'gamma': gamma})
        # Gradient descent step size decreasing as $\frac{1}{it + 1}$
        gamma_prime = gamma - 0.1 * (np.mean(rez_s) - np.mean(rez_q)) / (it + 1.)
        # Projecting it back to [0, 1]
        if gamma_prime >= 1 or gamma_prime <= 0:
            gamma_prime = max(min(gamma_prime, 1.), 0.)

        if np.abs(gamma - gamma_prime) < 1e-6:
            gamma = gamma_prime
            break

        gamma = gamma_prime

    if return_gamma: return gamma
    return {'gamma': gamma, 'n_samples': N_samples, 'grad_gamma': grad_gamma}
def adaptive_afw(weights, comps, locs, diags, q_t, mu_s, cov_s, s_t, p,
                 k, l_prev):
    """
        Away steps variant
    Args:
        same as fixed
    """
    d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval()
    logger.info('distance norm is %.5f' % d_t_norm)

    # Find v_t
    qcomps = q_t.components
    index_v_t, step_v_t = argmax_grad_dotp(p, q_t, qcomps,
                                           FLAGS.n_monte_carlo_samples)
    v_t = qcomps[index_v_t]

    # Frank-Wolfe gap
    sample_q = q_t.sample([FLAGS.n_monte_carlo_samples])
    sample_s = s_t.sample([FLAGS.n_monte_carlo_samples])
    step_s = tf.reduce_mean(grad_kl(q_t, p, sample_s)).eval()
    step_q = tf.reduce_mean(grad_kl(q_t, p, sample_q)).eval()
    gap_fw = step_q - step_s
    if gap_fw < 0: logger.warning("Frank-Wolfe duality gap is negative")
    # Away gap
    gap_a = step_v_t - step_q
    if gap_a < 0: eprint('Away gap < 0!!!')
    logger.info('fw gap %.5f, away gap %.5f' % (gap_fw, gap_a))

    # Set $q_{t+1}$'s params
    new_locs = copy.copy(locs)
    new_diags = copy.copy(diags)
    if (gap_fw >= gap_a) or (len(comps) == 1):
        # FW direction, proceeds exactly as adafw
        logger.info('Proceeding in FW direction ')
        adaptive_step_type = 'fw'
        gap = gap_fw
        new_locs.append(mu_s)
        new_diags.append(cov_s)
        gamma_max = 1.0
    else:
        # Away direction
        logger.info('Proceeding in Away direction ')
        adaptive_step_type = 'away'
        gap = gap_a
        if weights[index_v_t] < 1.0:
            gamma_max = weights[index_v_t] / (1.0 - weights[index_v_t])
        else:
            gamma_max = 100. # Large value when t = 1

    def default_fixed_step(fail_type='fixed'):
        # adaptive failed, return to fixed
        gamma = 2. / (k + 2.)
        new_comps = copy.copy(comps)
        new_comps.append({'loc': mu_s, 'scale_diag': cov_s})
        new_weights = [(1. - gamma) * w for w in weights]
        new_weights.append(gamma)
        return {
            'gamma': 2. / (k + 2.),
            'l_estimate': l_prev,
            'weights': new_weights,
            'comps': new_comps,
            'gap': gap,
            'step_type': fail_type
        }
    
    if gap <= 0:
        return default_fixed_step()

    tau = FLAGS.exp_adafw
    eta = FLAGS.damping_adafw
    pow_tau = 1.0
    i, l_t = 0, l_prev
    f_t =  kl_divergence(q_t, p, allow_nan_stats=False).eval()
    debug('f(q_t) = %.5f' % (f_t))
    gamma = 2. / (k + 2)
    is_drop_step = False
    while gamma >= MIN_GAMMA and i < FLAGS.adafw_MAXITER:
        # compute $L_t$ and $\gamma_t$
        l_t = pow_tau * eta * l_prev
        # NOTE: Handle extreme values of gamma carefully
        gamma = min(gap / (l_t * d_t_norm), gamma_max)

        d_1 = - gamma * gap
        d_2 = gamma * gamma * l_t * d_t_norm / 2.
        debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2))
        quad_bound_rhs = f_t  + d_1 + d_2

        # construct $q_{t + 1}$
        if adaptive_step_type == 'fw':
            if gamma == gamma_max:
                # gamma = 1.0, q_{t + 1} = s_t
                new_comps = [{'loc': mu_s, 'scale_diag': cov_s}]
                new_weights = [1.]
                qt_new = MultivariateNormalDiag(loc=mu_s, scale_diag=cov_s)
            else:
                new_comps = copy.copy(comps)
                new_comps.append({'loc': mu_s, 'scale_diag': cov_s})
                new_weights = copy.copy(weights)
                new_weights = [(1. - gamma) * w for w in new_weights]
                new_weights.append(gamma)
                qt_new = Mixture(
                    cat=Categorical(probs=tf.convert_to_tensor(new_weights)),
                    components=[
                        MultivariateNormalDiag(loc=loc, scale_diag=diag)
                        for loc, diag in zip(new_locs, new_diags)
                    ])
        elif adaptive_step_type == 'away':
            new_weights = copy.copy(weights)
            new_comps = copy.copy(comps)
            if gamma == gamma_max:
                # drop v_t
                is_drop_step = True
                logger.info('...drop step')
                del new_weights[index_v_t]
                new_weights = [(1. + gamma) * w for w in new_weights]
                del new_comps[index_v_t]
                # NOTE: recompute locs and diags after dropping v_t
                drop_locs = [c['loc'] for c in new_comps]
                drop_diags = [c['scale_diag'] for c in new_comps]
                qt_new = Mixture(
                    cat=Categorical(probs=tf.convert_to_tensor(new_weights)),
                    components=[
                        MultivariateNormalDiag(loc=loc, scale_diag=diag)
                        for loc, diag in zip(drop_locs, drop_diags)
                    ])
            else:
                is_drop_step = False
                new_weights = [(1. + gamma) * w for w in new_weights]
                new_weights[index_v_t] -= gamma
                qt_new = Mixture(
                    cat=Categorical(probs=tf.convert_to_tensor(new_weights)),
                    components=[
                        MultivariateNormalDiag(loc=loc, scale_diag=diag)
                        for loc, diag in zip(new_locs, new_diags)
                    ])

        quad_bound_lhs = kl_divergence(qt_new, p, allow_nan_stats=False).eval()
        logger.info('lt = %.5f, gamma = %.3f, f_(qt_new) = %.5f, '
                    'linear extrapolated = %.5f' % (l_t, gamma, quad_bound_lhs,
                                                    quad_bound_rhs))
        if quad_bound_lhs <= quad_bound_rhs:
            step_type = "adaptive"
            if adaptive_step_type == "away": step_type = "away"
            if is_drop_step: step_type = "drop"
            return {
                'gamma': gamma,
                'l_estimate': l_t,
                'weights': new_weights,
                'comps': new_comps,
                'gap': gap,
                'step_type': step_type
            }
        pow_tau *= tau
        i += 1

    # adaptive loop failed, return fixed step size
    logger.warning("gamma below threshold value, returning fixed step")
    return default_fixed_step()
def adaptive_fw(weights, locs, diags, q_t, mu_s, cov_s, s_t, p, k, l_prev,
                return_gamma=False):
    """Adaptive Frank-Wolfe algorithm.
    
    Sets step size as suggested in Algorithm 1 of
    https://arxiv.org/pdf/1806.05123.pdf

    Args:
        weights: [k], weights of the mixture components of q_t
        locs: [k x dim], means of mixture components of q_t
        diags: [k x dim], std deviations of mixture components of q_t
        q_t: current mixture iterate q_t
        mu_s: [dim], mean for LMO solution s
        cov_s: [dim], cov matrix for LMO solution s
        s_t: Current atom & LMO Solution s
        p: edward.model, target distribution p
        k: iteration number of Frank-Wolfe
        l_prev: previous lipschitz estimate
        return_gamma: only return the value of gamma
    Returns:
        If return_gamma is True, only the computed value of gamma
        is returned. Else returns a dictionary containing gamma, 
        lipschitz estimate, duality gap and step information
    """

    # Set $q_{t+1}$'s params
    new_locs = copy.copy(locs)
    new_diags = copy.copy(diags)
    new_locs.append(mu_s)
    new_diags.append(cov_s)

    d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval()
    logger.info('distance norm is %.5f' % d_t_norm)

    N_samples = FLAGS.n_monte_carlo_samples
    # create and sample from $s_t, q_t$
    sample_q = q_t.sample([N_samples])
    sample_s = s_t.sample([N_samples])
    step_s = tf.reduce_mean(grad_kl(q_t, p, sample_s)).eval()
    step_q = tf.reduce_mean(grad_kl(q_t, p, sample_q)).eval()
    gap = step_q - step_s
    logger.info('duality gap %.5f' % gap)
    if gap < 0: logger.warning("Duality gap is negative returning 0 step")

    #gamma = 2. / (k + 2.)
    gamma = 0.
    tau = FLAGS.exp_adafw
    eta = FLAGS.damping_adafw
    # did the adaptive loop suceed or not
    step_type = "fixed"
    # NOTE: this is from v1 of the paper, new version
    # replaces multiplicative tau with divisor eta
    pow_tau = 1.0
    i, l_t = 0, l_prev
    f_t =  kl_divergence(q_t, p, allow_nan_stats=False).eval()
    debug('f(q_t) = %.5f' % (f_t))
    # return intial estimate if gap is -ve
    while gap >= 0:
        # compute $L_t$ and $\gamma_t$
        l_t = pow_tau * eta * l_prev
        gamma = min(gap / (l_t * d_t_norm), 1.0)
        d_1 = - gamma * gap
        d_2 = gamma * gamma * l_t * d_t_norm / 2.
        debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2))
        quad_bound_rhs = f_t  + d_1 + d_2

        # $w_{t + 1} = [(1 - \gamma)w_t, \gamma]$
        new_weights = copy.copy(weights)
        new_weights = [(1. - gamma) * w for w in new_weights]
        new_weights.append(gamma)
        qt_new = Mixture(
            cat=Categorical(probs=tf.convert_to_tensor(new_weights)),
            components=[
                MultivariateNormalDiag(loc=loc, scale_diag=diag)
                for loc, diag in zip(new_locs, new_diags)
            ])
        quad_bound_lhs = kl_divergence(qt_new, p, allow_nan_stats=False).eval()
        logger.info('lt = %.5f, gamma = %.3f, f_(qt_new) = %.5f, '
                    'linear extrapolated = %.5f' % (l_t, gamma, quad_bound_lhs,
                                                    quad_bound_rhs))
        if quad_bound_lhs <= quad_bound_rhs:
            step_type = "adaptive"
            break
        pow_tau *= tau
        i += 1
        #if i > FLAGS.adafw_MAXITER or gamma < MIN_GAMMA:
        if i > FLAGS.adafw_MAXITER:
            # estimate not good
            #gamma = 2. / (k + 2.)
            gamma = 0.
            l_t = l_prev
            step_type = "fixed_adaptive_MAXITER"
            break

    if return_gamma: return gamma
    return {
        'gamma': gamma,
        'l_estimate': l_t,
        'gap': gap,
        'step_type': step_type
    }
Esempio n. 5
0
def run_gap(pi, mus, stds):
    weights, comps = [], []
    elbos = []
    relbo_vals = []
    for t in range(FLAGS.n_fw_iter):
        logger.info('Frank Wolfe Iteration %d' % t)
        g = tf.Graph()
        with g.as_default():
            tf.set_random_seed(FLAGS.seed)
            sess = tf.InteractiveSession()
            with sess.as_default():
                # target distribution components
                pcomps = [
                    MultivariateNormalDiag(
                        loc=tf.convert_to_tensor(mus[i], dtype=tf.float32),
                        scale_diag=tf.convert_to_tensor(stds[i],
                                                        dtype=tf.float32))
                    for i in range(len(mus))
                ]
                # target distribution
                p = Mixture(cat=Categorical(probs=tf.convert_to_tensor(pi)),
                            components=pcomps)

                # LMO appoximation
                s = construct_normal([1], t, 's')
                fw_iterates = {}
                if t > 0:
                    qtx = Mixture(
                        cat=Categorical(probs=tf.convert_to_tensor(weights)),
                        components=[
                            MultivariateNormalDiag(**c) for c in comps
                        ])
                    fw_iterates = {p: qtx}
                sess.run(tf.global_variables_initializer())
                # Run inference on relbo to solve LMO problem
                # NOTE: KLqp has a side effect, it is modifying s
                inference = relbo.KLqp({p: s},
                                       fw_iterates=fw_iterates,
                                       fw_iter=t)
                inference.run(n_iter=FLAGS.LMO_iter)
                # s now contains solution to LMO

                if t > 0:
                    sample_s = s.sample([FLAGS.n_monte_carlo_samples])
                    sample_q = qtx.sample([FLAGS.n_monte_carlo_samples])
                    step_s = tf.reduce_mean(grad_kl(qtx, p, sample_s)).eval()
                    step_q = tf.reduce_mean(grad_kl(qtx, p, sample_q)).eval()
                    gap = step_q - step_s
                    logger.info('Frank-Wolfe gap at iter %d is %.5f' %
                                (t, gap))
                    if gap < 0:
                        eprint('Frank-Wolfe gab becoming negative!')
                    # f(q*) = f(p) = 0
                    logger.info('Objective value (actual gap) is %.5f' %
                                kl_divergence(qtx, p).eval())

                gamma = 2. / (t + 2.)
                comps.append({
                    'loc': s.mean().eval(),
                    'scale_diag': s.stddev().eval()
                })
                weights = coreutils.update_weights(weights, gamma, t)

        tf.reset_default_graph()