Exemplo n.º 1
0
 def test_blanket_structure(self):
     """a -> c <- b
         |
         v
    d -> f <- e
 """
     with self.test_session():
         a = Normal(0.0, 1.0)
         b = Normal(0.0, 1.0)
         c = Normal(a * b, 1.0)
         d = Normal(0.0, 1.0)
         e = Normal(0.0, 1.0)
         f = Normal(c * d * e, 1.0)
         self.assertEqual(set(get_blanket(a)), set([b, c]))
         self.assertEqual(set(get_blanket(b)), set([a, c]))
         self.assertEqual(set(get_blanket(c)), set([a, b, d, e, f]))
         self.assertEqual(set(get_blanket(d)), set([c, e, f]))
         self.assertEqual(set(get_blanket(e)), set([c, d, f]))
         self.assertEqual(set(get_blanket(f)), set([c, d, e]))
Exemplo n.º 2
0
 def test_blanket_structure(self):
   """a -> c <- b
           |
           v
      d -> f <- e
   """
   with self.test_session():
     a = Normal(0.0, 1.0)
     b = Normal(0.0, 1.0)
     c = Normal(a * b, 1.0)
     d = Normal(0.0, 1.0)
     e = Normal(0.0, 1.0)
     f = Normal(c * d * e, 1.0)
     self.assertEqual(set(get_blanket(a)), set([b, c]))
     self.assertEqual(set(get_blanket(b)), set([a, c]))
     self.assertEqual(set(get_blanket(c)), set([a, b, d, e, f]))
     self.assertEqual(set(get_blanket(d)), set([c, e, f]))
     self.assertEqual(set(get_blanket(e)), set([c, d, f]))
     self.assertEqual(set(get_blanket(f)), set([c, d, e]))
Exemplo n.º 3
0
def complete_conditional(rv, cond_set=None):
  """Returns the conditional distribution `RandomVariable` p(`rv` | .).

  This function tries to infer the conditional distribution of `rv`
  given `cond_set`, a set of other `RandomVariable`s in the graph. It
  will only be able to do this if
  a) p(`rv` | `cond_set`) is in a tractable exponential family AND
  b) the truth of assumption (a) is not obscured in the TensorFlow graph.
  In other words, this function will do its best to recognize conjugate
  relationships when they exist, but it may not always be able to do the
  necessary algebra.

  Parameters
  ----------
  rv : RandomVariable
    The `RandomVariable` whose conditional distribution we are interested in.
  cond_set : iterable of RandomVariables, optional
    The set of `RandomVariable`s we want to condition on. Defaults to all
    `RandomVariable`s in the graph. (It makes no difference if `cond_set` does
    or does not include `rv`.)

  Notes
  -----
  When calling `complete_conditional()` multiple times, one should
  usually pass an explicit `cond_set`. Otherwise
  `complete_conditional()` will try to condition on the
  `RandomVariable`s returned by previous calls to itself, which may
  result in unpredictable behavior.
  """
  if cond_set is None:
    # Default to Markov blanket, excluding conditionals. This is useful if
    # calling complete_conditional many times without passing in cond_set.
    cond_set = get_blanket(rv)
    cond_set = [i for i in cond_set if not
                ('complete_conditional' in i.name and 'cond_dist' in i.name)]

  cond_set = set([rv] + list(cond_set))
  with tf.name_scope('complete_conditional_%s' % rv.name) as scope:
    # log_joint holds all the information we need to get a conditional.
    log_joint = get_log_joint(cond_set)

    # Pull out the nodes that are nonlinear functions of rv into s_stats.
    stop_nodes = set([i.value() for i in cond_set])
    subgraph = extract_subgraph(log_joint, stop_nodes)
    s_stats = suff_stat_nodes(subgraph, rv.value(), cond_set)
    s_stats = list(set(s_stats))

    # Simplify those nodes, and put any new linear terms into multipliers_i.
    s_stat_exprs = defaultdict(list)
    for s_stat in s_stats:
      expr = symbolic_suff_stat(s_stat, rv.value(), stop_nodes)
      expr = full_simplify(expr)
      multipliers_i, s_stats_i = extract_s_stat_multipliers(expr)
      s_stat_exprs[s_stats_i].append(
          (s_stat, reconstruct_multiplier(multipliers_i)))

    # Sort out the sufficient statistics to identify this conditional's family.
    s_stat_keys = list(six.iterkeys(s_stat_exprs))
    order = np.argsort([str(i) for i in s_stat_keys])
    dist_key = tuple((s_stat_keys[i] for i in order))
    dist_constructor, constructor_params = (
        _suff_stat_to_dist[rv.support].get(dist_key, (None, None)))
    if dist_constructor is None:
      raise NotImplementedError('Conditional distribution has sufficient '
                                'statistics %s, but no available '
                                'exponential-family distribution has those '
                                'sufficient statistics.' % str(dist_key))

    # Swap sufficient statistics for placeholders, then take gradients
    # w.r.t. those placeholders to get natural parameters. The original
    # nodes involving the sufficient statistic nodes are swapped for new
    # nodes that depend linearly on the sufficient statistic placeholders.
    s_stat_placeholders = []
    swap_dict = {}
    swap_back = {}
    for s_stat_expr in six.itervalues(s_stat_exprs):
      s_stat_placeholder = tf.placeholder(tf.float32,
                                          s_stat_expr[0][0].get_shape())
      swap_back[s_stat_placeholder] = tf.cast(rv.value(), tf.float32)
      s_stat_placeholders.append(s_stat_placeholder)
      for s_stat_node, multiplier in s_stat_expr:
        fake_node = s_stat_placeholder * multiplier
        swap_dict[s_stat_node] = fake_node
        swap_back[fake_node] = s_stat_node

    for i in cond_set:
      if i != rv:
        val = i.value()
        val_placeholder = tf.placeholder(val.dtype)
        swap_dict[val] = val_placeholder
        swap_back[val_placeholder] = val
        swap_back[val] = val  # prevent random variable nodes from being copied

    scope_name = scope + str(time.time())  # ensure unique scope when copying
    log_joint_copy = copy(log_joint, swap_dict, scope=scope_name + 'swap')
    nat_params = tf.gradients(log_joint_copy, s_stat_placeholders)

    # Remove any dependencies on those old placeholders.
    nat_params = [copy(nat_param, swap_back, scope=scope_name + 'swapback')
                  for nat_param in nat_params]
    nat_params = [nat_params[i] for i in order]

    return dist_constructor(name='cond_dist', **constructor_params(*nat_params))
Exemplo n.º 4
0
def complete_conditional(rv, cond_set=None):
  """Returns the conditional distribution `RandomVariable`
  $p(\\text{rv}\mid \cdot)$.

  This function tries to infer the conditional distribution of `rv`
  given `cond_set`, a set of other `RandomVariable`s in the graph. It
  will only be able to do this if

  1. $p(\\text{rv}\mid \\text{cond\_set})$ is in a tractable
     exponential family; and
  2. the truth of assumption 1 is not obscured in the TensorFlow graph.

  In other words, this function will do its best to recognize conjugate
  relationships when they exist. But it may not always be able to do the
  necessary algebra.

  Args:
    rv: RandomVariable.
      The random variable whose conditional distribution we are interested in.
    cond_set: iterable of RandomVariable, optional.
      The set of random variables we want to condition on. Default is all
      random variables in the graph. (It makes no difference if `cond_set`
      does or does not include `rv`.)

  #### Notes

  When calling `complete_conditional()` multiple times, one should
  usually pass an explicit `cond_set`. Otherwise
  `complete_conditional()` will try to condition on the
  `RandomVariable`s returned by previous calls to itself. This may
  result in unpredictable behavior.
  """
  if cond_set is None:
    # Default to Markov blanket, excluding conditionals. This is useful if
    # calling complete_conditional many times without passing in cond_set.
    cond_set = get_blanket(rv)
    cond_set = [i for i in cond_set if not
                ('complete_conditional' in i.name and 'cond_dist' in i.name)]

  cond_set = set([rv] + list(cond_set))
  with tf.name_scope('complete_conditional_%s' % rv.name) as scope:
    # log_joint holds all the information we need to get a conditional.
    log_joint = get_log_joint(cond_set)

    # Pull out the nodes that are nonlinear functions of rv into s_stats.
    stop_nodes = set([i.value() for i in cond_set])
    subgraph = extract_subgraph(log_joint, stop_nodes)
    s_stats = suff_stat_nodes(subgraph, rv.value(), cond_set)
    s_stats = list(set(s_stats))

    # Simplify those nodes, and put any new linear terms into multipliers_i.
    s_stat_exprs = defaultdict(list)
    for s_stat in s_stats:
      expr = symbolic_suff_stat(s_stat, rv.value(), stop_nodes)
      expr = full_simplify(expr)
      multipliers_i, s_stats_i = extract_s_stat_multipliers(expr)
      s_stat_exprs[s_stats_i].append(
          (s_stat, reconstruct_multiplier(multipliers_i)))

    # Sort out the sufficient statistics to identify this conditional's family.
    s_stat_keys = list(six.iterkeys(s_stat_exprs))
    order = np.argsort([str(i) for i in s_stat_keys])
    dist_key = tuple((s_stat_keys[i] for i in order))
    dist_constructor, constructor_params = (
        _suff_stat_to_dist[rv.support].get(dist_key, (None, None)))
    if dist_constructor is None:
      raise NotImplementedError('Conditional distribution has sufficient '
                                'statistics %s, but no available '
                                'exponential-family distribution has those '
                                'sufficient statistics.' % str(dist_key))

    # Swap sufficient statistics for placeholders, then take gradients
    # w.r.t. those placeholders to get natural parameters. The original
    # nodes involving the sufficient statistic nodes are swapped for new
    # nodes that depend linearly on the sufficient statistic placeholders.
    s_stat_placeholders = []
    swap_dict = {}
    swap_back = {}
    for s_stat_expr in six.itervalues(s_stat_exprs):
      s_stat_placeholder = tf.placeholder(tf.float32,
                                          s_stat_expr[0][0].get_shape())
      swap_back[s_stat_placeholder] = tf.cast(rv.value(), tf.float32)
      s_stat_placeholders.append(s_stat_placeholder)
      for s_stat_node, multiplier in s_stat_expr:
        fake_node = s_stat_placeholder * multiplier
        swap_dict[s_stat_node] = fake_node
        swap_back[fake_node] = s_stat_node

    for i in cond_set:
      if i != rv:
        val = i.value()
        val_placeholder = tf.placeholder(val.dtype)
        swap_dict[val] = val_placeholder
        swap_back[val_placeholder] = val
        swap_back[val] = val  # prevent random variable nodes from being copied

    scope_name = scope + str(time.time())  # ensure unique scope when copying
    log_joint_copy = copy(log_joint, swap_dict, scope=scope_name + 'swap')
    nat_params = tf.gradients(log_joint_copy, s_stat_placeholders)

    # Remove any dependencies on those old placeholders.
    nat_params = [copy(nat_param, swap_back, scope=scope_name + 'swapback')
                  for nat_param in nat_params]
    nat_params = [nat_params[i] for i in order]

    return dist_constructor(name='cond_dist', **constructor_params(*nat_params))