def _loop_body(t, ta_z_prior, ta_z_post, ta_kl): """ iter body. iter over trading days. """ with tf.variable_scope('iter_body', reuse=tf.AUTO_REUSE): init = lambda: tf.random_normal(shape=[self.batch_size, self.z_size], name='z_post_t_1') subsequent = lambda: tf.reshape(ta_z_post.read(t-1), [self.batch_size, self.z_size]) z_post_t_1 = tf.cond(t >= 1, subsequent, init) with tf.variable_scope('h_z_prior'): h_z_prior_t = self._linear([x[t], h_s[t], z_post_t_1], self.z_size, 'tanh') with tf.variable_scope('z_prior'): z_prior_t, z_prior_t_pdf = self._z(h_z_prior_t, is_prior=True) with tf.variable_scope('h_z_post'): h_z_post_t = self._linear([x[t], h_s[t], y_[t], z_post_t_1], self.z_size, 'tanh') with tf.variable_scope('z_post'): z_post_t, z_post_t_pdf = self._z(h_z_post_t, is_prior=False) kl_t = ds.kl_divergence(z_post_t_pdf, z_prior_t_pdf) # batch_size * z_size ta_z_prior = ta_z_prior.write(t, z_prior_t) # write: batch_size * z_size ta_z_post = ta_z_post.write(t, z_post_t) # write: batch_size * z_size ta_kl = ta_kl.write(t, kl_t) # write: batch_size * 1 return t + 1, ta_z_prior, ta_z_post, ta_kl
def gumbel_reparmeterization(logits_z, tau, rnd_sample=None, hard=True, eps=1e-9): ''' The gumbel-softmax reparameterization ''' latent_size = logits_z.get_shape().as_list()[1] # Prior p_z = d.OneHotCategorical(probs=tf.constant(1.0/latent_size, shape=[latent_size])) # p_z = d.RelaxedOneHotCategorical(probs=tf.constant(1.0/latent_size, # shape=[latent_size]), # temperature=10.0) # p_z = 1.0 / latent_size # log_p_z = tf.log(p_z + eps) with st.value_type(st.SampleValue()): q_z = st.StochasticTensor(d.RelaxedOneHotCategorical(temperature=tau, logits=logits_z)) q_z_full = st.StochasticTensor(d.OneHotCategorical(logits=logits_z)) reduce_index = [1] if len(logits_z.get_shape().as_list()) == 2 else [1, 2] kl = d.kl_divergence(q_z_full.distribution, p_z, allow_nan_stats=False) if len(shp(kl)) > 1: return [q_z, tf.reduce_sum(kl, reduce_index)] else: return [q_z, kl]
def build_loss_and_gradients(self, var_list): cof = tf.constant(self.alpha,tf.float32) cof2 = tf.constant(self.alpha+1,tf.float32) M= tf.constant(self.size,tf.float32) N= tf.constant(self.tot,tf.float32) p_log_prob = [0.0] * self.n_samples q_log_prob = [0.0] * self.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(self.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("sample") dict_swap = {} for x, qx in six.iteritems(self.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx for z, qz in six.iteritems(self.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) dict_swap[z] = qz_copy.value() q_log_prob[s] += tf.reduce_sum( self.scale.get(z, 1.0) * qz_copy.log_prob(dict_swap[z])) for z in six.iterkeys(self.latent_vars): z_copy = copy(z, dict_swap, scope=scope) q_log_prob[s] -= tf.reduce_sum( self.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) for x in six.iterkeys(self.data): if isinstance(x, RandomVariable): x_copy = copy(x, dict_swap, scope=scope) p_log_prob[s] += cof2/cof*tf.reduce_sum( tf.exp( x_copy.log_prob(dict_swap[x])*cof))*N/M-tf.exp(tf.reduce_logsumexp(tf.log((x_copy.mean())**cof2+(1-x_copy.mean())**cof2)))*N/M kl_penalty = tf.reduce_sum([ self.kl_scaling.get(z, 1.0) * tf.reduce_sum(kl_divergence(qz, z)) for z, qz in six.iteritems(self.latent_vars)]) p_log_prob = tf.reduce_mean(p_log_prob) q_log_prob = tf.reduce_mean(q_log_prob) if self.logging: tf.summary.scalar("loss/p_log_prob", p_log_prob, collections=[self._summary_key]) tf.summary.scalar("loss/q_log_prob", q_log_prob, collections=[self._summary_key]) loss = -(p_log_prob - q_log_prob) grads = tf.gradients(loss, var_list) grads_and_vars = list(zip(grads, var_list)) return loss, grads_and_vars
def build_reparam_kl_loss_and_gradients(inference, var_list): """Build loss function. Its automatic differentiation is a stochastic gradient of .. math:: -\\text{ELBO} = - ( \mathbb{E}_{q(z; \lambda)} [ \log p(x \mid z) ] + \\text{KL}(q(z; \lambda) \| p(z)) ) based on the reparameterization trick (Kingma and Welling, 2014). It assumes the KL is analytic. Computed by sampling from $q(z;\lambda)$ and evaluating the expectation using Monte Carlo sampling. """ p_log_lik = [0.0] * inference.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(inference.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("sample") dict_swap = {} for x, qx in six.iteritems(inference.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx for z, qz in six.iteritems(inference.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) dict_swap[z] = qz_copy.value() for x in six.iterkeys(inference.data): if isinstance(x, RandomVariable): x_copy = copy(x, dict_swap, scope=scope) p_log_lik[s] += tf.reduce_sum( inference.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) p_log_lik = tf.reduce_mean(p_log_lik) kl_penalty = tf.reduce_sum([ inference.kl_scaling.get(z, 1.0) * tf.reduce_sum(kl_divergence(qz, z)) for z, qz in six.iteritems(inference.latent_vars)]) if inference.logging: tf.summary.scalar("loss/p_log_lik", p_log_lik, collections=[inference._summary_key]) tf.summary.scalar("loss/kl_penalty", kl_penalty, collections=[inference._summary_key]) loss = -(p_log_lik - kl_penalty) grads = tf.gradients(loss, var_list) grads_and_vars = list(zip(grads, var_list)) return loss, grads_and_vars
def build_reparam_kl_loss_and_gradients(inference, var_list): """Build loss function. Its automatic differentiation is a stochastic gradient of .. math:: -\\text{ELBO} = - ( \mathbb{E}_{q(z; \lambda)} [ \log p(x \mid z) ] + \\text{KL}(q(z; \lambda) \| p(z)) ) based on the reparameterization trick [@kingma2014auto]. It assumes the KL is analytic. Computed by sampling from $q(z;\lambda)$ and evaluating the expectation using Monte Carlo sampling. """ p_log_lik = [0.0] * inference.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(inference.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("sample") dict_swap = {} for x, qx in six.iteritems(inference.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx for z, qz in six.iteritems(inference.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) dict_swap[z] = qz_copy.value() for x in six.iterkeys(inference.data): if isinstance(x, RandomVariable): x_copy = copy(x, dict_swap, scope=scope) p_log_lik[s] += tf.reduce_sum( inference.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) p_log_lik = tf.reduce_mean(p_log_lik) kl_penalty = tf.reduce_sum([ tf.reduce_sum(inference.kl_scaling.get(z, 1.0) * kl_divergence(qz, z)) for z, qz in six.iteritems(inference.latent_vars)]) if inference.logging: tf.summary.scalar("loss/p_log_lik", p_log_lik, collections=[inference._summary_key]) tf.summary.scalar("loss/kl_penalty", kl_penalty, collections=[inference._summary_key]) loss = -(p_log_lik - kl_penalty) grads = tf.gradients(loss, var_list) grads_and_vars = list(zip(grads, var_list)) return loss, grads_and_vars
def sample_qH(self, H): h_mu = H[:, :self.dim_h] h_var = tf.exp(H[:, self.dim_h:]) qh = dist.Normal(h_mu, tf.sqrt(h_var)) ph = dist.Normal(tf.zeros_like(h_mu), tf.ones_like(h_var)) kl_h = dist.kl_divergence(qh, ph) h_sample = qh.sample() return h_sample, kl_h
def build_score_kl_loss_and_gradients(inference, var_list): """Build loss function and gradients based on the score function estimator (Paisley et al., 2012). It assumes the KL is analytic. Computed by sampling from $q(z;\lambda)$ and evaluating the expectation using Monte Carlo sampling. """ p_log_lik = [0.0] * inference.n_samples q_log_prob = [0.0] * inference.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(inference.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("sample") dict_swap = {} for x, qx in six.iteritems(inference.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx for z, qz in six.iteritems(inference.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) dict_swap[z] = qz_copy.value() q_log_prob[s] += tf.reduce_sum( inference.scale.get(z, 1.0) * qz_copy.log_prob(tf.stop_gradient(dict_swap[z]))) for x in six.iterkeys(inference.data): if isinstance(x, RandomVariable): x_copy = copy(x, dict_swap, scope=scope) p_log_lik[s] += tf.reduce_sum( inference.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) p_log_lik = tf.stack(p_log_lik) q_log_prob = tf.stack(q_log_prob) kl_penalty = tf.reduce_sum([ inference.kl_scaling.get(z, 1.0) * tf.reduce_sum(kl_divergence(qz, z)) for z, qz in six.iteritems(inference.latent_vars)]) if inference.logging: tf.summary.scalar("loss/p_log_lik", tf.reduce_mean(p_log_lik), collections=[inference._summary_key]) tf.summary.scalar("loss/kl_penalty", kl_penalty, collections=[inference._summary_key]) loss = -(tf.reduce_mean(p_log_lik) - kl_penalty) grads = tf.gradients( -(tf.reduce_mean(q_log_prob * tf.stop_gradient(p_log_lik)) - kl_penalty), var_list) grads_and_vars = list(zip(grads, var_list)) return loss, grads_and_vars
def build_loss_and_gradients(self, var_list): cof = tf.constant(self.alpha,tf.float32) cof2 = tf.constant(self.alpha+1,tf.float32) M= tf.constant(self.size,tf.float32) N= tf.constant(self.tot,tf.float32) p_log_prob = [0.0] * self.n_samples q_log_prob = [0.0] * self.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(self.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("sample") dict_swap = {} for x, qx in six.iteritems(self.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx for z, qz in six.iteritems(self.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) dict_swap[z] = qz_copy.value() q_log_prob[s] += tf.reduce_sum( self.scale.get(z, 1.0) * qz_copy.log_prob(dict_swap[z])) for z in six.iterkeys(self.latent_vars): z_copy = copy(z, dict_swap, scope=scope) q_log_prob[s] -= tf.reduce_sum( self.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) for x in six.iterkeys(self.data): if isinstance(x, RandomVariable): x_copy = copy(x, dict_swap, scope=scope) p_log_prob[s] +=cof2/cof* tf.reduce_sum( self.scale.get(x, 1.0) *tf.exp( x_copy.log_prob(dict_swap[x])*cof))#-self.scale.get(x, 1.0) *1/cof2*(2*3.1415*1)**(cof/2)*(1+cof)**0.5) # the above second term for the unbiasedness need not to be included in the objective function because it will be constant when we consider the regression problem, and thus it will vanish when we take the gradient. kl_penalty = tf.reduce_sum([ self.kl_scaling.get(z, 1.0) * tf.reduce_sum(kl_divergence(qz, z)) for z, qz in six.iteritems(self.latent_vars)]) p_log_prob = tf.reduce_mean(p_log_prob) q_log_prob = tf.reduce_mean(q_log_prob) if self.logging: tf.summary.scalar("loss/p_log_prob", p_log_prob, collections=[self._summary_key]) tf.summary.scalar("loss/q_log_prob", q_log_prob, collections=[self._summary_key]) loss = -(p_log_prob - q_log_prob) grads = tf.gradients(loss, var_list) grads_and_vars = list(zip(grads, var_list)) return loss, grads_and_vars
def __init__(self, batch_size=1000, latent_dim=25, epochs=50): self.epochs = epochs self.latent_dim = latent_dim # Data input plink_dataset = SingleDataset( plink_file= '/plink_tensorflow/data/test/scz_easy-access_wave2.no_trio.bgn', scratch_dir='/plink_tensorflow/data/test/', overwrite=False) self.m_variants = plink_dataset.bim.shape[0] self.total_train_batches = (len(plink_dataset.train_files) // batch_size) + 1 self.total_test_batches = (len(plink_dataset.test_files) // batch_size) + 1 print('\nTraining Summary:') print('\tTraining files: {}'.format(len(plink_dataset.train_files))) print('\tTesting files: {}'.format(len(plink_dataset.test_files))) print('\tTraining batches: {}'.format(self.total_train_batches)) print('\tTesing batches: {}'.format(self.total_test_batches)) print('\nBuilding computational graph...') # Input pipeline test_dataset = self.build_test_dataset(plink_dataset, batch_size) training_dataset = self.build_training_dataset(plink_dataset, batch_size) self.handle = tf.placeholder(tf.string, shape=[]) self.iterator = tf.data.Iterator.from_string_handle( self.handle, training_dataset.output_types, training_dataset.output_shapes) self.training_iterator = training_dataset.make_initializable_iterator() self.test_iterator = test_dataset.make_initializable_iterator() genotypes = self.iterator.get_next() genotypes = tf.cast(genotypes, tf.float32, name='cast_genotypes') genotypes.set_shape([None, self.m_variants]) # Define the model. prior = self.make_prior(latent_dim=self.latent_dim) make_encoder = tf.make_template('encoder', self.make_encoder) posterior = make_encoder(genotypes, latent_dim=self.latent_dim) self.latent_z = posterior.sample() # Define the loss. make_decoder = tf.make_template('decoder', self.make_decoder) likelihood = make_decoder(self.latent_z).log_prob(genotypes) divergence = tfd.kl_divergence(posterior, prior) self.elbo = tf.reduce_mean(likelihood - divergence) self.optimizer = tf.train.AdamOptimizer(0.001).minimize(-self.elbo) print('Done')
def __init__( self, s_dim, a_dim, kl_target, ): self.a_dim = a_dim self.s_dim = s_dim self.kl_target = kl_target self.tfs = tf.placeholder(tf.float32, [None, s_dim]) # critic with tf.variable_scope('critic'): l1 = tf.layers.dense(self.tfs, 100, tf.nn.relu) self.v = tf.layers.dense(l1, 1) self.tfdc_r = tf.placeholder(tf.float32, [ None, ]) self.advantage = self.tfdc_r - tf.squeeze(self.v) self.closs = tf.reduce_mean(tf.square(self.advantage)) self.ctrain_op = tf.train.AdamOptimizer(C_LR).minimize(self.closs) # actor pi, pi_params = self._build_anet('pi', trainable=True) oldpi, oldpi_params = self._build_anet('oldpi', trainable=False) with tf.variable_scope('update_oldpi'): self.update_oldpi_op = [ oldp.assign(p) for p, oldp in zip(pi_params, oldpi_params) ] self.sample_op = pi.sample(1) self.tfa = tf.placeholder(tf.float32, [ None, ], 'action') with tf.variable_scope('ratio'): # ratio = tf.exp(pi.log_prob(self.tfa) - oldpi.log_prob(self.tfa)) ratio = pi.prob(self.tfa) / oldpi.prob(self.tfa) with tf.variable_scope('kl'): self.kl = tf.stop_gradient(tf.reduce_mean(kl_divergence(oldpi, pi))) self.tflam = tf.placeholder(tf.float32, None, 'lambda') self.tfadv = tf.placeholder(tf.float32, [ None, ], 'advantage') with tf.variable_scope('loss'): self.aloss = -(tf.reduce_mean(ratio * self.tfadv) - self.tflam * self.kl) with tf.variable_scope('atrain'): self.atrain_op = tf.train.AdamOptimizer(A_LR).minimize(self.aloss) tf.summary.FileWriter("log/", self.sess.graph) self.sess.run(tf.global_variables_initializer())
def __init__( self, s_dim, a_dim, ): self.a_dim = a_dim self.s_dim = s_dim self.sess = tf.Session() self.tfs = tf.placeholder(tf.float32, [None, s_dim], 'state') # critic with tf.variable_scope('critic'): l1 = tf.layers.dense(self.tfs, 100, tf.nn.relu) self.v = tf.layers.dense(l1, 1) self.tfdc_r = tf.placeholder(tf.float32, [None, 1], 'discounted_r') self.advantage = self.tfdc_r - self.v self.closs = tf.reduce_mean(tf.square(self.advantage)) self.ctrain_op = tf.train.AdamOptimizer(C_LR).minimize(self.closs) # actor pi, pi_params = self._build_anet('pi', trainable=True) oldpi, oldpi_params = self._build_anet('oldpi', trainable=False) self.sample_op = tf.squeeze(pi.sample(1), axis=0) # choosing action with tf.variable_scope('update_oldpi'): self.update_oldpi_op = [ oldp.assign(p) for p, oldp in zip(pi_params, oldpi_params) ] self.tfa = tf.placeholder(tf.float32, [None, a_dim], 'action') self.tfadv = tf.placeholder(tf.float32, [None, 1], 'advantage') with tf.variable_scope('surrogate'): # ratio = tf.exp(pi.log_prob(self.tfa) - oldpi.log_prob(self.tfa)) ratio = pi.prob(self.tfa) / oldpi.prob(self.tfa) surr = ratio * self.tfadv if METHOD['name'] == 'kl_pen': self.tflam = tf.placeholder(tf.float32, None, 'lambda') with tf.variable_scope('loss'): kl = tf.stop_gradient(kl_divergence(oldpi, pi)) self.kl_mean = tf.reduce_mean(kl) self.aloss = -(tf.reduce_mean(surr - self.tflam * kl)) else: # clipping method, find this is better with tf.variable_scope('loss'): self.aloss = -tf.reduce_mean( tf.minimum( surr, tf.clip_by_value(ratio, 1. - METHOD['epsilon'], 1. + METHOD['epsilon']) * self.tfadv)) with tf.variable_scope('atrain'): self.atrain_op = tf.train.AdamOptimizer(A_LR).minimize(self.aloss) tf.summary.FileWriter("log/", self.sess.graph) self.sess.run(tf.global_variables_initializer())
def kl_q_p(self): kl_separated = distributions.kl_divergence(self.q_dist, self.p_dist) # [bs/nbs, N] kl_minibatch = tf.reduce_mean(kl_separated, 0, keep_dims=True) # [1, N] tf.summary.scalar("true_kl", tf.reduce_sum(kl_minibatch)) if self.kl_min > 0: kl_lower_bounded = tf.maximum(kl_minibatch, self.kl_min) kl = tf.reduce_sum(kl_lower_bounded) # [], i.e., scalar else: kl = tf.reduce_sum(kl_minibatch) # [], i.e., scalar return kl
def _loop_body(t, ta_h_s, ta_z_prior, ta_z_post, ta_kl): with tf.variable_scope('iter_body', reuse=tf.AUTO_REUSE): def _init(): h_s_init = tf.nn.tanh(tf.random_normal(shape=[self.batch_size, self.h_size])) h_z_init = tf.nn.tanh(tf.random_normal(shape=[self.batch_size, self.z_size])) z_init, _ = self._z(arg=h_z_init, is_prior=False) return h_s_init, z_init def _subsequent(): h_s_t_1 = tf.reshape(ta_h_s.read(t-1), [self.batch_size, self.h_size]) z_t_1 = tf.reshape(ta_z_post.read(t-1), [self.batch_size, self.z_size]) return h_s_t_1, z_t_1 h_s_t_1, z_t_1 = tf.cond(t >= 1, _subsequent, _init) gate_args = [x[t], h_s_t_1, z_t_1] with tf.variable_scope('gru_r'): r = self._linear(gate_args, self.h_size, 'sigmoid') with tf.variable_scope('gru_u'): u = self._linear(gate_args, self.h_size, 'sigmoid') h_args = [x[t], tf.multiply(r, h_s_t_1), z_t_1] with tf.variable_scope('gru_h'): h_tilde = self._linear(h_args, self.h_size, 'tanh') h_s_t = tf.multiply(1 - u, h_s_t_1) + tf.multiply(u, h_tilde) with tf.variable_scope('h_z_prior'): h_z_prior_t = self._linear([x[t], h_s_t], self.z_size, 'tanh') with tf.variable_scope('z_prior'): z_prior_t, z_prior_t_pdf = self._z(h_z_prior_t, is_prior=True) with tf.variable_scope('h_z_post'): h_z_post_t = self._linear([x[t], h_s_t, y_[t]], self.z_size, 'tanh') with tf.variable_scope('z_post'): z_post_t, z_post_t_pdf = self._z(h_z_post_t, is_prior=False) kl_t = ds.kl_divergence(z_post_t_pdf, z_prior_t_pdf) # write ta_h_s = ta_h_s.write(t, h_s_t) ta_z_prior = ta_z_prior.write(t, z_prior_t) # write: batch_size * z_size ta_z_post = ta_z_post.write(t, z_post_t) # write: batch_size * z_size ta_kl = ta_kl.write(t, kl_t) # write: batch_size * 1 return t + 1, ta_h_s, ta_z_prior, ta_z_post, ta_kl
def _build(self, inputs, hvar_labels, n_samples=10, analytic_kl=True): datum_shape = inputs.get_shape().as_list()[1:] enc_repr = self._encoder(inputs) self.hvar_prior = tfd.ExpRelaxedOneHotCategorical( temperature=self._temperature, logits=hvar_labels) self.hvar_posterior = tfd.ExpRelaxedOneHotCategorical( temperature=self._temperature, logits=self._hvar(enc_repr)) hvar_sample_shape = [n_samples ] + self.hvar_posterior.batch_shape.as_list( ) + self.hvar_posterior.event_shape.as_list() hvar_sample = tf.reshape(self.hvar_posterior.sample(n_samples), hvar_sample_shape) self.latent_posterior = self._latent_posterior_fn( self._loc(enc_repr), self._scale(enc_repr)) latent_posterior_sample = self.latent_posterior.sample(n_samples) joint_sample = tf.concat([hvar_sample, latent_posterior_sample], axis=-1) sample_decoder = snt.BatchApply(self._decoder) self.output_distribution = tfd.Independent( tfd.Bernoulli(logits=sample_decoder(joint_sample)), reinterpreted_batch_ndims=len(datum_shape)) distortion = -self.output_distribution.log_prob(inputs) if analytic_kl and n_samples == 1: rate = tfd.kl_divergence(self.latent_posterior, self.latent_prior) else: rate = (self.latent_posterior.log_prob(latent_posterior_sample) - self.latent_prior.log_prob(latent_posterior_sample)) hrate = self.hvar_posterior.log_prob( hvar_sample) - self.hvar_prior.log_prob(hvar_sample) # hrate = tf.Print(hrate, [temperature]) # hrate = tf.Print(hrate, [hvar_sample], summarize=10) # hrate = tf.Print(hrate, [self.hvar_posterior.log_prob(hvar_sample)]) # hrate = tf.Print(hrate, [self.hvar_prior.log_prob(hvar_sample)]) # hrate = tf.Print(hrate, [hrate], summarize=10) elbo_local = -(rate + hrate + distortion) self.elbo = tf.reduce_mean(elbo_local) self.importance_weighted_elbo = tf.reduce_mean( tf.reduce_logsumexp(elbo_local, axis=0) - tf.log(tf.to_float(n_samples))) self.hvar_sample = tf.exp(tf.split(hvar_sample, n_samples)[0]) self.hvar_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2( labels=hvar_labels, logits=tf.split(hvar_sample, n_samples)[0]) self.hvar_labels = hvar_labels self.distortion = distortion self.rate = rate self.hrate = hrate
def testKL(self): mu1, sd1, mu2, sd2 = [np.random.rand(4, 6) for _ in range(4)] pair_kl = pair_kl_divergence(mu1, sd1, mu2, sd2) dist1 = distributions.Normal(mu1, sd1) dist2 = distributions.Normal(mu2, sd2) kl_tf = distributions.kl_divergence(dist1, dist2) with tf.Session() as sess: kl_val = sess.run(kl_tf) kl_val = kl_val.sum(axis=-1) self.assertAllClose(np.diag(pair_kl), kl_val)
def compute_loss(inf_mean_list, inf_var_list, gen_mean_list, gen_var_list, q_log_discrete, log_px, batch_size): gaussian_div = [] for mean0, var0, mean1, var1 in zip(inf_mean_list, inf_var_list, reversed(gen_mean_list), reversed(gen_var_list)): kl_gauss = dist.kl_divergence(dist.MultivariateNormalDiag(mean0, var0), dist.MultivariateNormalDiag(mean1, var1)) gaussian_div.append(kl_gauss) kl_gauss = tf.reshape(tf.concat(gaussian_div, axis=0), [batch_size, len(gaussian_div)]) kl_dis = dist.kl_divergence( dist.OneHotCategorical(logits=q_log_discrete), dist.OneHotCategorical( logits=tf.log(tf.ones_like(q_log_discrete) * 1 / 10))) mean_KL = tf.reduce_mean(tf.reduce_sum(kl_gauss, axis=1) + kl_dis) mean_rec = tf.reduce_mean(log_px) loss = tf.reduce_mean(log_px - 0.5 * ((tf.reduce_sum(kl_gauss, axis=1) + kl_dis))) return loss, mean_rec, mean_KL
def get_loss(self, config): self.divergence = tf.reduce_mean( tfd.kl_divergence(self.posterior, self.prior)) self.crossent = tf.contrib.seq2seq.sequence_loss( self.logits, self.targets, self.target_weights, average_across_timesteps=True, average_across_batch=True) loss = self.divergence + self.crossent #print self.divergence #print self.crossent #print loss #exit(1) return loss
def kl_divergence(distribution_a, distribution_b, average_across_latent_dim=False, average_across_batch=True): kl_div = distributions.kl_divergence(distribution_a, distribution_b) if average_across_latent_dim: kl_div = tf.reduce_mean(kl_div, axis=1) # [b] else: kl_div = tf.reduce_sum(kl_div, axis=1) # [b] if average_across_batch: kl_div = tf.reduce_mean(kl_div, axis=0) else: kl_div = tf.reduce_sum(kl_div, axis=0) return kl_div
def one_step(self, a, x): z = a[0] u, enc = x q_mean, q_var = self.q_transition(z, enc, u) p_mean, p_var = self.p_transition(z, u) q = MultivariateNormalDiag(q_mean, tf.sqrt(q_var)) p = MultivariateNormalDiag(p_mean, tf.sqrt(p_var)) z_step = q.sample() kl = kl_divergence(q, p) return z_step, kl
def kl_categorical(p=None, q=None, p_logits=None, q_logits=None, eps=1e-6): ''' Given p and q (as EITHER BOTH logits or softmax's) then this func returns the KL between them. Utilizes an eps in order to resolve divide by zero / log issues ''' if p_logits is not None and q_logits is not None: Q = distributions.Categorical(logits=q_logits, dtype=tf.float32) P = distributions.Categorical(logits=p_logits, dtype=tf.float32) elif p is not None and q is not None: print 'p shp = ', p.get_shape().as_list(), \ ' | q shp = ', q.get_shape().as_list() Q = distributions.Categorical(probs=q + eps, dtype=tf.float32) P = distributions.Categorical(probs=p + eps, dtype=tf.float32) else: raise Exception("please provide either logits or dists") return distributions.kl_divergence(P, Q)
def divergence(q, p, metric='kl', n_monte_carlo_samples=1000): """Compute divergence measure between probability distributions. Args: q,p: probability distributions metric: divergence metric n_monte_carlo_samples: number of monte carlo samples for estimate """ if metric == 'kl': return kl_divergence(q, p, allow_nan_stats=False) elif metric == 'dotproduct': samples_q = q.sample([n_monte_carlo_samples]) distance_wrt_q = tf.reduce_mean(q.prob(samples_q) - p.prob(samples_q)) samples_p = p.sample([n_monte_carlo_samples]) distance_wrt_p = tf.reduce_mean(q.prob(samples_p) - p.prob(samples_p)) return (distance_wrt_q - distance_wrt_p) elif metric == 'gradkl': raise NotImplementedError('Metric not supported %s' % metric) else: raise NotImplementedError('Metric not supported %s' % metric)
def gaussian_reparmeterization(logits_z, rnd_sample=None): ''' The vanilla gaussian reparameterization from Kingma et. al z = mu + sigma * N(0, I) ''' zshp = logits_z.get_shape().as_list() assert zshp[1] % 2 == 0 q_sigma = 1e-6 + tf.nn.softplus(logits_z[:, 0:zshp[1]/2]) q_mu = logits_z[:, zshp[1]/2:] # Prior p_z = d.Normal(loc=tf.zeros(zshp[1] / 2), scale=tf.ones(zshp[1] / 2)) with st.value_type(st.SampleValue()): q_z = st.StochasticTensor(d.Normal(loc=q_mu, scale=q_sigma)) reduce_index = [1] if len(zshp) == 2 else [1, 2] kl = d.kl_divergence(q_z.distribution, p_z, allow_nan_stats=False) return [q_z, tf.reduce_sum(kl, reduce_index)]
def plot_objective(): weights_q = [0.6, 0.4] # weights_s = gamma is what we iterate on gammas = np.arange(0., 1., 0.02) # for exact gamma mus = [2., -1., 0.] stds = [.6, .4, 0.5] # for inexact approx mus2 = [-1., 1., 0., 2.0] stds2 = [3.3, 0.9, 0.5, 0.4] g = tf.Graph() with g.as_default(): sess = tf.InteractiveSession() with sess.as_default(): comps = [ Normal(loc=tf.convert_to_tensor(mus[i], dtype=tf.float32), scale=tf.convert_to_tensor(stds[i], dtype=tf.float32)) for i in range(len(mus)) ] comps2 = [ Normal(loc=tf.convert_to_tensor(mus2[i], dtype=tf.float32), scale=tf.convert_to_tensor(stds2[i], dtype=tf.float32)) for i in range(len(mus2)) ] # p = pi[0] * N(mus[0], stds[0]) + ... + pi[2] * N(mus[2], stds[2]) weight_s = 0.5 logger.info('true gamma for exact mixture %.2f' % (weight_s)) final_weights = [(1 - weight_s) * w for w in weights_q] final_weights.append(weight_s) p = Mixture( cat=Categorical(probs=tf.convert_to_tensor(final_weights)), components=comps) objective_exact = [] objective_inexact = [] for gamma in gammas: new_weights = [(1 - gamma) * w for w in weights_q] new_weights.append(gamma) q = Mixture( cat=Categorical(probs=tf.convert_to_tensor(new_weights)), components=comps) objective = kl_divergence(q, p, allow_nan_stats=False).eval() objective_exact.append(objective) new_weights2 = [(1 - gamma) * w for w in final_weights] new_weights2.append(gamma) q2 = Mixture( cat=Categorical(probs=tf.convert_to_tensor(new_weights2)), components=comps2) objective2 = kl_divergence(q2, p, allow_nan_stats=False).eval() objective_inexact.append(objective2) logger.info( 'gamma = %.2f, D_kl_exact = %.5f, D_kl_inexact = %.5f' % (gamma, objective, objective2)) plt.plot(gammas, objective_exact, '-', color='r', linewidth=2.0, label='exact mixture') plt.plot(gammas, objective_inexact, '-', color='b', linewidth=2.0, label='inexact mixture') plt.legend() plt.xlabel('gamma') plt.ylabel('kl divergence of mixture') plt.show()
return tfd.Independent(tfd.Bernoulli(logit), 2) make_encoder = tf.make_template('encoder', make_encoder) make_decoder = tf.make_template('decoder', make_decoder) data = tf.placeholder(tf.float32, [None, input_size, 1]) prior = make_prior(code_size) posterior, loc, scale = make_encoder(data, code_size) code_all = posterior.sample(10) code = tf.reduce_mean(code_all, reduction_indices=0) likelihood = make_decoder(code, [input_size, 1]).log_prob(data) divergence = tfd.kl_divergence(posterior, prior) elbo = tf.reduce_mean(likelihood - divergence) optimize = tf.train.AdamOptimizer(lr).minimize(-elbo) samples = make_decoder(prior.sample(10), [input_size, 1]).mean() init1 = tf.global_variables_initializer() sess1 = tf.Session() sess1.run(init1) saver = tf.train.Saver() if __name__ == '__main__': for epoch in range(5):
def main(_): epoch_size = 20 logdir = './logdir' make_encoder = tf.make_template('encoder', _make_encoder) # In TensorFlow, if you call a network function twice, # it will create two separate networks. # TensorFlow templates allow you to wrap a function # so that multiple calls to it will reuse the same network parameters. make_decoder = tf.make_template('decoder', _make_decoder) data = tf.placeholder(tf.float32, [None, 28, 28]) prior = make_prior(code_size=2) posterior = make_encoder(data, code_size=2) code = posterior.sample() likelihood = make_decoder(code, [28, 28]).log_prob(data) divergence = tfd.kl_divergence(posterior, prior) elbo = tf.reduce_mean(likelihood - divergence) optimizer = tf.train.AdamOptimizer(0.001).minimize(-elbo) samples = make_decoder(prior.sample(10), [28, 28]).mean() mnist = input_data.read_data_sets('/tmp/MNIST_data/') fig, ax = plt.subplots(nrows=epoch_size, ncols=11, figsize=(10, 20)) # Merged all summaries. _summary('likelihood', likelihood) _summary('divergence', divergence) _summary('elbo', elbo) _summary('samples', samples) merged = tf.summary.merge_all() saver = tf.train.Saver() _global_step = tf.get_variable('global_step', [], dtype=tf.int32, trainable=False) global_step_op = tf.assign_add(_global_step, 1) with tf.train.MonitoredSession() as sess: writer = tf.summary.FileWriter(logdir, sess.graph) for epoch in range(epoch_size): feed = {data: mnist.test.images.reshape([-1, 28, 28])} test_elbo, test_codes, test_samples = sess.run( [elbo, code, samples], feed) test_likelihood, test_divergence = sess.run( [likelihood, divergence], feed) print('likeli {}, ') # Plot codes and samples ax[epoch, 0].set_ylabel('Epoch {}'.format(epoch)) plot_codes(ax[epoch, 0], test_codes, mnist.test.labels) plot_samples(ax[epoch, 1:], test_samples) print( '\rEpoch {}, elbo {}, labes {}, test_codes {}, test_samples {}' .format(epoch, test_elbo, mnist.test.labels.shape, test_codes.shape, test_samples.shape), end='', flush=True) for step in range(1, 600): feed = { data: mnist.train.next_batch(100)[0].reshape([-1, 28, 28]) } _, summary, global_step = sess.run( [optimizer, merged, global_step_op], feed) writer.add_summary(summary, global_step=global_step)
def f(gamma): weights = [(1 - gamma), gamma] q_l = Mixture(cat=Categorical(probs=tf.convert_to_tensor(weights)), components=[MultivariateNormalDiag(**c) for c in comps]) return kl_divergence(q_l, qt).eval()
def main(argv): del argv outdir = FLAGS.outdir if '~' in outdir: outdir = os.path.expanduser(outdir) os.makedirs(outdir, exist_ok=True) # Files to log metrics times_filename = os.path.join(outdir, 'times.csv') elbos_filename = os.path.join(outdir, 'elbos.csv') objective_filename = os.path.join(outdir, 'kl.csv') reference_filename = os.path.join(outdir, 'ref_kl.csv') step_filename = os.path.join(outdir, 'steps.csv') # 'adafw', 'ada_afw', 'ada_pfw' if FLAGS.fw_variant.startswith('ada'): curvature_filename = os.path.join(outdir, 'curvature.csv') gap_filename = os.path.join(outdir, 'gap.csv') iter_info_filename = os.path.join(outdir, 'iter_info.txt') elif FLAGS.fw_variant == 'line_search': goutdir = os.path.join(outdir, 'gradients') # empty the files present in the folder already open(times_filename, 'w').close() open(elbos_filename, 'w').close() open(objective_filename, 'w').close() open(reference_filename, 'w').close() open(step_filename, 'w').close() # 'adafw', 'ada_afw', 'ada_pfw' if FLAGS.fw_variant.startswith('ada'): open(curvature_filename, 'w').close() append_to_file(curvature_filename, "c_local,c_global") open(gap_filename, 'w').close() open(iter_info_filename, 'w').close() elif FLAGS.fw_variant == 'line_search': os.makedirs(goutdir, exist_ok=True) for i in range(FLAGS.n_fw_iter): # NOTE: First iteration (t = 0) is initialization g = tf.Graph() with g.as_default(): tf.set_random_seed(FLAGS.seed) sess = tf.InteractiveSession() with sess.as_default(): p, mus, stds = create_target_dist() # current iterate (solution until now) if FLAGS.init == 'random': muq = np.random.randn(D).astype(np.float32) stdq = softplus(np.random.randn(D).astype(np.float32)) raise ValueError else: muq = mus[0] stdq = stds[0] # 1 correct LMO t = 1 comps = [{'loc': muq, 'scale_diag': stdq}] weights = [1.0] curvature_estimate = opt.adafw_linit() qtx = MultivariateNormalDiag( loc=tf.convert_to_tensor(muq, dtype=tf.float32), scale_diag=tf.convert_to_tensor(stdq, dtype=tf.float32)) fw_iterates = {p: qtx} # calculate kl-div with 1 component objective_old = kl_divergence(qtx, p).eval() logger.info("kl with init %.4f" % (objective_old)) append_to_file(reference_filename, objective_old) # s is the solution to LMO. It is initialized randomly # mu ~ N(0, 1), std ~ softplus(N(0, 1)) s = coreutils.construct_multivariatenormaldiag([D], t, 's') sess.run(tf.global_variables_initializer()) total_time = 0 start_inference_time = time.time() if FLAGS.LMO == 'vi': # we have to iterate over parameter space raise ValueError inference = relbo.KLqp({p: s}, fw_iterates=fw_iterates, fw_iter=t) inference.run(n_iter=FLAGS.LMO_iter) # s now contains solution to LMO end_inference_time = time.time() mu_s = s.mean().eval() cov_s = s.stddev().eval() # NOTE: keep only step size time #total_time += end_inference_time - start_inference_time # compute step size to update the next iterate step_result = {} if FLAGS.fw_variant == 'fixed': gamma = 2. / (t + 2.) elif FLAGS.fw_variant == 'line_search': start_line_search_time = time.time() step_result = opt.line_search_dkl( weights, [c['loc'] for c in comps], [c['scale_diag'] for c in comps], qtx, mu_s, cov_s, s, p, t) end_line_search_time = time.time() total_time += (end_line_search_time - start_line_search_time) gamma = step_result['gamma'] elif FLAGS.fw_variant == 'adafw': start_adafw_time = time.time() step_result = opt.adaptive_fw( weights, [c['loc'] for c in comps], [c['scale_diag'] for c in comps], qtx, mu_s, cov_s, s, p, t, curvature_estimate) end_adafw_time = time.time() total_time += end_adafw_time - start_adafw_time gamma = step_result['gamma'] else: raise NotImplementedError comps.append({'loc': mu_s, 'scale_diag': cov_s}) weights = [(1. - gamma), gamma] c_global = estimate_global_curvature(comps, qtx) q_latest = Mixture( cat=Categorical(probs=tf.convert_to_tensor(weights)), components=[MultivariateNormalDiag(**c) for c in comps]) # Log metrics for current iteration time_t = float(total_time) logger.info('total time %f' % (time_t)) append_to_file(times_filename, time_t) elbo_t = elbo(q_latest, p, n_samples=1000) logger.info("iter, %d, elbo, %.2f +/- %.2f" % (t, elbo_t[0], elbo_t[1])) append_to_file(elbos_filename, "%f,%f" % (elbo_t[0], elbo_t[1])) logger.info('iter %d, gamma %.4f' % (t, gamma)) append_to_file(step_filename, gamma) objective_t = kl_divergence(q_latest, p).eval() logger.info("run %d, kl %.4f" % (i, objective_t)) append_to_file(objective_filename, objective_t) if FLAGS.fw_variant.startswith('ada'): curvature_estimate = step_result['c_estimate'] append_to_file(gap_filename, step_result['gap']) append_to_file(iter_info_filename, step_result['step_type']) logger.info('gap = %.3f, ct = %.5f, iter_type = %s' % (step_result['gap'], step_result['c_estimate'], step_result['step_type'])) append_to_file(curvature_filename, '%f,%f' % (curvature_estimate, c_global)) elif FLAGS.fw_variant == 'line_search': n_line_search_samples = step_result['n_samples'] grad_t = step_result['grad_gamma'] g_outfile = os.path.join( goutdir, 'line_search_samples_%d.npy.%d' % (n_line_search_samples, t)) logger.info('saving line search data to, %s' % g_outfile) np.save(open(g_outfile, 'wb'), grad_t) sess.close() tf.reset_default_graph()
def kl_q_p(self): return tf.reduce_mean( distributions.kl_divergence(self.q_dist, self.p_dist))
def adaptive_pfw(weights, comps, locs, diags, q_t, mu_s, cov_s, s_t, p, k, l_prev): """ Adaptive pairwise variant. Args: same as fixed """ d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval() logger.info('distance norm is %.5f' % d_t_norm) # Find v_t qcomps = q_t.components index_v_t, step_v_t = argmax_grad_dotp(p, q_t, qcomps, FLAGS.n_monte_carlo_samples) v_t = qcomps[index_v_t] # Pairwise gap sample_s = s_t.sample([FLAGS.n_monte_carlo_samples]) step_s = tf.reduce_mean(grad_kl(q_t, p, sample_s)).eval() gap_pw = step_v_t - step_s if gap_pw < 0: eprint("Pairwise gap is negative") def default_fixed_step(fail_type='fixed'): # adaptive failed, return to fixed gamma = 2. / (k + 2.) new_comps = copy.copy(comps) new_comps.append({'loc': mu_s, 'scale_diag': cov_s}) new_weights = [(1. - gamma) * w for w in weights] new_weights.append(gamma) return { 'gamma': 2. / (k + 2.), 'l_estimate': l_prev, 'weights': new_weights, 'comps': new_comps, 'gap': gap_pw, 'step_type': fail_type } logger.info('Pairwise gap %.5f' % gap_pw) # Set $q_{t+1}$'s params new_locs = copy.copy(locs) new_diags = copy.copy(diags) new_locs.append(mu_s) new_diags.append(cov_s) gap = gap_pw if gap <= 0: return default_fixed_step() gamma_max = weights[index_v_t] step_type = 'adaptive' tau = FLAGS.exp_adafw eta = FLAGS.damping_adafw pow_tau = 1.0 i, l_t = 0, l_prev f_t = kl_divergence(q_t, p, allow_nan_stats=False).eval() drop_step = False debug('f(q_t) = %.5f' % (f_t)) gamma = 2. / (k + 2) while gamma >= MIN_GAMMA and i < FLAGS.adafw_MAXITER: # compute $L_t$ and $\gamma_t$ l_t = pow_tau * eta * l_prev gamma = min(gap / (l_t * d_t_norm), gamma_max) d_1 = - gamma * gap d_2 = gamma * gamma * l_t * d_t_norm / 2. debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2)) quad_bound_rhs = f_t + d_1 + d_2 # construct $q_{t + 1}$ new_weights = copy.copy(weights) new_weights.append(gamma) if gamma == gamma_max: # hardcoding to 0 for precision issues new_weights[index_v_t] = 0 drop_step = True else: new_weights[index_v_t] -= gamma drop_step = False qt_new = Mixture( cat=Categorical(probs=tf.convert_to_tensor(new_weights)), components=[ MultivariateNormalDiag(loc=loc, scale_diag=diag) for loc, diag in zip(new_locs, new_diags) ]) quad_bound_lhs = kl_divergence(qt_new, p, allow_nan_stats=False).eval() logger.info('lt = %.5f, gamma = %.3f, f_(qt_new) = %.5f, ' 'linear extrapolated = %.5f' % (l_t, gamma, quad_bound_lhs, quad_bound_rhs)) if quad_bound_lhs <= quad_bound_rhs: new_comps = copy.copy(comps) new_comps.append({'loc': mu_s, 'scale_diag': cov_s}) if drop_step: del new_comps[index_v_t] del new_weights[index_v_t] logger.info("...drop step") step_type = 'drop' return { 'gamma': gamma, 'l_estimate': l_t, 'weights': new_weights, 'comps': new_comps, 'gap': gap, 'step_type': step_type } pow_tau *= tau i += 1 # gamma below MIN_GAMMA logger.warning("gamma below threshold value, returning fixed step") return default_fixed_step("fixed_adaptive_MAXITER")
def adaptive_fw(weights, locs, diags, q_t, mu_s, cov_s, s_t, p, k, l_prev, return_gamma=False): """Adaptive Frank-Wolfe algorithm. Sets step size as suggested in Algorithm 1 of https://arxiv.org/pdf/1806.05123.pdf Args: weights: [k], weights of the mixture components of q_t locs: [k x dim], means of mixture components of q_t diags: [k x dim], std deviations of mixture components of q_t q_t: current mixture iterate q_t mu_s: [dim], mean for LMO solution s cov_s: [dim], cov matrix for LMO solution s s_t: Current atom & LMO Solution s p: edward.model, target distribution p k: iteration number of Frank-Wolfe l_prev: previous lipschitz estimate return_gamma: only return the value of gamma Returns: If return_gamma is True, only the computed value of gamma is returned. Else returns a dictionary containing gamma, lipschitz estimate, duality gap and step information """ # Set $q_{t+1}$'s params new_locs = copy.copy(locs) new_diags = copy.copy(diags) new_locs.append(mu_s) new_diags.append(cov_s) d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval() logger.info('distance norm is %.5f' % d_t_norm) N_samples = FLAGS.n_monte_carlo_samples # create and sample from $s_t, q_t$ sample_q = q_t.sample([N_samples]) sample_s = s_t.sample([N_samples]) step_s = tf.reduce_mean(grad_kl(q_t, p, sample_s)).eval() step_q = tf.reduce_mean(grad_kl(q_t, p, sample_q)).eval() gap = step_q - step_s logger.info('duality gap %.5f' % gap) if gap < 0: logger.warning("Duality gap is negative returning 0 step") #gamma = 2. / (k + 2.) gamma = 0. tau = FLAGS.exp_adafw eta = FLAGS.damping_adafw # did the adaptive loop suceed or not step_type = "fixed" # NOTE: this is from v1 of the paper, new version # replaces multiplicative tau with divisor eta pow_tau = 1.0 i, l_t = 0, l_prev f_t = kl_divergence(q_t, p, allow_nan_stats=False).eval() debug('f(q_t) = %.5f' % (f_t)) # return intial estimate if gap is -ve while gap >= 0: # compute $L_t$ and $\gamma_t$ l_t = pow_tau * eta * l_prev gamma = min(gap / (l_t * d_t_norm), 1.0) d_1 = - gamma * gap d_2 = gamma * gamma * l_t * d_t_norm / 2. debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2)) quad_bound_rhs = f_t + d_1 + d_2 # $w_{t + 1} = [(1 - \gamma)w_t, \gamma]$ new_weights = copy.copy(weights) new_weights = [(1. - gamma) * w for w in new_weights] new_weights.append(gamma) qt_new = Mixture( cat=Categorical(probs=tf.convert_to_tensor(new_weights)), components=[ MultivariateNormalDiag(loc=loc, scale_diag=diag) for loc, diag in zip(new_locs, new_diags) ]) quad_bound_lhs = kl_divergence(qt_new, p, allow_nan_stats=False).eval() logger.info('lt = %.5f, gamma = %.3f, f_(qt_new) = %.5f, ' 'linear extrapolated = %.5f' % (l_t, gamma, quad_bound_lhs, quad_bound_rhs)) if quad_bound_lhs <= quad_bound_rhs: step_type = "adaptive" break pow_tau *= tau i += 1 #if i > FLAGS.adafw_MAXITER or gamma < MIN_GAMMA: if i > FLAGS.adafw_MAXITER: # estimate not good #gamma = 2. / (k + 2.) gamma = 0. l_t = l_prev step_type = "fixed_adaptive_MAXITER" break if return_gamma: return gamma return { 'gamma': gamma, 'l_estimate': l_t, 'gap': gap, 'step_type': step_type }
def adaptive_afw(weights, comps, locs, diags, q_t, mu_s, cov_s, s_t, p, k, l_prev): """ Away steps variant Args: same as fixed """ d_t_norm = divergence(s_t, q_t, metric=FLAGS.distance_metric).eval() logger.info('distance norm is %.5f' % d_t_norm) # Find v_t qcomps = q_t.components index_v_t, step_v_t = argmax_grad_dotp(p, q_t, qcomps, FLAGS.n_monte_carlo_samples) v_t = qcomps[index_v_t] # Frank-Wolfe gap sample_q = q_t.sample([FLAGS.n_monte_carlo_samples]) sample_s = s_t.sample([FLAGS.n_monte_carlo_samples]) step_s = tf.reduce_mean(grad_kl(q_t, p, sample_s)).eval() step_q = tf.reduce_mean(grad_kl(q_t, p, sample_q)).eval() gap_fw = step_q - step_s if gap_fw < 0: logger.warning("Frank-Wolfe duality gap is negative") # Away gap gap_a = step_v_t - step_q if gap_a < 0: eprint('Away gap < 0!!!') logger.info('fw gap %.5f, away gap %.5f' % (gap_fw, gap_a)) # Set $q_{t+1}$'s params new_locs = copy.copy(locs) new_diags = copy.copy(diags) if (gap_fw >= gap_a) or (len(comps) == 1): # FW direction, proceeds exactly as adafw logger.info('Proceeding in FW direction ') adaptive_step_type = 'fw' gap = gap_fw new_locs.append(mu_s) new_diags.append(cov_s) gamma_max = 1.0 else: # Away direction logger.info('Proceeding in Away direction ') adaptive_step_type = 'away' gap = gap_a if weights[index_v_t] < 1.0: gamma_max = weights[index_v_t] / (1.0 - weights[index_v_t]) else: gamma_max = 100. # Large value when t = 1 def default_fixed_step(fail_type='fixed'): # adaptive failed, return to fixed gamma = 2. / (k + 2.) new_comps = copy.copy(comps) new_comps.append({'loc': mu_s, 'scale_diag': cov_s}) new_weights = [(1. - gamma) * w for w in weights] new_weights.append(gamma) return { 'gamma': 2. / (k + 2.), 'l_estimate': l_prev, 'weights': new_weights, 'comps': new_comps, 'gap': gap, 'step_type': fail_type } if gap <= 0: return default_fixed_step() tau = FLAGS.exp_adafw eta = FLAGS.damping_adafw pow_tau = 1.0 i, l_t = 0, l_prev f_t = kl_divergence(q_t, p, allow_nan_stats=False).eval() debug('f(q_t) = %.5f' % (f_t)) gamma = 2. / (k + 2) is_drop_step = False while gamma >= MIN_GAMMA and i < FLAGS.adafw_MAXITER: # compute $L_t$ and $\gamma_t$ l_t = pow_tau * eta * l_prev # NOTE: Handle extreme values of gamma carefully gamma = min(gap / (l_t * d_t_norm), gamma_max) d_1 = - gamma * gap d_2 = gamma * gamma * l_t * d_t_norm / 2. debug('linear d1 = %.5f, quad d2 = %.5f' % (d_1, d_2)) quad_bound_rhs = f_t + d_1 + d_2 # construct $q_{t + 1}$ if adaptive_step_type == 'fw': if gamma == gamma_max: # gamma = 1.0, q_{t + 1} = s_t new_comps = [{'loc': mu_s, 'scale_diag': cov_s}] new_weights = [1.] qt_new = MultivariateNormalDiag(loc=mu_s, scale_diag=cov_s) else: new_comps = copy.copy(comps) new_comps.append({'loc': mu_s, 'scale_diag': cov_s}) new_weights = copy.copy(weights) new_weights = [(1. - gamma) * w for w in new_weights] new_weights.append(gamma) qt_new = Mixture( cat=Categorical(probs=tf.convert_to_tensor(new_weights)), components=[ MultivariateNormalDiag(loc=loc, scale_diag=diag) for loc, diag in zip(new_locs, new_diags) ]) elif adaptive_step_type == 'away': new_weights = copy.copy(weights) new_comps = copy.copy(comps) if gamma == gamma_max: # drop v_t is_drop_step = True logger.info('...drop step') del new_weights[index_v_t] new_weights = [(1. + gamma) * w for w in new_weights] del new_comps[index_v_t] # NOTE: recompute locs and diags after dropping v_t drop_locs = [c['loc'] for c in new_comps] drop_diags = [c['scale_diag'] for c in new_comps] qt_new = Mixture( cat=Categorical(probs=tf.convert_to_tensor(new_weights)), components=[ MultivariateNormalDiag(loc=loc, scale_diag=diag) for loc, diag in zip(drop_locs, drop_diags) ]) else: is_drop_step = False new_weights = [(1. + gamma) * w for w in new_weights] new_weights[index_v_t] -= gamma qt_new = Mixture( cat=Categorical(probs=tf.convert_to_tensor(new_weights)), components=[ MultivariateNormalDiag(loc=loc, scale_diag=diag) for loc, diag in zip(new_locs, new_diags) ]) quad_bound_lhs = kl_divergence(qt_new, p, allow_nan_stats=False).eval() logger.info('lt = %.5f, gamma = %.3f, f_(qt_new) = %.5f, ' 'linear extrapolated = %.5f' % (l_t, gamma, quad_bound_lhs, quad_bound_rhs)) if quad_bound_lhs <= quad_bound_rhs: step_type = "adaptive" if adaptive_step_type == "away": step_type = "away" if is_drop_step: step_type = "drop" return { 'gamma': gamma, 'l_estimate': l_t, 'weights': new_weights, 'comps': new_comps, 'gap': gap, 'step_type': step_type } pow_tau *= tau i += 1 # adaptive loop failed, return fixed step size logger.warning("gamma below threshold value, returning fixed step") return default_fixed_step()