コード例 #1
0
def _jd_log_prob_ratio(p, x, q, y):
  tf.nest.assert_same_structure(x, y)
  ps, _ = p.sample_distributions(value=x)
  qs, _ = q.sample_distributions(value=y)
  tf.nest.assert_same_structure(ps, qs)
  parts = []
  for p_, x_, q_, y_ in zip(ps, x, qs, y):
    parts.append(log_prob_ratio.log_prob_ratio(p_, x_, q_, y_))
  return tf.add_n(parts)
コード例 #2
0
 def log_prob_ratio_parts_fn(x_y):
     x = tf.nest.map_structure(lambda part: part[0], x_y)
     y = tf.nest.map_structure(lambda part: part[1], x_y)
     p_dists = p.sample_distributions(value=x, seed=jd_lib.dummy_seed())[0]
     q_dists = q.sample_distributions(value=y, seed=jd_lib.dummy_seed())[0]
     # Ensure sharded distributions defer reductions.
     kwds = lambda s: {'reduce_over_shards': False} if s else {}
     return tf.nest.map_structure(
         lambda p, x, q, y, s: lp_ratio.log_prob_ratio(
             p, x, q, y, **kwds(s)), p_dists, x, q_dists, y, is_sharded)
コード例 #3
0
def _jd_log_prob_ratio(p, x, q, y, name=None):
    """Implements `log_prob_ratio` for tfd.JointDistribution*."""
    with tf.name_scope(name or 'jd_log_prob_ratio'):
        tf.nest.assert_same_structure(x, y)
        ps, _ = p.sample_distributions(value=x, seed=dummy_seed())
        qs, _ = q.sample_distributions(value=y, seed=dummy_seed())
        tf.nest.assert_same_structure(ps, qs)
        parts = []
        for p_, x_, q_, y_ in zip(ps, x, qs, y):
            parts.append(log_prob_ratio.log_prob_ratio(p_, x_, q_, y_))
        return tf.add_n(parts)
コード例 #4
0
 def log_prob_ratio_parts_fn(x_y):
   x = tf.nest.map_structure(lambda part: part[0], x_y)
   y = tf.nest.map_structure(lambda part: part[1], x_y)
   p_dists = p.sample_distributions(value=x, seed=samplers.zeros_seed())[0]
   q_dists = q.sample_distributions(value=y, seed=samplers.zeros_seed())[0]
   # Ensure sharded distributions defer reductions.
   kwds = lambda a: {'reduce_over_shards': False} if a else {}
   return nest.map_structure_up_to(
       p_dists,
       lambda p, x, q, y, s: lp_ratio.log_prob_ratio(p, x, q, y, **kwds(s)),
       p_dists, x, q_dists, y, p_axis_names)
コード例 #5
0
def _sample_log_prob_ratio(p, x, q, y):
  checks = []
  if p.validate_args or q.validate_args:
    checks.append(tf.debugging.assert_equal(p.sample_shape, q.sample_shape))
  with tf.control_dependencies(checks):
    # pylint: disable=protected-access
    x, aux = p._prepare_for_underlying(x)
    y, _ = q._prepare_for_underlying(y)
    return p._finish_log_prob(
        log_prob_ratio.log_prob_ratio(p.distribution, x, q.distribution, y),
        aux)
コード例 #6
0
def _independent_log_prob_ratio(p, x, q, y, name=None):
  """Sum-of-diffs log(p(x)/q(y)) for `Independent`s."""
  with tf.name_scope(name or 'independent_log_prob_ratio'):
    checks = []
    if p.validate_args or q.validate_args:
      checks.append(tf.debugging.assert_equal(
          p.reinterpreted_batch_ndims, q.reinterpreted_batch_ndims))
    if p._experimental_use_kahan_sum or q._experimental_use_kahan_sum:  # pylint: disable=protected-access
      sum_fn = lambda x, axis: tfp_math.reduce_kahan_sum(x, axis).total
    else:
      sum_fn = tf.reduce_sum
    with tf.control_dependencies(checks):
      return sum_fn(
          log_prob_ratio.log_prob_ratio(p.distribution, x, q.distribution, y),
          axis=-1 - ps.range(p.reinterpreted_batch_ndims))