def compute_sum_sq(v, shard_axes): sum_sq = tf.reduce_sum(v**2., axis=ps.range(independent_chain_ndims, ps.rank(v))) if shard_axes is not None: sum_sq = distribute_lib.psum(sum_sq, shard_axes) return sum_sq
def lp_fn(self, x, reduce_over_shards=True, **kwargs): new_kwargs = dict(kwargs) if self.distribution.experimental_shard_axis_names: new_kwargs['reduce_over_shards'] = reduce_over_shards lp = getattr(self.distribution, fn_name)(x, **new_kwargs) if reduce_over_shards: lp = distribute_lib.psum(lp, self.experimental_shard_axis_names) return lp
def reduce_sum(shard_axes, x, axis=None): x = tf.reduce_sum(x, axis) if shard_axes is not None: x = distribute_lib.psum(x, shard_axes) return x
def reduce_sum(x, m, shard_axes): out = tf.reduce_sum(x, axis=tf.range(log_prob_rank, ps.rank(m))) if shard_axes is not None: out = distribute_lib.psum(out, shard_axes) return out
def target_log_prob(a, b): return ( tfd.Normal(0., 1.).log_prob(a) + distribute_lib.psum(tfd.Normal( distribute_lib.pbroadcast(a, 'foo'), 1.).log_prob(b), 'foo'))
def _inverse_log_det_jacobian(self, y, **kwargs): return distribute_lib.psum(self.bijector.inverse_log_det_jacobian( y, **kwargs), named_axis=self.shard_axis_name)
def _forward_log_det_jacobian(self, x, **kwargs): return distribute_lib.psum(self.bijector.forward_log_det_jacobian( x, **kwargs), named_axis=self.shard_axis_name)
def reduce_sum(v, axis, shard_axes): out = tf.reduce_sum(v, axis=axis) if shard_axes is not None: out = distribute_lib.psum(out, shard_axes) return out
def lp_fn(self, x): lp = getattr(self.distribution, fn_name)(x) return distribute_lib.psum(lp, self.experimental_shard_axis_names)
def _sum_event_part(x, shard_axes=None): event_axes = ps.range(batch_ndims, ps.rank(x)) return distribute_lib.psum(tf.reduce_sum(x, axis=event_axes), shard_axes)