def main(argv): del argv if FLAGS.grid2d: raise NotImplementedError('Only 1D Normal supported...') if FLAGS.qt == "": eprint("provide some qt to the `--qt` option if you would like to " "plot") if FLAGS.label: label = FLAGS.label else: qt_file = os.path.splitext(FLAGS.qt)[0] label = qt_file[qt_file.find('qt_') + len('qt_'):] plt.figure(1) debug("visualizing %s" % FLAGS.qt) mixture_params = get_mixture_params_from_file(FLAGS.qt) #plot_normal_mix(mixture_params['weights'], mixture_params['locs'], # mixture_params['scale_diags'], plt, label) plt.figure(2) w = mixture_params['weights'] barlist = plt.bar(np.arange(len(w)), w, color='b', label=label) if FLAGS.iter_labels: label_name = os.path.basename(FLAGS.iter_labels) if label_name.startswith('iter_types'): # label which iterations came from adaptive which # from fixed. with open(FLAGS.iter_labels, 'r') as f: iter_types = f.readlines() for i, it in enumerate(iter_types): it = it.strip() if it != 'adaptive': if it == 'fixed': barlist[i + 1].set_color('r') elif it == 'fixed_adaptive_MAXITER': barlist[i + 1].set_color('g') else: barlist[i + 1].set_color('k') ad = mpatches.Patch(color='b', label='Adaptive step') fi = mpatches.Patch(color='r', label='Fixed step (adafw loop long)') fa = mpatches.Patch(color='g', label='Fixed step (-ve gap)') plt.legend(handles=[ad, fi, fa], loc=2) if FLAGS.outdir == 'stdout': plt.show() else: fig.tight_layout() outname = os.path.join(os.path.expanduser(FLAGS.outdir), FLAGS.outfile) fig.savefig(outname, bbox_extra_artists=(legend, ), bbox_inches='tight') print('saved to ', outname)
def adaptive_pfw(weights, comps, locs, diags, q_t, mu_s, cov_s, s_t, p, k, l_prev): """ Adaptive pairwise variant. Args: same as fixed """ d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval() logger.info('distance norm is %.5f' % d_t_norm) # Find v_t qcomps = q_t.components index_v_t, step_v_t = argmax_grad_dotp(p, q_t, qcomps, FLAGS.n_monte_carlo_samples) v_t = qcomps[index_v_t] # Pairwise gap sample_s = s_t.sample([FLAGS.n_monte_carlo_samples]) step_s = tf.reduce_mean(grad_kl(q_t, p, sample_s)).eval() gap_pw = step_v_t - step_s if gap_pw < 0: eprint("Pairwise gap is negative") def default_fixed_step(fail_type='fixed'): # adaptive failed, return to fixed gamma = 2. / (k + 2.) new_comps = copy.copy(comps) new_comps.append({'loc': mu_s, 'scale_diag': cov_s}) new_weights = [(1. - gamma) * w for w in weights] new_weights.append(gamma) return { 'gamma': 2. / (k + 2.), 'l_estimate': l_prev, 'weights': new_weights, 'comps': new_comps, 'gap': gap_pw, 'step_type': fail_type } logger.info('Pairwise gap %.5f' % gap_pw) # Set $q_{t+1}$'s params new_locs = copy.copy(locs) new_diags = copy.copy(diags) new_locs.append(mu_s) new_diags.append(cov_s) gap = gap_pw if gap <= 0: return default_fixed_step() gamma_max = weights[index_v_t] step_type = 'adaptive' tau = FLAGS.exp_adafw eta = FLAGS.damping_adafw pow_tau = 1.0 i, l_t = 0, l_prev f_t = kl_divergence(q_t, p, allow_nan_stats=False).eval() drop_step = False debug('f(q_t) = %.5f' % (f_t)) gamma = 2. / (k + 2) while gamma >= MIN_GAMMA and i < FLAGS.adafw_MAXITER: # compute $L_t$ and $\gamma_t$ l_t = pow_tau * eta * l_prev gamma = min(gap / (l_t * d_t_norm), gamma_max) d_1 = - gamma * gap d_2 = gamma * gamma * l_t * d_t_norm / 2. debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2)) quad_bound_rhs = f_t + d_1 + d_2 # construct $q_{t + 1}$ new_weights = copy.copy(weights) new_weights.append(gamma) if gamma == gamma_max: # hardcoding to 0 for precision issues new_weights[index_v_t] = 0 drop_step = True else: new_weights[index_v_t] -= gamma drop_step = False qt_new = Mixture( cat=Categorical(probs=tf.convert_to_tensor(new_weights)), components=[ MultivariateNormalDiag(loc=loc, scale_diag=diag) for loc, diag in zip(new_locs, new_diags) ]) quad_bound_lhs = kl_divergence(qt_new, p, allow_nan_stats=False).eval() logger.info('lt = %.5f, gamma = %.3f, f_(qt_new) = %.5f, ' 'linear extrapolated = %.5f' % (l_t, gamma, quad_bound_lhs, quad_bound_rhs)) if quad_bound_lhs <= quad_bound_rhs: new_comps = copy.copy(comps) new_comps.append({'loc': mu_s, 'scale_diag': cov_s}) if drop_step: del new_comps[index_v_t] del new_weights[index_v_t] logger.info("...drop step") step_type = 'drop' return { 'gamma': gamma, 'l_estimate': l_t, 'weights': new_weights, 'comps': new_comps, 'gap': gap, 'step_type': step_type } pow_tau *= tau i += 1 # gamma below MIN_GAMMA logger.warning("gamma below threshold value, returning fixed step") return default_fixed_step("fixed_adaptive_MAXITER")
def adaptive_afw(weights, comps, locs, diags, q_t, mu_s, cov_s, s_t, p, k, l_prev): """ Away steps variant Args: same as fixed """ d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval() logger.info('distance norm is %.5f' % d_t_norm) # Find v_t qcomps = q_t.components index_v_t, step_v_t = argmax_grad_dotp(p, q_t, qcomps, FLAGS.n_monte_carlo_samples) v_t = qcomps[index_v_t] # Frank-Wolfe gap sample_q = q_t.sample([FLAGS.n_monte_carlo_samples]) sample_s = s_t.sample([FLAGS.n_monte_carlo_samples]) step_s = tf.reduce_mean(grad_kl(q_t, p, sample_s)).eval() step_q = tf.reduce_mean(grad_kl(q_t, p, sample_q)).eval() gap_fw = step_q - step_s if gap_fw < 0: logger.warning("Frank-Wolfe duality gap is negative") # Away gap gap_a = step_v_t - step_q if gap_a < 0: eprint('Away gap < 0!!!') logger.info('fw gap %.5f, away gap %.5f' % (gap_fw, gap_a)) # Set $q_{t+1}$'s params new_locs = copy.copy(locs) new_diags = copy.copy(diags) if (gap_fw >= gap_a) or (len(comps) == 1): # FW direction, proceeds exactly as adafw logger.info('Proceeding in FW direction ') adaptive_step_type = 'fw' gap = gap_fw new_locs.append(mu_s) new_diags.append(cov_s) gamma_max = 1.0 else: # Away direction logger.info('Proceeding in Away direction ') adaptive_step_type = 'away' gap = gap_a if weights[index_v_t] < 1.0: gamma_max = weights[index_v_t] / (1.0 - weights[index_v_t]) else: gamma_max = 100. # Large value when t = 1 def default_fixed_step(fail_type='fixed'): # adaptive failed, return to fixed gamma = 2. / (k + 2.) new_comps = copy.copy(comps) new_comps.append({'loc': mu_s, 'scale_diag': cov_s}) new_weights = [(1. - gamma) * w for w in weights] new_weights.append(gamma) return { 'gamma': 2. / (k + 2.), 'l_estimate': l_prev, 'weights': new_weights, 'comps': new_comps, 'gap': gap, 'step_type': fail_type } if gap <= 0: return default_fixed_step() tau = FLAGS.exp_adafw eta = FLAGS.damping_adafw pow_tau = 1.0 i, l_t = 0, l_prev f_t = kl_divergence(q_t, p, allow_nan_stats=False).eval() debug('f(q_t) = %.5f' % (f_t)) gamma = 2. / (k + 2) is_drop_step = False while gamma >= MIN_GAMMA and i < FLAGS.adafw_MAXITER: # compute $L_t$ and $\gamma_t$ l_t = pow_tau * eta * l_prev # NOTE: Handle extreme values of gamma carefully gamma = min(gap / (l_t * d_t_norm), gamma_max) d_1 = - gamma * gap d_2 = gamma * gamma * l_t * d_t_norm / 2. debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2)) quad_bound_rhs = f_t + d_1 + d_2 # construct $q_{t + 1}$ if adaptive_step_type == 'fw': if gamma == gamma_max: # gamma = 1.0, q_{t + 1} = s_t new_comps = [{'loc': mu_s, 'scale_diag': cov_s}] new_weights = [1.] qt_new = MultivariateNormalDiag(loc=mu_s, scale_diag=cov_s) else: new_comps = copy.copy(comps) new_comps.append({'loc': mu_s, 'scale_diag': cov_s}) new_weights = copy.copy(weights) new_weights = [(1. - gamma) * w for w in new_weights] new_weights.append(gamma) qt_new = Mixture( cat=Categorical(probs=tf.convert_to_tensor(new_weights)), components=[ MultivariateNormalDiag(loc=loc, scale_diag=diag) for loc, diag in zip(new_locs, new_diags) ]) elif adaptive_step_type == 'away': new_weights = copy.copy(weights) new_comps = copy.copy(comps) if gamma == gamma_max: # drop v_t is_drop_step = True logger.info('...drop step') del new_weights[index_v_t] new_weights = [(1. + gamma) * w for w in new_weights] del new_comps[index_v_t] # NOTE: recompute locs and diags after dropping v_t drop_locs = [c['loc'] for c in new_comps] drop_diags = [c['scale_diag'] for c in new_comps] qt_new = Mixture( cat=Categorical(probs=tf.convert_to_tensor(new_weights)), components=[ MultivariateNormalDiag(loc=loc, scale_diag=diag) for loc, diag in zip(drop_locs, drop_diags) ]) else: is_drop_step = False new_weights = [(1. + gamma) * w for w in new_weights] new_weights[index_v_t] -= gamma qt_new = Mixture( cat=Categorical(probs=tf.convert_to_tensor(new_weights)), components=[ MultivariateNormalDiag(loc=loc, scale_diag=diag) for loc, diag in zip(new_locs, new_diags) ]) quad_bound_lhs = kl_divergence(qt_new, p, allow_nan_stats=False).eval() logger.info('lt = %.5f, gamma = %.3f, f_(qt_new) = %.5f, ' 'linear extrapolated = %.5f' % (l_t, gamma, quad_bound_lhs, quad_bound_rhs)) if quad_bound_lhs <= quad_bound_rhs: step_type = "adaptive" if adaptive_step_type == "away": step_type = "away" if is_drop_step: step_type = "drop" return { 'gamma': gamma, 'l_estimate': l_t, 'weights': new_weights, 'comps': new_comps, 'gap': gap, 'step_type': step_type } pow_tau *= tau i += 1 # adaptive loop failed, return fixed step size logger.warning("gamma below threshold value, returning fixed step") return default_fixed_step()
def run_gap(pi, mus, stds): weights, comps = [], [] elbos = [] relbo_vals = [] for t in range(FLAGS.n_fw_iter): logger.info('Frank Wolfe Iteration %d' % t) g = tf.Graph() with g.as_default(): tf.set_random_seed(FLAGS.seed) sess = tf.InteractiveSession() with sess.as_default(): # target distribution components pcomps = [ MultivariateNormalDiag( loc=tf.convert_to_tensor(mus[i], dtype=tf.float32), scale_diag=tf.convert_to_tensor(stds[i], dtype=tf.float32)) for i in range(len(mus)) ] # target distribution p = Mixture(cat=Categorical(probs=tf.convert_to_tensor(pi)), components=pcomps) # LMO appoximation s = construct_normal([1], t, 's') fw_iterates = {} if t > 0: qtx = Mixture( cat=Categorical(probs=tf.convert_to_tensor(weights)), components=[ MultivariateNormalDiag(**c) for c in comps ]) fw_iterates = {p: qtx} sess.run(tf.global_variables_initializer()) # Run inference on relbo to solve LMO problem # NOTE: KLqp has a side effect, it is modifying s inference = relbo.KLqp({p: s}, fw_iterates=fw_iterates, fw_iter=t) inference.run(n_iter=FLAGS.LMO_iter) # s now contains solution to LMO if t > 0: sample_s = s.sample([FLAGS.n_monte_carlo_samples]) sample_q = qtx.sample([FLAGS.n_monte_carlo_samples]) step_s = tf.reduce_mean(grad_kl(qtx, p, sample_s)).eval() step_q = tf.reduce_mean(grad_kl(qtx, p, sample_q)).eval() gap = step_q - step_s logger.info('Frank-Wolfe gap at iter %d is %.5f' % (t, gap)) if gap < 0: eprint('Frank-Wolfe gab becoming negative!') # f(q*) = f(p) = 0 logger.info('Objective value (actual gap) is %.5f' % kl_divergence(qtx, p).eval()) gamma = 2. / (t + 2.) comps.append({ 'loc': s.mean().eval(), 'scale_diag': s.stddev().eval() }) weights = coreutils.update_weights(weights, gamma, t) tf.reset_default_graph()
def adaptive_afw(weights, params, q_t, mu_s, cov_s, s_t, p, k, l_prev): """Adaptive Away Steps algorithm. Args: weights: [k], weights of the mixture components of q_t params: list containing dictionary of mixture params ('mu', 'scale') q_t: current mixture iterate q_t mu_s: [dim], mean for LMO solution s cov_s: [dim], cov matrix for LMO solution s s_t: Current atom & LMO Solution s p: edward.model, target distribution p k: iteration number of Frank-Wolfe l_prev: previous lipschitz estimate Returns: a dictionary containing gamma, new weights, new parameters lipschitz estimate, duality gap of current iterate and step information """ # FIXME is_vector = FLAGS.base_dist in ['mvnormal', 'mvlaplace'] d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval() logger.info('\ndistance norm is %.3e' % d_t_norm) # Find v_t qcomps = q_t.components index_v_t, step_v_t = argmax_grad_dotp(p, q_t, qcomps, FLAGS.n_monte_carlo_samples) v_t = qcomps[index_v_t] # Frank-Wolfe gap N_samples = FLAGS.n_monte_carlo_samples sample_q = q_t.sample([N_samples]) sample_s = s_t.sample([N_samples]) step_s = tf.reduce_mean(grad_elbo(q_t, p, sample_s)).eval() step_q = tf.reduce_mean(grad_elbo(q_t, p, sample_q)).eval() gap_fw = step_q - step_s if gap_fw < 0: logger.warning("Frank-Wolfe duality gap is negative") # Away gap gap_a = step_v_t - step_q if gap_a < 0: eprint('Away gap < 0!!!') logger.info('fw gap %.3e, away gap %.3e' % (gap_fw, gap_a)) if (gap_fw >= gap_a) or (len(params) == 1): # FW direction, proceeds exactly as adafw logger.info('Proceeding in FW direction ') return adaptive_fw(weights, params, q_t, mu_s, cov_s, s_t, p, k, l_prev, gap_fw) # Away direction logger.info('Proceeding in Away direction ') adaptive_step_type = 'away' gap = gap_a if weights[index_v_t] < 1.0: MAX_GAMMA = weights[index_v_t] / (1.0 - weights[index_v_t]) else: MAX_GAMMA = 100. # Large value when t = 1 gamma = 2. / (k + 2.) tau = FLAGS.exp_adafw eta = FLAGS.damping_adafw pow_tau = 1.0 i, l_t = 0, l_prev f_t = -elbo(q_t, p, N_samples, return_std=False) debug('f(q_t) = %.5f' % (f_t)) is_drop_step = False while gamma >= MIN_GAMMA and i < FLAGS.adafw_MAXITER: # compute $L_t$ and $\gamma_t$ l_t = pow_tau * eta * l_prev # NOTE: Handle extreme values of gamma carefully gamma = min(gap / (l_t * d_t_norm), MAX_GAMMA) d_1 = -gamma * gap d_2 = gamma * gamma * l_t * d_t_norm / 2. debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2)) quad_bound_rhs = f_t + d_1 + d_2 # construct $q_{t + 1}$ new_weights = copy.copy(weights) new_params = copy.copy(params) if gamma == MAX_GAMMA: # drop v_t is_drop_step = True del new_weights[index_v_t] new_weights = [(1. + gamma) * w for w in new_weights] del new_params[index_v_t] else: is_drop_step = False new_weights = [(1. + gamma) * w for w in new_weights] new_weights[index_v_t] -= gamma new_components = [ coreutils.base_loc_scale(FLAGS.base_dist, c['loc'], c['scale'], multivariate=is_vector) for c in new_params ] qt_new = coreutils.get_mixture(new_weights, new_components) quad_bound_lhs = -elbo(qt_new, p, N_samples, return_std=False) logger.info('lt = %.3e, gamma = %.3f, f_(qt_new) = %.3e, ' 'linear extrapolated = %.3e' % (l_t, gamma, quad_bound_lhs, quad_bound_rhs)) if quad_bound_lhs <= quad_bound_rhs: return { 'gamma': gamma, 'l_estimate': l_t, 'weights': new_weights, 'params': new_params, 'gap': gap, 'step_type': "drop" if is_drop_step else "away" } pow_tau *= tau i += 1 # gamma below MIN_GAMMA logger.warning("gamma below threshold value, returning fixed step") return fixed(weights, params, q_t, mu_s, cov_s, s_t, p, k, gap)
def main(argv): del argv x = deserialize_target_from_file(FLAGS.target) if FLAGS.widegrid: grid = np.arange(-25, 25, 0.1).astype(np.float32) else: grid = np.arange(-4, 4, 0.1).astype(np.float32) if FLAGS.grid2d: # 2D grid grid = np.arange(-2, 2, 0.1).astype(np.float32) gridx, gridy = np.meshgrid(grid, grid) grid = np.vstack((gridx.flatten(), gridy.flatten())).T if FLAGS.labels: labels = FLAGS.labels else: labels = ['approximation'] * len(FLAGS.qt) if FLAGS.styles: styles = FLAGS.styles else: styles = ['+', 'x', '.', '-'] colors = ['Greens', 'Reds'] sess = tf.Session() if FLAGS.grid2d: fig = plt.figure() ax = fig.add_subplot(211) else: fig, ax = plt.subplots() grid = np.expand_dims(grid, 1) # package dims for tf with sess.as_default(): xprobs = x.log_prob(grid) xprobs = tf.exp(xprobs).eval() if FLAGS.grid2d: ax.pcolormesh(gridx, gridy, xprobs.reshape(gridx.shape), cmap='Blues') else: ax.plot(grid, xprobs, label='target', linewidth=2.0) if len(FLAGS.qt) == 0: eprint( "provide some qts to the `--qt` option if you would like to " "plot them") for i, (qt_filename, label) in enumerate(zip(FLAGS.qt, labels)): debug("visualizing %s" % qt_filename) qt = deserialize_mixture_from_file(qt_filename) qtprobs = tf.exp(qt.log_prob(grid)) qtprobs = qtprobs.eval() if FLAGS.grid2d: ax2 = fig.add_subplot(212) ax2.pcolormesh(gridx, gridy, qtprobs.reshape(gridx.shape), cmap='Greens') else: ax.plot(grid, qtprobs, styles[i % len(styles)], label=label, linewidth=2.0) if len(FLAGS.qt) == 1 and FLAGS.bars: locs = [comp.loc.eval() for comp in qt.components] ax.plot(locs, [0] * len(locs), '+') weights = qt.cat.probs.eval() for i in range(len(locs)): ax.bar(locs[i], weights[i], .05) ax.set_xticks([]) ax.set_xticklabels([]) ax.set_xlabel(FLAGS.xlabel) ax.set_ylabel(FLAGS.ylabel) fig.suptitle(FLAGS.title) if not FLAGS.grid2d: legend = plt.legend(loc='upper right', prop={'size': 15}, bbox_to_anchor=(1.08, 1)) if FLAGS.outdir == 'stdout': plt.show() else: fig.tight_layout() outname = os.path.join(os.path.expanduser(FLAGS.outdir), FLAGS.outfile) fig.savefig(outname, bbox_extra_artists=(legend, ), bbox_inches='tight') print('saved to ', outname)