def _log_joint(self, z_sample): """Utility function to calculate model's log joint density, log p(x, z), for inputs z (and fixed data x). Args: z_sample: dict. Latent variable keys to samples. """ scope = self._scope + tf.get_default_graph().unique_name("sample") # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. dict_swap = z_sample.copy() for x, qx in six.iteritems(self.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx log_joint = 0.0 for z in six.iterkeys(self.latent_vars): z_copy = copy(z, dict_swap, scope=scope) log_joint += tf.reduce_sum(z_copy.log_prob(dict_swap[z])) for x in six.iterkeys(self.data): if isinstance(x, RandomVariable): x_copy = copy(x, dict_swap, scope=scope) log_joint += tf.reduce_sum(x_copy.log_prob(dict_swap[x])) return log_joint
def build_loss_and_gradients(self, var_list): """Build loss function. Its automatic differentiation is the gradient of $- \log p(x,z).$ """ # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = tf.get_default_graph().unique_name("inference") dict_swap = { z: qz.value() for z, qz in six.iteritems(self.latent_vars) } for x, qx in six.iteritems(self.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): dict_swap[x] = qx.value() else: dict_swap[x] = qx p_log_prob = 0.0 for z in six.iterkeys(self.latent_vars): z_copy = copy(z, dict_swap, scope=scope) p_log_prob += tf.reduce_sum( self.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) for x in six.iterkeys(self.data): if isinstance(x, RandomVariable): if dict_swap: x_copy = copy(x, dict_swap, scope=scope) else: x_copy = x p_log_prob += tf.reduce_sum( self.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) loss = -p_log_prob grads = tf.gradients(loss, var_list) grads_and_vars = list(zip(grads, var_list)) return loss, grads_and_vars
def complete_conditional(rv, cond_set=None): """Returns the conditional distribution `RandomVariable` $p(\\text{rv}\mid \cdot)$. This function tries to infer the conditional distribution of `rv` given `cond_set`, a set of other `RandomVariable`s in the graph. It will only be able to do this if 1. $p(\\text{rv}\mid \\text{cond\_set})$ is in a tractable exponential family; and 2. the truth of assumption 1 is not obscured in the TensorFlow graph. In other words, this function will do its best to recognize conjugate relationships when they exist. But it may not always be able to do the necessary algebra. Args: rv: RandomVariable. The random variable whose conditional distribution we are interested in. cond_set: iterable of RandomVariable, optional. The set of random variables we want to condition on. Default is all random variables in the graph. (It makes no difference if `cond_set` does or does not include `rv`.) #### Notes When calling `complete_conditional()` multiple times, one should usually pass an explicit `cond_set`. Otherwise `complete_conditional()` will try to condition on the `RandomVariable`s returned by previous calls to itself. This may result in unpredictable behavior. """ if cond_set is None: # Default to Markov blanket, excluding conditionals. This is useful if # calling complete_conditional many times without passing in cond_set. cond_set = get_blanket(rv) cond_set = [i for i in cond_set if not ('complete_conditional' in i.name and 'cond_dist' in i.name)] cond_set = set([rv] + list(cond_set)) with tf.name_scope('complete_conditional_%s' % rv.name) as scope: # log_joint holds all the information we need to get a conditional. log_joint = get_log_joint(cond_set) # Pull out the nodes that are nonlinear functions of rv into s_stats. stop_nodes = set([i.value() for i in cond_set]) subgraph = extract_subgraph(log_joint, stop_nodes) s_stats = suff_stat_nodes(subgraph, rv.value(), cond_set) s_stats = list(set(s_stats)) # Simplify those nodes, and put any new linear terms into multipliers_i. s_stat_exprs = defaultdict(list) for s_stat in s_stats: expr = symbolic_suff_stat(s_stat, rv.value(), stop_nodes) expr = full_simplify(expr) multipliers_i, s_stats_i = extract_s_stat_multipliers(expr) s_stat_exprs[s_stats_i].append( (s_stat, reconstruct_multiplier(multipliers_i))) # Sort out the sufficient statistics to identify this conditional's family. s_stat_keys = list(six.iterkeys(s_stat_exprs)) order = np.argsort([str(i) for i in s_stat_keys]) dist_key = tuple((s_stat_keys[i] for i in order)) dist_constructor, constructor_params = ( _suff_stat_to_dist[rv.support].get(dist_key, (None, None))) if dist_constructor is None: raise NotImplementedError('Conditional distribution has sufficient ' 'statistics %s, but no available ' 'exponential-family distribution has those ' 'sufficient statistics.' % str(dist_key)) # Swap sufficient statistics for placeholders, then take gradients # w.r.t. those placeholders to get natural parameters. The original # nodes involving the sufficient statistic nodes are swapped for new # nodes that depend linearly on the sufficient statistic placeholders. s_stat_placeholders = [] swap_dict = {} swap_back = {} for s_stat_expr in six.itervalues(s_stat_exprs): s_stat_placeholder = tf.placeholder(tf.float32, s_stat_expr[0][0].get_shape()) swap_back[s_stat_placeholder] = tf.cast(rv.value(), tf.float32) s_stat_placeholders.append(s_stat_placeholder) for s_stat_node, multiplier in s_stat_expr: fake_node = s_stat_placeholder * multiplier swap_dict[s_stat_node] = fake_node swap_back[fake_node] = s_stat_node for i in cond_set: if i != rv: val = i.value() val_placeholder = tf.placeholder(val.dtype) swap_dict[val] = val_placeholder swap_back[val_placeholder] = val swap_back[val] = val # prevent random variable nodes from being copied scope_name = scope + str(time.time()) # ensure unique scope when copying log_joint_copy = copy(log_joint, swap_dict, scope=scope_name + 'swap') nat_params = tf.gradients(log_joint_copy, s_stat_placeholders) # Remove any dependencies on those old placeholders. nat_params = [copy(nat_param, swap_back, scope=scope_name + 'swapback') for nat_param in nat_params] nat_params = [nat_params[i] for i in order] return dist_constructor(name='cond_dist', **constructor_params(*nat_params))
def build_loss_and_gradients(self, var_list): p_log_prob = [0.0] * self.n_samples q_log_prob = [0.0] * self.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(self.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("q_sample") dict_swap = {} for x, qx in six.iteritems(self.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx # Sample z ~ q(z), then compute log p(x, z). q_dict_swap = dict_swap.copy() for z, qz in six.iteritems(self.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) q_dict_swap[z] = qz_copy.value() if self.phase_q != 'sleep': # If not sleep phase, compute log q(z). q_log_prob[s] += tf.reduce_sum( self.scale.get(z, 1.0) * qz_copy.log_prob(tf.stop_gradient(q_dict_swap[z]))) for z in six.iterkeys(self.latent_vars): z_copy = copy(z, q_dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum( self.scale.get(z, 1.0) * z_copy.log_prob(q_dict_swap[z])) for x in six.iterkeys(self.data): if isinstance(x, RandomVariable): x_copy = copy(x, q_dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum( self.scale.get(x, 1.0) * x_copy.log_prob(q_dict_swap[x])) if self.phase_q == 'sleep': # Sample z ~ p(z), then compute log q(z). scope = base_scope + tf.get_default_graph().unique_name( "p_sample") p_dict_swap = dict_swap.copy() for z, qz in six.iteritems(self.latent_vars): # Copy p(z) to obtain new set of prior samples. z_copy = copy(z, scope=scope) p_dict_swap[qz] = z_copy.value() for qz in six.itervalues(self.latent_vars): qz_copy = copy(qz, p_dict_swap, scope=scope) q_log_prob[s] += tf.reduce_sum( self.scale.get(z, 1.0) * qz_copy.log_prob(tf.stop_gradient(p_dict_swap[qz]))) p_log_prob = tf.reduce_mean(p_log_prob) q_log_prob = tf.reduce_mean(q_log_prob) if self.logging: tf.summary.scalar("loss/p_log_prob", p_log_prob, collections=[self._summary_key]) tf.summary.scalar("loss/q_log_prob", q_log_prob, collections=[self._summary_key]) loss_p = -p_log_prob loss_q = -q_log_prob q_rvs = list(six.itervalues(self.latent_vars)) q_vars = [ v for v in var_list if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0 ] q_grads = tf.gradients(loss_q, q_vars) p_vars = [v for v in var_list if v not in q_vars] p_grads = tf.gradients(loss_p, p_vars) grads_and_vars = list(zip(q_grads, q_vars)) + list(zip( p_grads, p_vars)) return loss_p, grads_and_vars
def build_loss_and_gradients(self, var_list): """Build loss function $-\Big(\mathbb{E}_{q(\\beta)} [\log p(\\beta) - \log q(\\beta) ] + \sum_{n=1}^N \mathbb{E}_{q(\\beta)q(z_n\mid\\beta)} [ r^*(x_n, z_n, \\beta) ] \Big).$ We minimize it with respect to parameterized variational families $q(z, \\beta; \lambda)$. $r^*(x_n, z_n, \\beta)$ is a function of a single data point $x_n$, single local variable $z_n$, and all global variables $\\beta$. It is equal to the log-ratio $\log p(x_n, z_n\mid \\beta) - \log q(x_n, z_n\mid \\beta),$ where $q(x_n)$ is the empirical data distribution. Rather than explicit calculation, $r^*(x, z, \\beta)$ is the solution to a ratio estimation problem, minimizing the specified `ratio_loss`. Gradients are taken using the reparameterization trick [@kingma2014auto]. #### Notes This also includes model parameters $p(x, z, \\beta; \\theta)$ and variational distributions with inference networks $q(z\mid x)$. There are a bunch of extensions we could easily do in this implementation: + further factorizations can be used to better leverage the graph structure for more complicated models; + score function gradients for global variables; + use more samples; this would require the `copy()` utility function for q's as well, and an additional loop. we opt not to because it complicates the code; + analytic KL/swapping out the penalty term for the globals. """ # Collect tensors used in calculation of losses. scope = tf.get_default_graph().unique_name("inference") qbeta_sample = {} pbeta_log_prob = 0.0 qbeta_log_prob = 0.0 for beta, qbeta in six.iteritems(self.global_vars): # Draw a sample beta' ~ q(beta) and calculate # log p(beta') and log q(beta'). qbeta_sample[beta] = qbeta.value() pbeta_log_prob += tf.reduce_sum(beta.log_prob(qbeta_sample[beta])) qbeta_log_prob += tf.reduce_sum(qbeta.log_prob(qbeta_sample[beta])) pz_sample = {} qz_sample = {} for z, qz in six.iteritems(self.latent_vars): if z not in self.global_vars: # Copy local variables p(z), q(z) to draw samples # z' ~ p(z | beta'), z' ~ q(z | beta'). pz_copy = copy(z, dict_swap=qbeta_sample, scope=scope) pz_sample[z] = pz_copy.value() qz_sample[z] = qz.value() # Collect x' ~ p(x | z', beta') and x' ~ q(x). dict_swap = qbeta_sample.copy() dict_swap.update(qz_sample) x_psample = {} x_qsample = {} for x, x_data in six.iteritems(self.data): if isinstance(x, tf.Tensor): if "Placeholder" not in x.op.type: # Copy p(x | z, beta) to get draw p(x | z', beta'). x_copy = copy(x, dict_swap=dict_swap, scope=scope) x_psample[x] = x_copy x_qsample[x] = x_data elif isinstance(x, RandomVariable): # Copy p(x | z, beta) to get draw p(x | z', beta'). x_copy = copy(x, dict_swap=dict_swap, scope=scope) x_psample[x] = x_copy.value() x_qsample[x] = x_data with tf.variable_scope("Disc"): r_psample = self.discriminator(x_psample, pz_sample, qbeta_sample) with tf.variable_scope("Disc", reuse=True): r_qsample = self.discriminator(x_qsample, qz_sample, qbeta_sample) # Form ratio loss and ratio estimator. if len(self.scale) <= 1: loss_d = tf.reduce_mean(self.ratio_loss(r_psample, r_qsample)) scale = list(six.itervalues(self.scale)) scale = scale[0] if scale else 1.0 scaled_ratio = tf.reduce_sum(scale * r_qsample) else: loss_d = [ tf.reduce_mean(self.ratio_loss(r_psample[key], r_qsample[key])) for key in six.iterkeys(self.scale) ] loss_d = tf.reduce_sum(loss_d) scaled_ratio = [ tf.reduce_sum(self.scale[key] * r_qsample[key]) for key in six.iterkeys(self.scale) ] scaled_ratio = tf.reduce_sum(scaled_ratio) # Form variational objective. loss = -(pbeta_log_prob - qbeta_log_prob + scaled_ratio) var_list_d = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="Disc") if var_list is None: var_list = [ v for v in tf.trainable_variables() if v not in var_list_d ] grads = tf.gradients(loss, var_list) grads_d = tf.gradients(loss_d, var_list_d) grads_and_vars = list(zip(grads, var_list)) grads_and_vars_d = list(zip(grads_d, var_list_d)) return loss, grads_and_vars, loss_d, grads_and_vars_d
def build_update(self): """Draw sample from proposal conditional on last sample. Then accept or reject the sample based on the ratio, $\\text{ratio} = \log p(x, z^{\\text{new}}) - \log p(x, z^{\\text{old}}) + \log g(z^{\\text{new}} \mid z^{\\text{old}}) - \log g(z^{\\text{old}} \mid z^{\\text{new}})$ #### Notes The updates assume each Empirical random variable is directly parameterized by `tf.Variable`s. """ old_sample = { z: tf.gather(qz.params, tf.maximum(self.t - 1, 0)) for z, qz in six.iteritems(self.latent_vars) } old_sample = OrderedDict(old_sample) # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. dict_swap = {} for x, qx in six.iteritems(self.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope='conditional') dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx dict_swap_old = dict_swap.copy() dict_swap_old.update(old_sample) base_scope = tf.get_default_graph().unique_name("inference") + '/' scope_old = base_scope + 'old' scope_new = base_scope + 'new' # Draw proposed sample and calculate acceptance ratio. new_sample = old_sample.copy() # copy to ensure same order ratio = 0.0 for z, proposal_z in six.iteritems(self.proposal_vars): # Build proposal g(znew | zold). proposal_znew = copy(proposal_z, dict_swap_old, scope=scope_old) # Sample znew ~ g(znew | zold). new_sample[z] = proposal_znew.value() # Increment ratio. ratio += tf.reduce_sum(proposal_znew.log_prob(new_sample[z])) dict_swap_new = dict_swap.copy() dict_swap_new.update(new_sample) for z, proposal_z in six.iteritems(self.proposal_vars): # Build proposal g(zold | znew). proposal_zold = copy(proposal_z, dict_swap_new, scope=scope_new) # Increment ratio. ratio -= tf.reduce_sum(proposal_zold.log_prob(dict_swap_old[z])) for z in six.iterkeys(self.latent_vars): # Build priors p(znew) and p(zold). znew = copy(z, dict_swap_new, scope=scope_new) zold = copy(z, dict_swap_old, scope=scope_old) # Increment ratio. ratio += tf.reduce_sum(znew.log_prob(dict_swap_new[z])) ratio -= tf.reduce_sum(zold.log_prob(dict_swap_old[z])) for x in six.iterkeys(self.data): if isinstance(x, RandomVariable): # Build likelihoods p(x | znew) and p(x | zold). x_znew = copy(x, dict_swap_new, scope=scope_new) x_zold = copy(x, dict_swap_old, scope=scope_old) # Increment ratio. ratio += tf.reduce_sum(x_znew.log_prob(dict_swap[x])) ratio -= tf.reduce_sum(x_zold.log_prob(dict_swap[x])) # Accept or reject sample. u = Uniform(low=tf.constant(0.0, dtype=ratio.dtype), high=tf.constant(1.0, dtype=ratio.dtype)).sample() accept = tf.log(u) < ratio sample_values = tf.cond(accept, lambda: list(six.itervalues(new_sample)), lambda: list(six.itervalues(old_sample))) if not isinstance(sample_values, list): # `tf.cond` returns tf.Tensor if output is a list of size 1. sample_values = [sample_values] sample = { z: sample_value for z, sample_value in zip(six.iterkeys(new_sample), sample_values) } # Update Empirical random variables. assign_ops = [] for z, qz in six.iteritems(self.latent_vars): variable = qz.get_variables()[0] assign_ops.append(tf.scatter_update(variable, self.t, sample[z])) # Increment n_accept (if accepted). assign_ops.append(self.n_accept.assign_add(tf.where(accept, 1, 0))) return tf.group(*assign_ops)
def build_score_rb_loss_and_gradients(inference, var_list): """Build loss function and gradients based on the score function estimator [@paisley2012variational] and Rao-Blackwellization [@ranganath2014black]. Computed by sampling from :math:`q(z;\lambda)` and evaluating the expectation using Monte Carlo sampling and Rao-Blackwellization. """ # Build tensors for loss and gradient calculations. There is one set # for each sample from the variational distribution. p_log_probs = [{}] * inference.n_samples q_log_probs = [{}] * inference.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(inference.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("sample") dict_swap = {} for x, qx in six.iteritems(inference.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx for z, qz in six.iteritems(inference.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) dict_swap[z] = qz_copy.value() q_log_probs[s][qz] = tf.reduce_sum( inference.scale.get(z, 1.0) * qz_copy.log_prob(tf.stop_gradient(dict_swap[z]))) for z in six.iterkeys(inference.latent_vars): z_copy = copy(z, dict_swap, scope=scope) p_log_probs[s][z] = tf.reduce_sum( inference.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) for x in six.iterkeys(inference.data): if isinstance(x, RandomVariable): x_copy = copy(x, dict_swap, scope=scope) p_log_probs[s][x] = tf.reduce_sum( inference.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) # Take gradients of Rao-Blackwellized loss for each variational parameter. p_rvs = list(six.iterkeys(inference.latent_vars)) + \ [x for x in six.iterkeys(inference.data) if isinstance(x, RandomVariable)] q_rvs = list(six.itervalues(inference.latent_vars)) reverse_latent_vars = { v: k for k, v in six.iteritems(inference.latent_vars) } grads = [] grads_vars = [] for var in var_list: # Get all variational factors depending on the parameter. descendants = get_descendants(tf.convert_to_tensor(var), q_rvs) if len(descendants) == 0: continue # skip if not a variational parameter # Get p and q's Markov blanket wrt these latent variables. var_p_rvs = set() for qz in descendants: z = reverse_latent_vars[qz] var_p_rvs.update(z.get_blanket(p_rvs) + [z]) var_q_rvs = set() for qz in descendants: var_q_rvs.update(qz.get_blanket(q_rvs) + [qz]) pi_log_prob = [0.0] * inference.n_samples qi_log_prob = [0.0] * inference.n_samples for s in range(inference.n_samples): pi_log_prob[s] = tf.reduce_sum( [p_log_probs[s][rv] for rv in var_p_rvs]) qi_log_prob[s] = tf.reduce_sum( [q_log_probs[s][rv] for rv in var_q_rvs]) pi_log_prob = tf.stack(pi_log_prob) qi_log_prob = tf.stack(qi_log_prob) ## NEW CODE STARTS HERE if inference.control_variates: print("Using control variates") '''weight_copies = [tf.identity(var) for x in range(inference.n_samples)] var_p_rvs_list = [] var_q_rvs_list = [] for weight in weight_copies: descendants = get_descendants(tf.convert_to_tensor(weight), q_rvs) var_p_rvs = set() for qz in descendants: z = reverse_latent_vars[qz] var_p_rvs.update(z.get_blanket(p_rvs) + [z]) var_q_rvs = set() for qz in descendants: var_q_rvs.update(qz.get_blanket(q_rvs) + [qz]) var_p_rvs_list.append(var_p_rvs) var_q_rvs_list.append(var_q_rvs) pi_log_prob_2 = [0.0] * inference.n_samples qi_log_prob_2 = [0.0] * inference.n_samples for s in range(inference.n_samples): pi_log_prob_2[s] = tf.reduce_sum([p_log_probs[s][rv] for rv in var_p_rvs_list[s]]) qi_log_prob_2[s] = tf.reduce_sum([q_log_probs[s][rv] for rv in var_q_rvs_list[s]]) pi_log_prob_2 = tf.stack(pi_log_prob_2) qi_log_prob_2 = tf.stack(qi_log_prob_2) per_example_gradients = tf.gradients(qi_log_prob_2, weight_copies) print("per_example_gradients", per_example_gradients) grad = tf.gradients( -tf.reduce_mean(qi_log_prob * tf.stop_gradient(pi_log_prob - qi_log_prob)), var) #print("qi_log_prob", qi_log_prob) #print("pi_log_prob", pi_log_prob) print("Not actually using control variates") grads.extend(grad)''' fs = [] hs = [] for i in range(inference.n_samples): fs.append( tf.gradients( -qi_log_prob[i] * tf.stop_gradient(pi_log_prob[i] - qi_log_prob[i]), var)[0]) hs.append(tf.gradients(-qi_log_prob[i], var)[0]) fs = tf.stack(fs) # n_samples by n_docs by layer_dim hs = tf.stack(hs) # n_samples by n_docs by layer_dim f_mu = tf.reduce_mean(fs, 0) h_mu = tf.reduce_mean(hs, 0) cv_coefs = tf.divide( tf.reduce_sum(tf.multiply(fs - f_mu, hs - h_mu), 0), (inference.n_samples - 1)) grad_unmeaned = tf.subtract( fs, tf.multiply(tf.expand_dims(cv_coefs, 0), hs)) grads.append(tf.reduce_mean(grad_unmeaned, 0)) else: #grad_unmeaned = fs grad = tf.gradients( -tf.reduce_mean( qi_log_prob * tf.stop_gradient(pi_log_prob - qi_log_prob)), var) print("Not using control variates") grads.extend(grad) #grad = tf.reduce_mean(grad_unmeaned, 0) #grads.append(grad) grads_vars.append(var) #variance = tf.reduce_mean(tf.square(tf.subtract(grad_unmeaned, # tf.expand_dims(tf.reduce_mean(grad_unmeaned, 0), 0))), 0) #tf.summary.histogram("variance/" + # var.name.replace(':', '/'), # variance, collections=[inference._summary_key]) #tf.summary.scalar("variance mean/" + # var.name.replace(':', '.'), tf.reduce_mean(variance), # collections=[inference._summary_key]) ## NEW CODE ENDS HERE # Take gradients of total loss function for model parameters. loss = -(tf.reduce_mean([ tf.reduce_sum(list(six.itervalues(p_log_prob))) for p_log_prob in p_log_probs ]) - tf.reduce_mean([ tf.reduce_sum(list(six.itervalues(q_log_prob))) for q_log_prob in q_log_probs ])) model_vars = [v for v in var_list if v not in grads_vars] model_grads = tf.gradients(loss, model_vars) grads.extend(model_grads) grads_vars.extend(model_vars) grads_and_vars = list(zip(grads, grads_vars)) return loss, grads_and_vars
def build_score_entropy_loss_and_gradients(inference, var_list): """Build loss function and gradients based on the score function estimator [@paisley2012variational]. It assumes the entropy is analytic. Computed by sampling from $q(z;\lambda)$ and evaluating the expectation using Monte Carlo sampling. """ p_log_prob = [0.0] * inference.n_samples q_log_prob = [0.0] * inference.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(inference.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("sample") dict_swap = {} for x, qx in six.iteritems(inference.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx for z, qz in six.iteritems(inference.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) dict_swap[z] = qz_copy.value() q_log_prob[s] += tf.reduce_sum( inference.scale.get(z, 1.0) * qz_copy.log_prob(tf.stop_gradient(dict_swap[z]))) for z in six.iterkeys(inference.latent_vars): z_copy = copy(z, dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum( inference.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) for x in six.iterkeys(inference.data): if isinstance(x, RandomVariable): x_copy = copy(x, dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum( inference.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) p_log_prob = tf.stack(p_log_prob) q_log_prob = tf.stack(q_log_prob) q_entropy = tf.reduce_sum([ tf.reduce_sum(qz.entropy()) for z, qz in six.iteritems(inference.latent_vars) ]) if inference.logging: tf.summary.scalar("loss/p_log_prob", tf.reduce_mean(p_log_prob), collections=[inference._summary_key]) tf.summary.scalar("loss/q_log_prob", tf.reduce_mean(q_log_prob), collections=[inference._summary_key]) tf.summary.scalar("loss/q_entropy", q_entropy, collections=[inference._summary_key]) loss = -(tf.reduce_mean(p_log_prob) + q_entropy) q_rvs = list(six.itervalues(inference.latent_vars)) q_vars = [ v for v in var_list if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0 ] q_grads = tf.gradients( -(tf.reduce_mean(q_log_prob * tf.stop_gradient(p_log_prob)) + q_entropy), q_vars) p_vars = [v for v in var_list if v not in q_vars] p_grads = tf.gradients(loss, p_vars) grads_and_vars = list(zip(q_grads, q_vars)) + list(zip(p_grads, p_vars)) return loss, grads_and_vars
def build_reparam_entropy_loss_and_gradients(inference, var_list): """Build loss function. Its automatic differentiation is a stochastic gradient of $-\\text{ELBO} = -( \mathbb{E}_{q(z; \lambda)} [ \log p(x , z) ] + \mathbb{H}(q(z; \lambda)) )$ based on the reparameterization trick [@kingma2014auto]. It assumes the entropy is analytic. Computed by sampling from $q(z;\lambda)$ and evaluating the expectation using Monte Carlo sampling. """ p_log_prob = [0.0] * inference.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(inference.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("sample") dict_swap = {} for x, qx in six.iteritems(inference.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx for z, qz in six.iteritems(inference.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) dict_swap[z] = qz_copy.value() for z in six.iterkeys(inference.latent_vars): z_copy = copy(z, dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum( inference.scale.get(z, 1.0) * z_copy.log_prob(dict_swap[z])) for x in six.iterkeys(inference.data): if isinstance(x, RandomVariable): x_copy = copy(x, dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum( inference.scale.get(x, 1.0) * x_copy.log_prob(dict_swap[x])) p_log_prob = tf.reduce_mean(p_log_prob) q_entropy = tf.reduce_sum([ tf.reduce_sum(qz.entropy()) for z, qz in six.iteritems(inference.latent_vars) ]) if inference.logging: tf.summary.scalar("loss/p_log_prob", p_log_prob, collections=[inference._summary_key]) tf.summary.scalar("loss/q_entropy", q_entropy, collections=[inference._summary_key]) loss = -(p_log_prob + q_entropy) grads = tf.gradients(loss, var_list) grads_and_vars = list(zip(grads, var_list)) return loss, grads_and_vars
def build_loss_and_gradients(self, var_list): """Build loss function $\\text{KL}( p(z \mid x) \| q(z) ) = \mathbb{E}_{p(z \mid x)} [ \log p(z \mid x) - \log q(z; \lambda) ]$ and stochastic gradients based on importance sampling. The loss function can be estimated as $\sum_{s=1}^S [ w_{\\text{norm}}(z^s; \lambda) (\log p(x, z^s) - \log q(z^s; \lambda) ],$ where for $z^s \sim q(z; \lambda)$, $w_{\\text{norm}}(z^s; \lambda) = w(z^s; \lambda) / \sum_{s=1}^S w(z^s; \lambda)$ normalizes the importance weights, $w(z^s; \lambda) = p(x, z^s) / q(z^s; \lambda)$. This provides a gradient, $- \sum_{s=1}^S [ w_{\\text{norm}}(z^s; \lambda) \\nabla_{\lambda} \log q(z^s; \lambda) ].$ """ p_log_prob = [0.0] * self.n_samples q_log_prob = [0.0] * self.n_samples base_scope = tf.get_default_graph().unique_name("inference") + '/' for s in range(self.n_samples): # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. scope = base_scope + tf.get_default_graph().unique_name("sample") dict_swap = {} for x, qx in six.iteritems(self.data): if isinstance(x, RandomVariable): if isinstance(qx, RandomVariable): qx_copy = copy(qx, scope=scope) dict_swap[x] = qx_copy.value() else: dict_swap[x] = qx for z, qz in six.iteritems(self.latent_vars): # Copy q(z) to obtain new set of posterior samples. qz_copy = copy(qz, scope=scope) dict_swap[z] = qz_copy.value() q_log_prob[s] += tf.reduce_sum( qz_copy.log_prob(tf.stop_gradient(dict_swap[z]))) for z in six.iterkeys(self.latent_vars): z_copy = copy(z, dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum(z_copy.log_prob(dict_swap[z])) for x in six.iterkeys(self.data): if isinstance(x, RandomVariable): x_copy = copy(x, dict_swap, scope=scope) p_log_prob[s] += tf.reduce_sum( x_copy.log_prob(dict_swap[x])) p_log_prob = tf.stack(p_log_prob) q_log_prob = tf.stack(q_log_prob) if self.logging: tf.summary.scalar("loss/p_log_prob", tf.reduce_mean(p_log_prob), collections=[self._summary_key]) tf.summary.scalar("loss/q_log_prob", tf.reduce_mean(q_log_prob), collections=[self._summary_key]) log_w = p_log_prob - q_log_prob log_w_norm = log_w - tf.reduce_logsumexp(log_w) w_norm = tf.exp(log_w_norm) loss = tf.reduce_sum(w_norm * log_w) q_rvs = list(six.itervalues(self.latent_vars)) q_vars = [ v for v in var_list if len(get_descendants(tf.convert_to_tensor(v), q_rvs)) != 0 ] q_grads = tf.gradients( -tf.reduce_sum(q_log_prob * tf.stop_gradient(w_norm)), q_vars) p_vars = [v for v in var_list if v not in q_vars] p_grads = tf.gradients(-loss, p_vars) grads_and_vars = list(zip(q_grads, q_vars)) + list(zip( p_grads, p_vars)) return loss, grads_and_vars