def test_tensor(self): with self.test_session(): a = Normal(mu=0.0, sigma=1.0) b = tf.constant(2.0) c = a + b d = Normal(mu=c, sigma=1.0) self.assertEqual(get_descendants(a), [d]) self.assertEqual(get_descendants(b), [d]) self.assertEqual(get_descendants(c), [d]) self.assertEqual(get_descendants(d), [])
def test_chain_structure(self): with self.test_session(): a = Normal(mu=0.0, sigma=1.0) b = Normal(mu=a, sigma=1.0) c = Normal(mu=b, sigma=1.0) d = Normal(mu=c, sigma=1.0) e = Normal(mu=d, sigma=1.0) self.assertEqual(set(get_descendants(a)), set([b, c, d, e])) self.assertEqual(set(get_descendants(b)), set([c, d, e])) self.assertEqual(set(get_descendants(c)), set([d, e])) self.assertEqual(get_descendants(d), [e]) self.assertEqual(get_descendants(e), [])
def test_v_structure(self): with self.test_session(): a = Normal(mu=0.0, sigma=1.0) b = Normal(mu=a, sigma=1.0) c = Normal(mu=0.0, sigma=1.0) d = Normal(mu=c, sigma=1.0) e = Normal(mu=tf.multiply(b, d), sigma=1.0) self.assertEqual(set(get_descendants(a)), set([b, e])) self.assertEqual(get_descendants(b), [e]) self.assertEqual(set(get_descendants(c)), set([d, e])) self.assertEqual(get_descendants(d), [e]) self.assertEqual(get_descendants(e), [])
def test_control_flow(self): with self.test_session(): a = Bernoulli(p=0.5) b = Normal(mu=0.0, sigma=1.0) c = tf.constant(0.0) d = tf.cond(tf.cast(a, tf.bool), lambda: b, lambda: c) e = Normal(mu=d, sigma=1.0) self.assertEqual(get_descendants(a), [e]) self.assertEqual(get_descendants(b), [e]) self.assertEqual(get_descendants(c), [e]) self.assertEqual(get_descendants(d), [e]) self.assertEqual(get_descendants(e), [])
def test_chain_structure(self): """a -> b -> c -> d -> e""" with self.test_session(): a = Normal(0.0, 1.0) b = Normal(a, 1.0) c = Normal(b, 1.0) d = Normal(c, 1.0) e = Normal(d, 1.0) self.assertEqual(set(get_descendants(a)), set([b, c, d, e])) self.assertEqual(set(get_descendants(b)), set([c, d, e])) self.assertEqual(set(get_descendants(c)), set([d, e])) self.assertEqual(get_descendants(d), [e]) self.assertEqual(get_descendants(e), [])
def test_a_structure(self): """e <- d <- a -> b -> c""" with self.test_session(): a = Normal(0.0, 1.0) b = Normal(a, 1.0) c = Normal(b, 1.0) d = Normal(a, 1.0) e = Normal(d, 1.0) self.assertEqual(set(get_descendants(a)), set([b, c, d, e])) self.assertEqual(get_descendants(b), [c]) self.assertEqual(get_descendants(c), []) self.assertEqual(get_descendants(d), [e]) self.assertEqual(get_descendants(e), [])
def test_v_structure(self): """a -> b -> e <- d <- c""" with self.test_session(): a = Normal(0.0, 1.0) b = Normal(a, 1.0) c = Normal(0.0, 1.0) d = Normal(c, 1.0) e = Normal(b * d, 1.0) self.assertEqual(set(get_descendants(a)), set([b, e])) self.assertEqual(get_descendants(b), [e]) self.assertEqual(set(get_descendants(c)), set([d, e])) self.assertEqual(get_descendants(d), [e]) self.assertEqual(get_descendants(e), [])
def test_scan(self): """copied from test_chain_structure""" def cumsum(x): return tf.scan(lambda a, x: a + x, x) with self.test_session(): a = Normal(tf.ones([3]), tf.ones([3])) b = Normal(cumsum(a), tf.ones([3])) c = Normal(cumsum(b), tf.ones([3])) d = Normal(cumsum(c), tf.ones([3])) e = Normal(cumsum(d), tf.ones([3])) self.assertEqual(set(get_descendants(a)), set([b, c, d, e])) self.assertEqual(set(get_descendants(b)), set([c, d, e])) self.assertEqual(set(get_descendants(c)), set([d, e])) self.assertEqual(get_descendants(d), [e]) self.assertEqual(get_descendants(e), [])
def build_score_entropy_loss_and_gradients(inference, var_list): """Build loss function and gradients based on the score function estimator [@paisley2012variational]. It assumes the entropy is analytic. Computed by sampling from $q(z;\lambda)$ and evaluating the expectation using Monte Carlo sampling. """ p_log_prob = [0.0] * inference.n_samples q_log_prob = [0.0] * inference.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(inference.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("sample") dict_swap = {} for x, qx in six.iteritems(inference.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx for z, qz in six.iteritems(inference.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) dict_swap[z] = qz_copy.value() q_log_prob[s] += tf.reduce_sum( inference.scale.get(z, 1.0) * qz_copy.log_prob(tf.stop_gradient(dict_swap[z]))) for z in six.iterkeys(inference.latent_vars): z_copy = copy(z, dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum( inference.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) for x in six.iterkeys(inference.data): if isinstance(x, RandomVariable): x_copy = copy(x, dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum( inference.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) p_log_prob = tf.stack(p_log_prob) q_log_prob = tf.stack(q_log_prob) q_entropy = tf.reduce_sum([ tf.reduce_sum(qz.entropy()) for z, qz in six.iteritems(inference.latent_vars)]) reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) if inference.logging: tf.summary.scalar("loss/p_log_prob", tf.reduce_mean(p_log_prob), collections=[inference._summary_key]) tf.summary.scalar("loss/q_log_prob", tf.reduce_mean(q_log_prob), collections=[inference._summary_key]) tf.summary.scalar("loss/q_entropy", q_entropy, collections=[inference._summary_key]) tf.summary.scalar("loss/reg_penalty", reg_penalty, collections=[inference._summary_key]) loss = -(tf.reduce_mean(p_log_prob) + q_entropy - reg_penalty) q_rvs = list(six.itervalues(inference.latent_vars)) q_vars = [v for v in var_list if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0] q_grads = tf.gradients( -(tf.reduce_mean(q_log_prob * tf.stop_gradient(p_log_prob)) + q_entropy - reg_penalty), q_vars) p_vars = [v for v in var_list if v not in q_vars] p_grads = tf.gradients(loss, p_vars) grads_and_vars = list(zip(q_grads, q_vars)) + list(zip(p_grads, p_vars)) return loss, grads_and_vars
def build_score_rb_loss_and_gradients(inference, var_list): """Build loss function and gradients based on the score function estimator [@paisley2012variational] and Rao-Blackwellization [@ranganath2014black]. Computed by sampling from :math:`q(z;\lambda)` and evaluating the expectation using Monte Carlo sampling and Rao-Blackwellization. """ # Build tensors for loss and gradient calculations. There is one set # for each sample from the variational distribution. p_log_probs = [{}] * inference.n_samples q_log_probs = [{}] * inference.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(inference.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("sample") dict_swap = {} for x, qx in six.iteritems(inference.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx for z, qz in six.iteritems(inference.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) dict_swap[z] = qz_copy.value() q_log_probs[s][qz] = tf.reduce_sum( inference.scale.get(z, 1.0) * qz_copy.log_prob(tf.stop_gradient(dict_swap[z]))) for z in six.iterkeys(inference.latent_vars): z_copy = copy(z, dict_swap, scope=scope) p_log_probs[s][z] = tf.reduce_sum( inference.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) for x in six.iterkeys(inference.data): if isinstance(x, RandomVariable): x_copy = copy(x, dict_swap, scope=scope) p_log_probs[s][x] = tf.reduce_sum( inference.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) # Take gradients of Rao-Blackwellized loss for each variational parameter. p_rvs = list(six.iterkeys(inference.latent_vars)) + \ [x for x in six.iterkeys(inference.data) if isinstance(x, RandomVariable)] q_rvs = list(six.itervalues(inference.latent_vars)) reverse_latent_vars = {v: k for k, v in six.iteritems(inference.latent_vars)} grads = [] grads_vars = [] for var in var_list: # Get all variational factors depending on the parameter. descendants = get_descendants(tf.convert_to_tensor(var), q_rvs) if len(descendants) == 0: continue # skip if not a variational parameter # Get p and q's Markov blanket wrt these latent variables. var_p_rvs = set() for qz in descendants: z = reverse_latent_vars[qz] var_p_rvs.update(z.get_blanket(p_rvs) + [z]) var_q_rvs = set() for qz in descendants: var_q_rvs.update(qz.get_blanket(q_rvs) + [qz]) pi_log_prob = [0.0] * inference.n_samples qi_log_prob = [0.0] * inference.n_samples for s in range(inference.n_samples): pi_log_prob[s] = tf.reduce_sum([p_log_probs[s][rv] for rv in var_p_rvs]) qi_log_prob[s] = tf.reduce_sum([q_log_probs[s][rv] for rv in var_q_rvs]) pi_log_prob = tf.stack(pi_log_prob) qi_log_prob = tf.stack(qi_log_prob) grad = tf.gradients( -tf.reduce_mean(qi_log_prob * tf.stop_gradient(pi_log_prob - qi_log_prob)) + tf.reduce_sum(tf.losses.get_regularization_losses()), var) grads.extend(grad) grads_vars.append(var) # Take gradients of total loss function for model parameters. loss = -(tf.reduce_mean([tf.reduce_sum(list(six.itervalues(p_log_prob))) for p_log_prob in p_log_probs]) - tf.reduce_mean([tf.reduce_sum(list(six.itervalues(q_log_prob))) for q_log_prob in q_log_probs]) - tf.reduce_sum(tf.losses.get_regularization_losses())) model_vars = [v for v in var_list if v not in grads_vars] model_grads = tf.gradients(loss, model_vars) grads.extend(model_grads) grads_vars.extend(model_vars) grads_and_vars = list(zip(grads, grads_vars)) return loss, grads_and_vars
def build_loss_and_gradients(self, var_list): p_log_prob = [0.0] * self.n_samples q_log_prob = [0.0] * self.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(self.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("q_sample") dict_swap = {} for x, qx in six.iteritems(self.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx # Sample z ~ q(z), then compute log p(x, z). q_dict_swap = dict_swap.copy() for z, qz in six.iteritems(self.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) q_dict_swap[z] = qz_copy.value() if self.phase_q != 'sleep': # If not sleep phase, compute log q(z). q_log_prob[s] += tf.reduce_sum( self.scale.get(z, 1.0) * qz_copy.log_prob(tf.stop_gradient(q_dict_swap[z]))) for z in six.iterkeys(self.latent_vars): z_copy = copy(z, q_dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum( self.scale.get(z, 1.0) * z_copy.log_prob(q_dict_swap[z])) for x in six.iterkeys(self.data): if isinstance(x, RandomVariable): x_copy = copy(x, q_dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum( self.scale.get(x, 1.0) * x_copy.log_prob(q_dict_swap[x])) if self.phase_q == 'sleep': # Sample z ~ p(z), then compute log q(z). scope = base_scope + tf.get_default_graph().unique_name("p_sample") p_dict_swap = dict_swap.copy() for z, qz in six.iteritems(self.latent_vars): # Copy p(z) to obtain new set of prior samples. z_copy = copy(z, scope=scope) p_dict_swap[qz] = z_copy.value() for qz in six.itervalues(self.latent_vars): qz_copy = copy(qz, p_dict_swap, scope=scope) q_log_prob[s] += tf.reduce_sum( self.scale.get(z, 1.0) * qz_copy.log_prob(tf.stop_gradient(p_dict_swap[qz]))) p_log_prob = tf.reduce_mean(p_log_prob) q_log_prob = tf.reduce_mean(q_log_prob) reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) if self.logging: tf.summary.scalar("loss/p_log_prob", p_log_prob, collections=[self._summary_key]) tf.summary.scalar("loss/q_log_prob", q_log_prob, collections=[self._summary_key]) tf.summary.scalar("loss/reg_penalty", reg_penalty, collections=[self._summary_key]) loss_p = -p_log_prob + reg_penalty loss_q = -q_log_prob + reg_penalty q_rvs = list(six.itervalues(self.latent_vars)) q_vars = [v for v in var_list if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0] q_grads = tf.gradients(loss_q, q_vars) p_vars = [v for v in var_list if v not in q_vars] p_grads = tf.gradients(loss_p, p_vars) grads_and_vars = list(zip(q_grads, q_vars)) + list(zip(p_grads, p_vars)) return loss_p, grads_and_vars
def build_score_kl_loss_and_gradients(inference, var_list): """Build loss function and gradients based on the score function estimator [@paisley2012variational]. It assumes the KL is analytic. Computed by sampling from $q(z;\lambda)$ and evaluating the expectation using Monte Carlo sampling. """ p_log_lik = [0.0] * inference.n_samples q_log_prob = [0.0] * inference.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(inference.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("sample") dict_swap = {} for x, qx in six.iteritems(inference.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx for z, qz in six.iteritems(inference.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) dict_swap[z] = qz_copy.value() q_log_prob[s] += tf.reduce_sum( inference.scale.get(z, 1.0) * qz_copy.log_prob(tf.stop_gradient(dict_swap[z]))) for x in six.iterkeys(inference.data): if isinstance(x, RandomVariable): x_copy = copy(x, dict_swap, scope=scope) p_log_lik[s] += tf.reduce_sum( inference.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) p_log_lik = tf.stack(p_log_lik) q_log_prob = tf.stack(q_log_prob) kl_penalty = tf.reduce_sum([ tf.reduce_sum(inference.kl_scaling.get(z, 1.0) * kl_divergence(qz, z)) for z, qz in six.iteritems(inference.latent_vars) ]) reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) if inference.logging: tf.summary.scalar("loss/p_log_lik", tf.reduce_mean(p_log_lik), collections=[inference._summary_key]) tf.summary.scalar("loss/kl_penalty", kl_penalty, collections=[inference._summary_key]) tf.summary.scalar("loss/reg_penalty", reg_penalty, collections=[inference._summary_key]) loss = -(tf.reduce_mean(p_log_lik) - kl_penalty - reg_penalty) q_rvs = list(six.itervalues(inference.latent_vars)) q_vars = [ v for v in var_list if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0 ] q_grads = tf.gradients( -(tf.reduce_mean(q_log_prob * tf.stop_gradient(p_log_lik)) - kl_penalty - reg_penalty), q_vars) p_vars = [v for v in var_list if v not in q_vars] p_grads = tf.gradients(loss, p_vars) grads_and_vars = list(zip(q_grads, q_vars)) + list(zip(p_grads, p_vars)) return loss, grads_and_vars
def build_loss_and_gradients(self, var_list): """Build loss function $\\text{KL}( p(z \mid x) \| q(z) ) = \mathbb{E}_{p(z \mid x)} [ \log p(z \mid x) - \log q(z; \lambda) ]$ and stochastic gradients based on importance sampling. The loss function can be estimated as $\sum_{s=1}^S [ w_{\\text{norm}}(z^s; \lambda) (\log p(x, z^s) - \log q(z^s; \lambda) ],$ where for $z^s \sim q(z; \lambda)$, $w_{\\text{norm}}(z^s; \lambda) = w(z^s; \lambda) / \sum_{s=1}^S w(z^s; \lambda)$ normalizes the importance weights, $w(z^s; \lambda) = p(x, z^s) / q(z^s; \lambda)$. This provides a gradient, $- \sum_{s=1}^S [ w_{\\text{norm}}(z^s; \lambda) \\nabla_{\lambda} \log q(z^s; \lambda) ].$ """ p_log_prob = [0.0] * self.n_samples q_log_prob = [0.0] * self.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(self.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("sample") dict_swap = {} for x, qx in six.iteritems(self.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx for z, qz in six.iteritems(self.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) dict_swap[z] = qz_copy.value() q_log_prob[s] += tf.reduce_sum( qz_copy.log_prob(tf.stop_gradient(dict_swap[z]))) for z in six.iterkeys(self.latent_vars): z_copy = copy(z, dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum(z_copy.log_prob(dict_swap[z])) for x in six.iterkeys(self.data): if isinstance(x, RandomVariable): x_copy = copy(x, dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum(x_copy.log_prob(dict_swap[x])) p_log_prob = tf.stack(p_log_prob) q_log_prob = tf.stack(q_log_prob) reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) if self.logging: tf.summary.scalar("loss/p_log_prob", tf.reduce_mean(p_log_prob), collections=[self._summary_key]) tf.summary.scalar("loss/q_log_prob", tf.reduce_mean(q_log_prob), collections=[self._summary_key]) tf.summary.scalar("loss/reg_penalty", reg_penalty, collections=[self._summary_key]) log_w = p_log_prob - q_log_prob log_w_norm = log_w - tf.reduce_logsumexp(log_w) w_norm = tf.exp(log_w_norm) loss = tf.reduce_sum(w_norm * log_w) - reg_penalty q_rvs = list(six.itervalues(self.latent_vars)) q_vars = [v for v in var_list if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0] q_grads = tf.gradients( -(tf.reduce_sum(q_log_prob * tf.stop_gradient(w_norm)) - reg_penalty), q_vars) p_vars = [v for v in var_list if v not in q_vars] p_grads = tf.gradients(-loss, p_vars) grads_and_vars = list(zip(q_grads, q_vars)) + list(zip(p_grads, p_vars)) return loss, grads_and_vars
def build_score_rb_loss_and_gradients(inference, var_list): """Build loss function and gradients based on the score function estimator [@paisley2012variational] and Rao-Blackwellization [@ranganath2014black]. Computed by sampling from :math:`q(z;\lambda)` and evaluating the expectation using Monte Carlo sampling and Rao-Blackwellization. """ # Build tensors for loss and gradient calculations. There is one set # for each sample from the variational distribution. p_log_probs = [{}] * inference.n_samples q_log_probs = [{}] * inference.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(inference.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("sample") dict_swap = {} for x, qx in six.iteritems(inference.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx for z, qz in six.iteritems(inference.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) dict_swap[z] = qz_copy.value() q_log_probs[s][qz] = tf.reduce_sum( inference.scale.get(z, 1.0) * qz_copy.log_prob(tf.stop_gradient(dict_swap[z]))) for z in six.iterkeys(inference.latent_vars): z_copy = copy(z, dict_swap, scope=scope) p_log_probs[s][z] = tf.reduce_sum( inference.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) for x in six.iterkeys(inference.data): if isinstance(x, RandomVariable): x_copy = copy(x, dict_swap, scope=scope) p_log_probs[s][x] = tf.reduce_sum( inference.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) # Take gradients of Rao-Blackwellized loss for each variational parameter. p_rvs = list(six.iterkeys(inference.latent_vars)) + \ [x for x in six.iterkeys(inference.data) if isinstance(x, RandomVariable)] q_rvs = list(six.itervalues(inference.latent_vars)) reverse_latent_vars = { v: k for k, v in six.iteritems(inference.latent_vars) } grads = [] grads_vars = [] for var in var_list: # Get all variational factors depending on the parameter. descendants = get_descendants(tf.convert_to_tensor(var), q_rvs) if len(descendants) == 0: continue # skip if not a variational parameter # Get p and q's Markov blanket wrt these latent variables. var_p_rvs = set() for qz in descendants: z = reverse_latent_vars[qz] var_p_rvs.update(z.get_blanket(p_rvs) + [z]) var_q_rvs = set() for qz in descendants: var_q_rvs.update(qz.get_blanket(q_rvs) + [qz]) pi_log_prob = [0.0] * inference.n_samples qi_log_prob = [0.0] * inference.n_samples for s in range(inference.n_samples): pi_log_prob[s] = tf.reduce_sum( [p_log_probs[s][rv] for rv in var_p_rvs]) qi_log_prob[s] = tf.reduce_sum( [q_log_probs[s][rv] for rv in var_q_rvs]) pi_log_prob = tf.stack(pi_log_prob) qi_log_prob = tf.stack(qi_log_prob) grad = tf.gradients( -tf.reduce_mean( qi_log_prob * tf.stop_gradient(pi_log_prob - qi_log_prob)) + tf.reduce_sum(tf.losses.get_regularization_losses()), var) grads.extend(grad) grads_vars.append(var) # Take gradients of total loss function for model parameters. loss = -(tf.reduce_mean([ tf.reduce_sum(list(six.itervalues(p_log_prob))) for p_log_prob in p_log_probs ]) - tf.reduce_mean([ tf.reduce_sum(list(six.itervalues(q_log_prob))) for q_log_prob in q_log_probs ]) - tf.reduce_sum(tf.losses.get_regularization_losses())) model_vars = [v for v in var_list if v not in grads_vars] model_grads = tf.gradients(loss, model_vars) grads.extend(model_grads) grads_vars.extend(model_vars) grads_and_vars = list(zip(grads, grads_vars)) return loss, grads_and_vars
def build_loss_and_gradients(self, var_list): """Build loss function $\\text{KL}( p(z \mid x) \| q(z) ) = \mathbb{E}_{p(z \mid x)} [ \log p(z \mid x) - \log q(z; \lambda) ]$ and stochastic gradients based on importance sampling. The loss function can be estimated as $\sum_{s=1}^S [ w_{\\text{norm}}(z^s; \lambda) (\log p(x, z^s) - \log q(z^s; \lambda) ],$ where for $z^s \sim q(z; \lambda)$, $w_{\\text{norm}}(z^s; \lambda) = w(z^s; \lambda) / \sum_{s=1}^S w(z^s; \lambda)$ normalizes the importance weights, $w(z^s; \lambda) = p(x, z^s) / q(z^s; \lambda)$. This provides a gradient, $- \sum_{s=1}^S [ w_{\\text{norm}}(z^s; \lambda) \\nabla_{\lambda} \log q(z^s; \lambda) ].$ """ p_log_prob = [0.0] * self.n_samples q_log_prob = [0.0] * self.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(self.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("sample") dict_swap = {} for x, qx in six.iteritems(self.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx for z, qz in six.iteritems(self.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) dict_swap[z] = qz_copy.value() q_log_prob[s] += tf.reduce_sum( qz_copy.log_prob(tf.stop_gradient(dict_swap[z]))) for z in six.iterkeys(self.latent_vars): z_copy = copy(z, dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum(z_copy.log_prob(dict_swap[z])) for x in six.iterkeys(self.data): if isinstance(x, RandomVariable): x_copy = copy(x, dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum( x_copy.log_prob(dict_swap[x])) p_log_prob = tf.stack(p_log_prob) q_log_prob = tf.stack(q_log_prob) reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) if self.logging: tf.summary.scalar("loss/p_log_prob", tf.reduce_mean(p_log_prob), collections=[self._summary_key]) tf.summary.scalar("loss/q_log_prob", tf.reduce_mean(q_log_prob), collections=[self._summary_key]) tf.summary.scalar("loss/reg_penalty", reg_penalty, collections=[self._summary_key]) log_w = p_log_prob - q_log_prob log_w_norm = log_w - tf.reduce_logsumexp(log_w) w_norm = tf.exp(log_w_norm) loss = tf.reduce_sum(w_norm * log_w) - reg_penalty q_rvs = list(six.itervalues(self.latent_vars)) q_vars = [ v for v in var_list if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0 ] q_grads = tf.gradients( -(tf.reduce_sum(q_log_prob * tf.stop_gradient(w_norm)) - reg_penalty), q_vars) p_vars = [v for v in var_list if v not in q_vars] p_grads = tf.gradients(-loss, p_vars) grads_and_vars = list(zip(q_grads, q_vars)) + list(zip( p_grads, p_vars)) return loss, grads_and_vars
def build_loss_and_gradients(self, var_list): p_log_prob = [0.0] * self.n_samples q_log_prob = [0.0] * self.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(self.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("q_sample") dict_swap = {} for x, qx in six.iteritems(self.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx # Sample z ~ q(z), then compute log p(x, z). q_dict_swap = dict_swap.copy() for z, qz in six.iteritems(self.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) q_dict_swap[z] = qz_copy.value() if self.phase_q != 'sleep': # If not sleep phase, compute log q(z). q_log_prob[s] += tf.reduce_sum( self.scale.get(z, 1.0) * qz_copy.log_prob(tf.stop_gradient(q_dict_swap[z]))) for z in six.iterkeys(self.latent_vars): z_copy = copy(z, q_dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum( self.scale.get(z, 1.0) * z_copy.log_prob(q_dict_swap[z])) for x in six.iterkeys(self.data): if isinstance(x, RandomVariable): x_copy = copy(x, q_dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum( self.scale.get(x, 1.0) * x_copy.log_prob(q_dict_swap[x])) if self.phase_q == 'sleep': # Sample z ~ p(z), then compute log q(z). scope = base_scope + tf.get_default_graph().unique_name( "p_sample") p_dict_swap = dict_swap.copy() for z, qz in six.iteritems(self.latent_vars): # Copy p(z) to obtain new set of prior samples. z_copy = copy(z, scope=scope) p_dict_swap[qz] = z_copy.value() for qz in six.itervalues(self.latent_vars): qz_copy = copy(qz, p_dict_swap, scope=scope) q_log_prob[s] += tf.reduce_sum( self.scale.get(z, 1.0) * qz_copy.log_prob(tf.stop_gradient(p_dict_swap[qz]))) p_log_prob = tf.reduce_mean(p_log_prob) q_log_prob = tf.reduce_mean(q_log_prob) if self.logging: tf.summary.scalar("loss/p_log_prob", p_log_prob, collections=[self._summary_key]) tf.summary.scalar("loss/q_log_prob", q_log_prob, collections=[self._summary_key]) loss_p = -p_log_prob loss_q = -q_log_prob q_rvs = list(six.itervalues(self.latent_vars)) q_vars = [ v for v in var_list if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0 ] q_grads = tf.gradients(loss_q, q_vars) p_vars = [v for v in var_list if v not in q_vars] p_grads = tf.gradients(loss_p, p_vars) grads_and_vars = list(zip(q_grads, q_vars)) + list(zip( p_grads, p_vars)) return loss_p, grads_and_vars
def build_loss_and_gradients(self, var_list): p_log_prob = [0.0] * self.n_samples q_log_prob = [0.0] * self.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(self.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("sample") dict_swap = {} for x, qx in six.iteritems(self.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx for z, qz in six.iteritems(self.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) dict_swap[z] = qz_copy.value() q_log_prob[s] += tf.reduce_sum( qz_copy.log_prob(tf.stop_gradient(dict_swap[z]))) for z in six.iterkeys(self.latent_vars): z_copy = copy(z, dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum(z_copy.log_prob(dict_swap[z])) for x in six.iterkeys(self.data): if isinstance(x, RandomVariable): x_copy = copy(x, dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum(x_copy.log_prob(dict_swap[x])) p_log_prob = tf.stack(p_log_prob) q_log_prob = tf.stack(q_log_prob) reg_penalty = tf.reduce_sum(tf.losses.get_regularization_losses()) if self.logging: tf.summary.scalar("loss/p_log_prob", tf.reduce_mean(p_log_prob), collections=[self._summary_key]) tf.summary.scalar("loss/q_log_prob", tf.reduce_mean(q_log_prob), collections=[self._summary_key]) tf.summary.scalar("loss/reg_penalty", reg_penalty, collections=[self._summary_key]) log_w = p_log_prob - q_log_prob log_w_norm = log_w - tf.reduce_logsumexp(log_w) w_norm = tf.exp(log_w_norm) loss = tf.reduce_sum(w_norm * log_w) - reg_penalty q_rvs = list(six.itervalues(self.latent_vars)) q_vars = [v for v in var_list if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0] q_grads = tf.gradients( -(tf.reduce_sum(q_log_prob * tf.stop_gradient(w_norm)) - reg_penalty), q_vars) p_vars = [v for v in var_list if v not in q_vars] p_log_prob_grads = tf.gradients(p_log_prob, p_vars) dx=tf.reduce_sum( tf.matmul(self._gram(self.kern,p_log_prob),p_log_prob_grads) + self._gram(self.dkern,p_log_prob)) p_grads = tf.gradients(-loss,p_vars)*dx grads_and_vars = list(zip(q_grads, q_vars)) + list(zip(p_grads, p_vars)) return loss, grads_and_vars