Beispiel #1
0
 def _log_prob(self, value, reduce_over_shards=True, **kwargs):
     out_log_prob = super(ShardedIndependent,
                          self)._log_prob(value, **kwargs)
     if reduce_over_shards:
         return distribute_lib.psum(out_log_prob,
                                    axis_name=self.shard_axis_name)
     return out_log_prob
Beispiel #2
0
def _sharded_independent_log_prob_ratio(p, x, q, y, reduce_over_shards=True):
    """Distributed log-prob ratio for ShardedIndependent."""
    if p.shard_axis_name != q.shard_axis_name:
        raise ValueError(
            f'Mismatched axis names "{p.shard_axis_name}" vs "{q.shard_axis_name}"'
        )
    underlying = independent_lib._independent_log_prob_ratio(p, x, q, y)  # pylint: disable=protected-access
    if reduce_over_shards:
        return distribute_lib.psum(underlying, axis_name=p.shard_axis_name)
    return underlying
Beispiel #3
0
def _sharded_sample_log_prob_ratio(p,
                                   x,
                                   q,
                                   y,
                                   name=None,
                                   reduce_over_shards=True):
    """Distributed log-prob ratio for ShardedSample."""
    with tf.name_scope(name or 'sharded_sample_log_prob_ratio'):
        if p.shard_axis_name != q.shard_axis_name:
            raise ValueError('Mismatched axis names '
                             f'"{p.shard_axis_name}" vs "{q.shard_axis_name}"')
        underlying = sample_lib._sample_log_prob_ratio(p, x, q, y)  # pylint: disable=protected-access
        if reduce_over_shards:
            return distribute_lib.psum(underlying, axis_name=p.shard_axis_name)
        return underlying