def log_prob(x, y, z): sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [ self.axis_name, other_axis_name, [self.axis_name, other_axis_name] ]) parts = sharded_log_prob_parts([x, y, z]) return tf.add_n(parts)
def lp_grad(x): untransformed_log_prob = distribute_lib.make_sharded_log_prob_parts( log_prob, self.axis_name) transformed_log_prob = transform_log_prob( untransformed_log_prob, sharded.Sharded(tfb.Sigmoid(), shard_axis_name=self.axis_name)) lp, g = tfp.math.value_and_gradient(transformed_log_prob, (x, )) return lp, g
def log_prob(*value): w, x = value sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, { 'w': None, 'x': self.axis_name, 'data': self.axis_name }) parts = sharded_log_prob_parts({'w': w, 'x': x, 'data': data}) return tf.add_n(tf.nest.flatten(parts))
def lp_fn(self, x, reduce_over_shards=True, **kwargs): def impl(value): new_kwargs = dict(kwargs) if self.distribution.experimental_shard_axis_names: new_kwargs['reduce_over_shards'] = reduce_over_shards return getattr(self.distribution, fn_name)(value, **new_kwargs) if reduce_over_shards: impl = distribute_lib.make_sharded_log_prob_parts( impl, self.experimental_shard_axis_names) return impl(x)
def run(x, data): def log_prob_parts(value): x, data = value return [ tfd.Normal(0., 1.).log_prob(x), tf.reduce_sum(tfd.Normal(x, 1.).log_prob(data)) ] sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [self.axis_name, self.axis_name]) return sharded_log_prob_parts([x, data])
def test_correct_log_prob_for_global_variable_no_strategy(self): data = tf.ones(4) def log_prob_parts(value): x, data = value return [ tfd.Normal(0., 1.).log_prob(x), tf.reduce_sum(tfd.Normal(x, 1.).log_prob(data)) ] sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [None, True]) self.assertAllEqualNested( self.evaluate(sharded_log_prob_parts([tf.constant(0.), data])), self.evaluate([ tfd.Normal(0., 1.).log_prob(0.), tf.reduce_sum(tfd.Normal(0., 1.).log_prob(data)) ]))
def _sharded_log_prob_ratio(p, x, q, y, name=None, reduce_over_shards=True): """Distributed log-prob ratio for Sharded.""" with tf.name_scope(name or 'sharded_log_prob_ratio'): if p.experimental_shard_axis_names != q.experimental_shard_axis_names: raise ValueError('Mismatched axis names ' f'"{p.experimental_shard_axis_names}" vs "' f'"{q.experimental_shard_axis_names}"') def log_prob_ratio_fn(x_y): return log_prob_ratio.log_prob_ratio(p.distribution, x_y[0], q.distribution, x_y[1]) if reduce_over_shards: return distribute_lib.make_sharded_log_prob_parts( # Stack, because make_sharded_log_prob_parts expects inputs/outputs to # be 1 to 1. TODO(b/175084455): revisit this after the distributed # bijectors are done, as it is likely that make_sharded_log_prob_parts # will be adjusted then to not have this limitation. log_prob_ratio_fn, p.experimental_shard_axis_names)(tf.stack([x, y], axis=0)) return log_prob_ratio_fn([x, y])
def _map_measure_over_dists(self, attr, value): """Override the default implementation to shard its log_prob calculation.""" if any(x is None for x in tf.nest.flatten(value)): raise ValueError( 'No `value` part can be `None`; saw: {}.'.format(value)) if (attr in ('log_prob', 'unnormalized_log_prob')) and any( self.experimental_shard_axis_names): def inner_log_prob_parts(flat_value): unflat_value = self._model_unflatten(flat_value) ds, xs = self._call_flat_sample_distributions( value=unflat_value, seed=samplers.zeros_seed()) # For sharded distributions, we need to make sure not to do an # all-reduce. axis_names = self._model_flatten( self.experimental_shard_axis_names) log_prob_fns = [ functools.partial(getattr(d, attr), reduce_over_shards=False) if axis_name else getattr(d, attr) for d, axis_name in zip(ds, axis_names) ] # We need to flatten and unflatten here to ensure the output structure # matches `flat_sharded_distributions`. vals = self._model_unflatten([ log_prob_fn(x) for log_prob_fn, x in zip(log_prob_fns, xs) ]) return self._model_flatten(vals) flat_value = self._model_flatten(value) flat_axis_names = self._model_flatten( self.experimental_shard_axis_names) flat_xs = distribute_lib.make_sharded_log_prob_parts( inner_log_prob_parts, flat_axis_names)(flat_value) return iter(flat_xs) ds, xs = self._call_flat_sample_distributions( value=value, seed=samplers.zeros_seed()) return (getattr(d, attr)(x) for d, x in zip(ds, xs))
def _dist_jd_log_prob_ratio(p, x, q, y, name=None): """Distributed log-prob ratio for JDs.""" with tf.name_scope(name or 'dist_jd_log_prob_ratio'): tf.nest.assert_same_structure(x, y) p_axis_names = p.experimental_shard_axis_names q_axis_names = q.experimental_shard_axis_names if p_axis_names != q_axis_names: raise ValueError( 'p and q must use the same sharding. ' f'Saw: p: {p}, {p_axis_names}, q: {q}, {q_axis_names}') def log_prob_ratio_parts_fn(x_y): x = tf.nest.map_structure(lambda part: part[0], x_y) y = tf.nest.map_structure(lambda part: part[1], x_y) p_dists = p.sample_distributions(value=x, seed=samplers.zeros_seed())[0] q_dists = q.sample_distributions(value=y, seed=samplers.zeros_seed())[0] # Ensure sharded distributions defer reductions. kwds = lambda a: {'reduce_over_shards': False} if a else {} return nest.map_structure_up_to( p_dists, lambda p, x, q, y, s: lp_ratio.log_prob_ratio( p, x, q, y, **kwds(s)), p_dists, x, q_dists, y, p_axis_names) return tf.add_n( tf.nest.flatten( distribute_lib.make_sharded_log_prob_parts( log_prob_ratio_parts_fn, # Stack, because make_sharded_log_prob_parts expects # inputs/outputs to be 1 to 1. TODO(b/175084455): revisit this # after the distributed bijectors are done, as it is likely that # make_sharded_log_prob_parts will be adjusted then to not have # this limitation. p_axis_names)(tf.nest.map_structure( lambda x, y: tf.stack([x, y], axis=0), x, y))))
def log_prob(x, data): sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [None, {self.axis_name, other_axis_name}]) parts = sharded_log_prob_parts([x, data]) return tf.add_n(parts)
def log_prob(x, y): sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [None, self.axis_name]) parts = sharded_log_prob_parts([x, y]) return tf.add_n(parts)
def log_prob(*value): w, x = value sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [None, self.axis_name, self.axis_name]) parts = sharded_log_prob_parts([w, x, data]) return tf.add_n(parts)