def _log_prob(self, value, reduce_over_shards=True, **kwargs): out_log_prob = super(ShardedIndependent, self)._log_prob(value, **kwargs) if reduce_over_shards: return distribute_lib.psum(out_log_prob, axis_name=self.shard_axis_name) return out_log_prob
def _sharded_independent_log_prob_ratio(p, x, q, y, reduce_over_shards=True): """Distributed log-prob ratio for ShardedIndependent.""" if p.shard_axis_name != q.shard_axis_name: raise ValueError( f'Mismatched axis names "{p.shard_axis_name}" vs "{q.shard_axis_name}"' ) underlying = independent_lib._independent_log_prob_ratio(p, x, q, y) # pylint: disable=protected-access if reduce_over_shards: return distribute_lib.psum(underlying, axis_name=p.shard_axis_name) return underlying
def _sharded_sample_log_prob_ratio(p, x, q, y, name=None, reduce_over_shards=True): """Distributed log-prob ratio for ShardedSample.""" with tf.name_scope(name or 'sharded_sample_log_prob_ratio'): if p.shard_axis_name != q.shard_axis_name: raise ValueError('Mismatched axis names ' f'"{p.shard_axis_name}" vs "{q.shard_axis_name}"') underlying = sample_lib._sample_log_prob_ratio(p, x, q, y) # pylint: disable=protected-access if reduce_over_shards: return distribute_lib.psum(underlying, axis_name=p.shard_axis_name) return underlying