def adaptive_pfw(weights, comps, locs, diags, q_t, mu_s, cov_s, s_t, p, k, l_prev): """ Adaptive pairwise variant. Args: same as fixed """ d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval() logger.info('distance norm is %.5f' % d_t_norm) # Find v_t qcomps = q_t.components index_v_t, step_v_t = argmax_grad_dotp(p, q_t, qcomps, FLAGS.n_monte_carlo_samples) v_t = qcomps[index_v_t] # Pairwise gap sample_s = s_t.sample([FLAGS.n_monte_carlo_samples]) step_s = tf.reduce_mean(grad_kl(q_t, p, sample_s)).eval() gap_pw = step_v_t - step_s if gap_pw < 0: eprint("Pairwise gap is negative") def default_fixed_step(fail_type='fixed'): # adaptive failed, return to fixed gamma = 2. / (k + 2.) new_comps = copy.copy(comps) new_comps.append({'loc': mu_s, 'scale_diag': cov_s}) new_weights = [(1. - gamma) * w for w in weights] new_weights.append(gamma) return { 'gamma': 2. / (k + 2.), 'l_estimate': l_prev, 'weights': new_weights, 'comps': new_comps, 'gap': gap_pw, 'step_type': fail_type } logger.info('Pairwise gap %.5f' % gap_pw) # Set $q_{t+1}$'s params new_locs = copy.copy(locs) new_diags = copy.copy(diags) new_locs.append(mu_s) new_diags.append(cov_s) gap = gap_pw if gap <= 0: return default_fixed_step() gamma_max = weights[index_v_t] step_type = 'adaptive' tau = FLAGS.exp_adafw eta = FLAGS.damping_adafw pow_tau = 1.0 i, l_t = 0, l_prev f_t = kl_divergence(q_t, p, allow_nan_stats=False).eval() drop_step = False debug('f(q_t) = %.5f' % (f_t)) gamma = 2. / (k + 2) while gamma >= MIN_GAMMA and i < FLAGS.adafw_MAXITER: # compute $L_t$ and $\gamma_t$ l_t = pow_tau * eta * l_prev gamma = min(gap / (l_t * d_t_norm), gamma_max) d_1 = - gamma * gap d_2 = gamma * gamma * l_t * d_t_norm / 2. debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2)) quad_bound_rhs = f_t + d_1 + d_2 # construct $q_{t + 1}$ new_weights = copy.copy(weights) new_weights.append(gamma) if gamma == gamma_max: # hardcoding to 0 for precision issues new_weights[index_v_t] = 0 drop_step = True else: new_weights[index_v_t] -= gamma drop_step = False qt_new = Mixture( cat=Categorical(probs=tf.convert_to_tensor(new_weights)), components=[ MultivariateNormalDiag(loc=loc, scale_diag=diag) for loc, diag in zip(new_locs, new_diags) ]) quad_bound_lhs = kl_divergence(qt_new, p, allow_nan_stats=False).eval() logger.info('lt = %.5f, gamma = %.3f, f_(qt_new) = %.5f, ' 'linear extrapolated = %.5f' % (l_t, gamma, quad_bound_lhs, quad_bound_rhs)) if quad_bound_lhs <= quad_bound_rhs: new_comps = copy.copy(comps) new_comps.append({'loc': mu_s, 'scale_diag': cov_s}) if drop_step: del new_comps[index_v_t] del new_weights[index_v_t] logger.info("...drop step") step_type = 'drop' return { 'gamma': gamma, 'l_estimate': l_t, 'weights': new_weights, 'comps': new_comps, 'gap': gap, 'step_type': step_type } pow_tau *= tau i += 1 # gamma below MIN_GAMMA logger.warning("gamma below threshold value, returning fixed step") return default_fixed_step("fixed_adaptive_MAXITER")
def line_search_dkl(weights, locs, diags, q_t, mu_s, cov_s, s_t, p, k, return_gamma=False): """Performs line search for the best step size gamma. Uses gradient ascent to find gamma that minimizes KL(q_t + gamma (s - q_t) || p) Args: weights: [k], weights of mixture components of q_t locs: [k x dim], means of mixture components of q_t diags: [k x dim], deviations of mixture components of q_t q_t: current mixture iterate q_t mu_s: [dim], mean for LMO Solution s cov_s: [dim], cov matrix for LMO solution s s_t: Current atom & LMO Solution s p: edward.model, target distribution p k: iteration number of Frank-Wolfe return_gamma: only return the value of gamma Returns: If return_gamma is True, only the computed value of gamma is returned. Else along with gradient data is returned in a dict """ N_samples = FLAGS.n_monte_carlo_samples # sample from $q_t$ and s sample_q = q_t.sample([N_samples]) sample_s = s_t.sample([N_samples]) # set $q_{t+1}$'s parameters new_locs = copy.copy(locs) new_diags = copy.copy(diags) new_locs.append(mu_s) new_diags.append(cov_s) # initialize $\gamma$ gamma = 2. / (k + 2.) n_steps = FLAGS.n_line_search_iter prog_bar = ed.util.Progbar(n_steps) # storing gradients for analysis grad_gamma = [] for it in range(n_steps): print("line_search iter %d, %.5f" % (it, gamma)) new_weights = copy.copy(weights) new_weights = [(1. - gamma) * w for w in new_weights] new_weights.append(gamma) qt_new = Mixture( cat=Categorical(probs=tf.convert_to_tensor(new_weights)), components=[ MultivariateNormalDiag(loc=loc, scale_diag=diag) for loc, diag in zip(new_locs, new_diags) ]) rez_s = grad_kl(qt_new, p, sample_s).eval() rez_q = grad_kl(qt_new, p, sample_q).eval() grad_gamma.append({'E_s': rez_s, 'E_q': rez_q, 'gamma': gamma}) # Gradient descent step size decreasing as $\frac{1}{it + 1}$ gamma_prime = gamma - 0.1 * (np.mean(rez_s) - np.mean(rez_q)) / (it + 1.) # Projecting it back to [0, 1] if gamma_prime >= 1 or gamma_prime <= 0: gamma_prime = max(min(gamma_prime, 1.), 0.) if np.abs(gamma - gamma_prime) < 1e-6: gamma = gamma_prime break gamma = gamma_prime if return_gamma: return gamma return {'gamma': gamma, 'n_samples': N_samples, 'grad_gamma': grad_gamma}
def adaptive_afw(weights, comps, locs, diags, q_t, mu_s, cov_s, s_t, p, k, l_prev): """ Away steps variant Args: same as fixed """ d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval() logger.info('distance norm is %.5f' % d_t_norm) # Find v_t qcomps = q_t.components index_v_t, step_v_t = argmax_grad_dotp(p, q_t, qcomps, FLAGS.n_monte_carlo_samples) v_t = qcomps[index_v_t] # Frank-Wolfe gap sample_q = q_t.sample([FLAGS.n_monte_carlo_samples]) sample_s = s_t.sample([FLAGS.n_monte_carlo_samples]) step_s = tf.reduce_mean(grad_kl(q_t, p, sample_s)).eval() step_q = tf.reduce_mean(grad_kl(q_t, p, sample_q)).eval() gap_fw = step_q - step_s if gap_fw < 0: logger.warning("Frank-Wolfe duality gap is negative") # Away gap gap_a = step_v_t - step_q if gap_a < 0: eprint('Away gap < 0!!!') logger.info('fw gap %.5f, away gap %.5f' % (gap_fw, gap_a)) # Set $q_{t+1}$'s params new_locs = copy.copy(locs) new_diags = copy.copy(diags) if (gap_fw >= gap_a) or (len(comps) == 1): # FW direction, proceeds exactly as adafw logger.info('Proceeding in FW direction ') adaptive_step_type = 'fw' gap = gap_fw new_locs.append(mu_s) new_diags.append(cov_s) gamma_max = 1.0 else: # Away direction logger.info('Proceeding in Away direction ') adaptive_step_type = 'away' gap = gap_a if weights[index_v_t] < 1.0: gamma_max = weights[index_v_t] / (1.0 - weights[index_v_t]) else: gamma_max = 100. # Large value when t = 1 def default_fixed_step(fail_type='fixed'): # adaptive failed, return to fixed gamma = 2. / (k + 2.) new_comps = copy.copy(comps) new_comps.append({'loc': mu_s, 'scale_diag': cov_s}) new_weights = [(1. - gamma) * w for w in weights] new_weights.append(gamma) return { 'gamma': 2. / (k + 2.), 'l_estimate': l_prev, 'weights': new_weights, 'comps': new_comps, 'gap': gap, 'step_type': fail_type } if gap <= 0: return default_fixed_step() tau = FLAGS.exp_adafw eta = FLAGS.damping_adafw pow_tau = 1.0 i, l_t = 0, l_prev f_t = kl_divergence(q_t, p, allow_nan_stats=False).eval() debug('f(q_t) = %.5f' % (f_t)) gamma = 2. / (k + 2) is_drop_step = False while gamma >= MIN_GAMMA and i < FLAGS.adafw_MAXITER: # compute $L_t$ and $\gamma_t$ l_t = pow_tau * eta * l_prev # NOTE: Handle extreme values of gamma carefully gamma = min(gap / (l_t * d_t_norm), gamma_max) d_1 = - gamma * gap d_2 = gamma * gamma * l_t * d_t_norm / 2. debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2)) quad_bound_rhs = f_t + d_1 + d_2 # construct $q_{t + 1}$ if adaptive_step_type == 'fw': if gamma == gamma_max: # gamma = 1.0, q_{t + 1} = s_t new_comps = [{'loc': mu_s, 'scale_diag': cov_s}] new_weights = [1.] qt_new = MultivariateNormalDiag(loc=mu_s, scale_diag=cov_s) else: new_comps = copy.copy(comps) new_comps.append({'loc': mu_s, 'scale_diag': cov_s}) new_weights = copy.copy(weights) new_weights = [(1. - gamma) * w for w in new_weights] new_weights.append(gamma) qt_new = Mixture( cat=Categorical(probs=tf.convert_to_tensor(new_weights)), components=[ MultivariateNormalDiag(loc=loc, scale_diag=diag) for loc, diag in zip(new_locs, new_diags) ]) elif adaptive_step_type == 'away': new_weights = copy.copy(weights) new_comps = copy.copy(comps) if gamma == gamma_max: # drop v_t is_drop_step = True logger.info('...drop step') del new_weights[index_v_t] new_weights = [(1. + gamma) * w for w in new_weights] del new_comps[index_v_t] # NOTE: recompute locs and diags after dropping v_t drop_locs = [c['loc'] for c in new_comps] drop_diags = [c['scale_diag'] for c in new_comps] qt_new = Mixture( cat=Categorical(probs=tf.convert_to_tensor(new_weights)), components=[ MultivariateNormalDiag(loc=loc, scale_diag=diag) for loc, diag in zip(drop_locs, drop_diags) ]) else: is_drop_step = False new_weights = [(1. + gamma) * w for w in new_weights] new_weights[index_v_t] -= gamma qt_new = Mixture( cat=Categorical(probs=tf.convert_to_tensor(new_weights)), components=[ MultivariateNormalDiag(loc=loc, scale_diag=diag) for loc, diag in zip(new_locs, new_diags) ]) quad_bound_lhs = kl_divergence(qt_new, p, allow_nan_stats=False).eval() logger.info('lt = %.5f, gamma = %.3f, f_(qt_new) = %.5f, ' 'linear extrapolated = %.5f' % (l_t, gamma, quad_bound_lhs, quad_bound_rhs)) if quad_bound_lhs <= quad_bound_rhs: step_type = "adaptive" if adaptive_step_type == "away": step_type = "away" if is_drop_step: step_type = "drop" return { 'gamma': gamma, 'l_estimate': l_t, 'weights': new_weights, 'comps': new_comps, 'gap': gap, 'step_type': step_type } pow_tau *= tau i += 1 # adaptive loop failed, return fixed step size logger.warning("gamma below threshold value, returning fixed step") return default_fixed_step()
def adaptive_fw(weights, locs, diags, q_t, mu_s, cov_s, s_t, p, k, l_prev, return_gamma=False): """Adaptive Frank-Wolfe algorithm. Sets step size as suggested in Algorithm 1 of https://arxiv.org/pdf/1806.05123.pdf Args: weights: [k], weights of the mixture components of q_t locs: [k x dim], means of mixture components of q_t diags: [k x dim], std deviations of mixture components of q_t q_t: current mixture iterate q_t mu_s: [dim], mean for LMO solution s cov_s: [dim], cov matrix for LMO solution s s_t: Current atom & LMO Solution s p: edward.model, target distribution p k: iteration number of Frank-Wolfe l_prev: previous lipschitz estimate return_gamma: only return the value of gamma Returns: If return_gamma is True, only the computed value of gamma is returned. Else returns a dictionary containing gamma, lipschitz estimate, duality gap and step information """ # Set $q_{t+1}$'s params new_locs = copy.copy(locs) new_diags = copy.copy(diags) new_locs.append(mu_s) new_diags.append(cov_s) d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval() logger.info('distance norm is %.5f' % d_t_norm) N_samples = FLAGS.n_monte_carlo_samples # create and sample from $s_t, q_t$ sample_q = q_t.sample([N_samples]) sample_s = s_t.sample([N_samples]) step_s = tf.reduce_mean(grad_kl(q_t, p, sample_s)).eval() step_q = tf.reduce_mean(grad_kl(q_t, p, sample_q)).eval() gap = step_q - step_s logger.info('duality gap %.5f' % gap) if gap < 0: logger.warning("Duality gap is negative returning 0 step") #gamma = 2. / (k + 2.) gamma = 0. tau = FLAGS.exp_adafw eta = FLAGS.damping_adafw # did the adaptive loop suceed or not step_type = "fixed" # NOTE: this is from v1 of the paper, new version # replaces multiplicative tau with divisor eta pow_tau = 1.0 i, l_t = 0, l_prev f_t = kl_divergence(q_t, p, allow_nan_stats=False).eval() debug('f(q_t) = %.5f' % (f_t)) # return intial estimate if gap is -ve while gap >= 0: # compute $L_t$ and $\gamma_t$ l_t = pow_tau * eta * l_prev gamma = min(gap / (l_t * d_t_norm), 1.0) d_1 = - gamma * gap d_2 = gamma * gamma * l_t * d_t_norm / 2. debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2)) quad_bound_rhs = f_t + d_1 + d_2 # $w_{t + 1} = [(1 - \gamma)w_t, \gamma]$ new_weights = copy.copy(weights) new_weights = [(1. - gamma) * w for w in new_weights] new_weights.append(gamma) qt_new = Mixture( cat=Categorical(probs=tf.convert_to_tensor(new_weights)), components=[ MultivariateNormalDiag(loc=loc, scale_diag=diag) for loc, diag in zip(new_locs, new_diags) ]) quad_bound_lhs = kl_divergence(qt_new, p, allow_nan_stats=False).eval() logger.info('lt = %.5f, gamma = %.3f, f_(qt_new) = %.5f, ' 'linear extrapolated = %.5f' % (l_t, gamma, quad_bound_lhs, quad_bound_rhs)) if quad_bound_lhs <= quad_bound_rhs: step_type = "adaptive" break pow_tau *= tau i += 1 #if i > FLAGS.adafw_MAXITER or gamma < MIN_GAMMA: if i > FLAGS.adafw_MAXITER: # estimate not good #gamma = 2. / (k + 2.) gamma = 0. l_t = l_prev step_type = "fixed_adaptive_MAXITER" break if return_gamma: return gamma return { 'gamma': gamma, 'l_estimate': l_t, 'gap': gap, 'step_type': step_type }
def run_gap(pi, mus, stds): weights, comps = [], [] elbos = [] relbo_vals = [] for t in range(FLAGS.n_fw_iter): logger.info('Frank Wolfe Iteration %d' % t) g = tf.Graph() with g.as_default(): tf.set_random_seed(FLAGS.seed) sess = tf.InteractiveSession() with sess.as_default(): # target distribution components pcomps = [ MultivariateNormalDiag( loc=tf.convert_to_tensor(mus[i], dtype=tf.float32), scale_diag=tf.convert_to_tensor(stds[i], dtype=tf.float32)) for i in range(len(mus)) ] # target distribution p = Mixture(cat=Categorical(probs=tf.convert_to_tensor(pi)), components=pcomps) # LMO appoximation s = construct_normal([1], t, 's') fw_iterates = {} if t > 0: qtx = Mixture( cat=Categorical(probs=tf.convert_to_tensor(weights)), components=[ MultivariateNormalDiag(**c) for c in comps ]) fw_iterates = {p: qtx} sess.run(tf.global_variables_initializer()) # Run inference on relbo to solve LMO problem # NOTE: KLqp has a side effect, it is modifying s inference = relbo.KLqp({p: s}, fw_iterates=fw_iterates, fw_iter=t) inference.run(n_iter=FLAGS.LMO_iter) # s now contains solution to LMO if t > 0: sample_s = s.sample([FLAGS.n_monte_carlo_samples]) sample_q = qtx.sample([FLAGS.n_monte_carlo_samples]) step_s = tf.reduce_mean(grad_kl(qtx, p, sample_s)).eval() step_q = tf.reduce_mean(grad_kl(qtx, p, sample_q)).eval() gap = step_q - step_s logger.info('Frank-Wolfe gap at iter %d is %.5f' % (t, gap)) if gap < 0: eprint('Frank-Wolfe gab becoming negative!') # f(q*) = f(p) = 0 logger.info('Objective value (actual gap) is %.5f' % kl_divergence(qtx, p).eval()) gamma = 2. / (t + 2.) comps.append({ 'loc': s.mean().eval(), 'scale_diag': s.stddev().eval() }) weights = coreutils.update_weights(weights, gamma, t) tf.reset_default_graph()