def _jd_log_prob_ratio(p, x, q, y): tf.nest.assert_same_structure(x, y) ps, _ = p.sample_distributions(value=x) qs, _ = q.sample_distributions(value=y) tf.nest.assert_same_structure(ps, qs) parts = [] for p_, x_, q_, y_ in zip(ps, x, qs, y): parts.append(log_prob_ratio.log_prob_ratio(p_, x_, q_, y_)) return tf.add_n(parts)
def log_prob_ratio_parts_fn(x_y): x = tf.nest.map_structure(lambda part: part[0], x_y) y = tf.nest.map_structure(lambda part: part[1], x_y) p_dists = p.sample_distributions(value=x, seed=jd_lib.dummy_seed())[0] q_dists = q.sample_distributions(value=y, seed=jd_lib.dummy_seed())[0] # Ensure sharded distributions defer reductions. kwds = lambda s: {'reduce_over_shards': False} if s else {} return tf.nest.map_structure( lambda p, x, q, y, s: lp_ratio.log_prob_ratio( p, x, q, y, **kwds(s)), p_dists, x, q_dists, y, is_sharded)
def _jd_log_prob_ratio(p, x, q, y, name=None): """Implements `log_prob_ratio` for tfd.JointDistribution*.""" with tf.name_scope(name or 'jd_log_prob_ratio'): tf.nest.assert_same_structure(x, y) ps, _ = p.sample_distributions(value=x, seed=dummy_seed()) qs, _ = q.sample_distributions(value=y, seed=dummy_seed()) tf.nest.assert_same_structure(ps, qs) parts = [] for p_, x_, q_, y_ in zip(ps, x, qs, y): parts.append(log_prob_ratio.log_prob_ratio(p_, x_, q_, y_)) return tf.add_n(parts)
def log_prob_ratio_parts_fn(x_y): x = tf.nest.map_structure(lambda part: part[0], x_y) y = tf.nest.map_structure(lambda part: part[1], x_y) p_dists = p.sample_distributions(value=x, seed=samplers.zeros_seed())[0] q_dists = q.sample_distributions(value=y, seed=samplers.zeros_seed())[0] # Ensure sharded distributions defer reductions. kwds = lambda a: {'reduce_over_shards': False} if a else {} return nest.map_structure_up_to( p_dists, lambda p, x, q, y, s: lp_ratio.log_prob_ratio(p, x, q, y, **kwds(s)), p_dists, x, q_dists, y, p_axis_names)
def _sample_log_prob_ratio(p, x, q, y): checks = [] if p.validate_args or q.validate_args: checks.append(tf.debugging.assert_equal(p.sample_shape, q.sample_shape)) with tf.control_dependencies(checks): # pylint: disable=protected-access x, aux = p._prepare_for_underlying(x) y, _ = q._prepare_for_underlying(y) return p._finish_log_prob( log_prob_ratio.log_prob_ratio(p.distribution, x, q.distribution, y), aux)
def _independent_log_prob_ratio(p, x, q, y, name=None): """Sum-of-diffs log(p(x)/q(y)) for `Independent`s.""" with tf.name_scope(name or 'independent_log_prob_ratio'): checks = [] if p.validate_args or q.validate_args: checks.append(tf.debugging.assert_equal( p.reinterpreted_batch_ndims, q.reinterpreted_batch_ndims)) if p._experimental_use_kahan_sum or q._experimental_use_kahan_sum: # pylint: disable=protected-access sum_fn = lambda x, axis: tfp_math.reduce_kahan_sum(x, axis).total else: sum_fn = tf.reduce_sum with tf.control_dependencies(checks): return sum_fn( log_prob_ratio.log_prob_ratio(p.distribution, x, q.distribution, y), axis=-1 - ps.range(p.reinterpreted_batch_ndims))