def test_blanket_structure(self): """a -> c <- b | v d -> f <- e """ with self.test_session(): a = Normal(0.0, 1.0) b = Normal(0.0, 1.0) c = Normal(a * b, 1.0) d = Normal(0.0, 1.0) e = Normal(0.0, 1.0) f = Normal(c * d * e, 1.0) self.assertEqual(set(get_blanket(a)), set([b, c])) self.assertEqual(set(get_blanket(b)), set([a, c])) self.assertEqual(set(get_blanket(c)), set([a, b, d, e, f])) self.assertEqual(set(get_blanket(d)), set([c, e, f])) self.assertEqual(set(get_blanket(e)), set([c, d, f])) self.assertEqual(set(get_blanket(f)), set([c, d, e]))
def complete_conditional(rv, cond_set=None): """Returns the conditional distribution `RandomVariable` p(`rv` | .). This function tries to infer the conditional distribution of `rv` given `cond_set`, a set of other `RandomVariable`s in the graph. It will only be able to do this if a) p(`rv` | `cond_set`) is in a tractable exponential family AND b) the truth of assumption (a) is not obscured in the TensorFlow graph. In other words, this function will do its best to recognize conjugate relationships when they exist, but it may not always be able to do the necessary algebra. Parameters ---------- rv : RandomVariable The `RandomVariable` whose conditional distribution we are interested in. cond_set : iterable of RandomVariables, optional The set of `RandomVariable`s we want to condition on. Defaults to all `RandomVariable`s in the graph. (It makes no difference if `cond_set` does or does not include `rv`.) Notes ----- When calling `complete_conditional()` multiple times, one should usually pass an explicit `cond_set`. Otherwise `complete_conditional()` will try to condition on the `RandomVariable`s returned by previous calls to itself, which may result in unpredictable behavior. """ if cond_set is None: # Default to Markov blanket, excluding conditionals. This is useful if # calling complete_conditional many times without passing in cond_set. cond_set = get_blanket(rv) cond_set = [i for i in cond_set if not ('complete_conditional' in i.name and 'cond_dist' in i.name)] cond_set = set([rv] + list(cond_set)) with tf.name_scope('complete_conditional_%s' % rv.name) as scope: # log_joint holds all the information we need to get a conditional. log_joint = get_log_joint(cond_set) # Pull out the nodes that are nonlinear functions of rv into s_stats. stop_nodes = set([i.value() for i in cond_set]) subgraph = extract_subgraph(log_joint, stop_nodes) s_stats = suff_stat_nodes(subgraph, rv.value(), cond_set) s_stats = list(set(s_stats)) # Simplify those nodes, and put any new linear terms into multipliers_i. s_stat_exprs = defaultdict(list) for s_stat in s_stats: expr = symbolic_suff_stat(s_stat, rv.value(), stop_nodes) expr = full_simplify(expr) multipliers_i, s_stats_i = extract_s_stat_multipliers(expr) s_stat_exprs[s_stats_i].append( (s_stat, reconstruct_multiplier(multipliers_i))) # Sort out the sufficient statistics to identify this conditional's family. s_stat_keys = list(six.iterkeys(s_stat_exprs)) order = np.argsort([str(i) for i in s_stat_keys]) dist_key = tuple((s_stat_keys[i] for i in order)) dist_constructor, constructor_params = ( _suff_stat_to_dist[rv.support].get(dist_key, (None, None))) if dist_constructor is None: raise NotImplementedError('Conditional distribution has sufficient ' 'statistics %s, but no available ' 'exponential-family distribution has those ' 'sufficient statistics.' % str(dist_key)) # Swap sufficient statistics for placeholders, then take gradients # w.r.t. those placeholders to get natural parameters. The original # nodes involving the sufficient statistic nodes are swapped for new # nodes that depend linearly on the sufficient statistic placeholders. s_stat_placeholders = [] swap_dict = {} swap_back = {} for s_stat_expr in six.itervalues(s_stat_exprs): s_stat_placeholder = tf.placeholder(tf.float32, s_stat_expr[0][0].get_shape()) swap_back[s_stat_placeholder] = tf.cast(rv.value(), tf.float32) s_stat_placeholders.append(s_stat_placeholder) for s_stat_node, multiplier in s_stat_expr: fake_node = s_stat_placeholder * multiplier swap_dict[s_stat_node] = fake_node swap_back[fake_node] = s_stat_node for i in cond_set: if i != rv: val = i.value() val_placeholder = tf.placeholder(val.dtype) swap_dict[val] = val_placeholder swap_back[val_placeholder] = val swap_back[val] = val # prevent random variable nodes from being copied scope_name = scope + str(time.time()) # ensure unique scope when copying log_joint_copy = copy(log_joint, swap_dict, scope=scope_name + 'swap') nat_params = tf.gradients(log_joint_copy, s_stat_placeholders) # Remove any dependencies on those old placeholders. nat_params = [copy(nat_param, swap_back, scope=scope_name + 'swapback') for nat_param in nat_params] nat_params = [nat_params[i] for i in order] return dist_constructor(name='cond_dist', **constructor_params(*nat_params))
def complete_conditional(rv, cond_set=None): """Returns the conditional distribution `RandomVariable` $p(\\text{rv}\mid \cdot)$. This function tries to infer the conditional distribution of `rv` given `cond_set`, a set of other `RandomVariable`s in the graph. It will only be able to do this if 1. $p(\\text{rv}\mid \\text{cond\_set})$ is in a tractable exponential family; and 2. the truth of assumption 1 is not obscured in the TensorFlow graph. In other words, this function will do its best to recognize conjugate relationships when they exist. But it may not always be able to do the necessary algebra. Args: rv: RandomVariable. The random variable whose conditional distribution we are interested in. cond_set: iterable of RandomVariable, optional. The set of random variables we want to condition on. Default is all random variables in the graph. (It makes no difference if `cond_set` does or does not include `rv`.) #### Notes When calling `complete_conditional()` multiple times, one should usually pass an explicit `cond_set`. Otherwise `complete_conditional()` will try to condition on the `RandomVariable`s returned by previous calls to itself. This may result in unpredictable behavior. """ if cond_set is None: # Default to Markov blanket, excluding conditionals. This is useful if # calling complete_conditional many times without passing in cond_set. cond_set = get_blanket(rv) cond_set = [i for i in cond_set if not ('complete_conditional' in i.name and 'cond_dist' in i.name)] cond_set = set([rv] + list(cond_set)) with tf.name_scope('complete_conditional_%s' % rv.name) as scope: # log_joint holds all the information we need to get a conditional. log_joint = get_log_joint(cond_set) # Pull out the nodes that are nonlinear functions of rv into s_stats. stop_nodes = set([i.value() for i in cond_set]) subgraph = extract_subgraph(log_joint, stop_nodes) s_stats = suff_stat_nodes(subgraph, rv.value(), cond_set) s_stats = list(set(s_stats)) # Simplify those nodes, and put any new linear terms into multipliers_i. s_stat_exprs = defaultdict(list) for s_stat in s_stats: expr = symbolic_suff_stat(s_stat, rv.value(), stop_nodes) expr = full_simplify(expr) multipliers_i, s_stats_i = extract_s_stat_multipliers(expr) s_stat_exprs[s_stats_i].append( (s_stat, reconstruct_multiplier(multipliers_i))) # Sort out the sufficient statistics to identify this conditional's family. s_stat_keys = list(six.iterkeys(s_stat_exprs)) order = np.argsort([str(i) for i in s_stat_keys]) dist_key = tuple((s_stat_keys[i] for i in order)) dist_constructor, constructor_params = ( _suff_stat_to_dist[rv.support].get(dist_key, (None, None))) if dist_constructor is None: raise NotImplementedError('Conditional distribution has sufficient ' 'statistics %s, but no available ' 'exponential-family distribution has those ' 'sufficient statistics.' % str(dist_key)) # Swap sufficient statistics for placeholders, then take gradients # w.r.t. those placeholders to get natural parameters. The original # nodes involving the sufficient statistic nodes are swapped for new # nodes that depend linearly on the sufficient statistic placeholders. s_stat_placeholders = [] swap_dict = {} swap_back = {} for s_stat_expr in six.itervalues(s_stat_exprs): s_stat_placeholder = tf.placeholder(tf.float32, s_stat_expr[0][0].get_shape()) swap_back[s_stat_placeholder] = tf.cast(rv.value(), tf.float32) s_stat_placeholders.append(s_stat_placeholder) for s_stat_node, multiplier in s_stat_expr: fake_node = s_stat_placeholder * multiplier swap_dict[s_stat_node] = fake_node swap_back[fake_node] = s_stat_node for i in cond_set: if i != rv: val = i.value() val_placeholder = tf.placeholder(val.dtype) swap_dict[val] = val_placeholder swap_back[val_placeholder] = val swap_back[val] = val # prevent random variable nodes from being copied scope_name = scope + str(time.time()) # ensure unique scope when copying log_joint_copy = copy(log_joint, swap_dict, scope=scope_name + 'swap') nat_params = tf.gradients(log_joint_copy, s_stat_placeholders) # Remove any dependencies on those old placeholders. nat_params = [copy(nat_param, swap_back, scope=scope_name + 'swapback') for nat_param in nat_params] nat_params = [nat_params[i] for i in order] return dist_constructor(name='cond_dist', **constructor_params(*nat_params))