def _dist_jd_log_prob_ratio(p, x, q, y, name=None): """Distributed log-prob ratio for JDs.""" with tf.name_scope(name or 'dist_jd_log_prob_ratio'): tf.nest.assert_same_structure(x, y) p_axis_names = p.experimental_shard_axis_names q_axis_names = q.experimental_shard_axis_names if p_axis_names != q_axis_names: raise ValueError('p and q must use the same sharding. ' f'Saw: p: {p}, {p_axis_names}, q: {q}, {q_axis_names}') def log_prob_ratio_parts_fn(x_y): x = tf.nest.map_structure(lambda part: part[0], x_y) y = tf.nest.map_structure(lambda part: part[1], x_y) p_dists = p.sample_distributions(value=x, seed=samplers.zeros_seed())[0] q_dists = q.sample_distributions(value=y, seed=samplers.zeros_seed())[0] # Ensure sharded distributions defer reductions. kwds = lambda a: {'reduce_over_shards': False} if a else {} return nest.map_structure_up_to( p_dists, lambda p, x, q, y, s: lp_ratio.log_prob_ratio(p, x, q, y, **kwds(s)), p_dists, x, q_dists, y, p_axis_names) return tf.add_n( tf.nest.flatten( distribute_lib.make_sharded_log_prob_parts( log_prob_ratio_parts_fn, # Stack, because make_sharded_log_prob_parts expects # inputs/outputs to be 1 to 1. TODO(b/175084455): revisit this # after the distributed bijectors are done, as it is likely that # make_sharded_log_prob_parts will be adjusted then to not have # this limitation. p_axis_names)(tf.nest.map_structure( lambda x, y: tf.stack([x, y], axis=0), x, y))))
def _dist_jd_log_prob_ratio(p, x, q, y): """Distributed log-prob ratio for JDs.""" tf.nest.assert_same_structure(x, y) if p.shard_axis_name != q.shard_axis_name: raise ValueError( 'p and q must have the same shard_axis_name. ' f'Saw: p: {p}, {p.shard_axis_name}, q: {q}, {q.shard_axis_name}') def log_prob_ratio_parts_fn(x_y): x = tf.nest.map_structure(lambda part: part[0], x_y) y = tf.nest.map_structure(lambda part: part[1], x_y) p_dists = p.sample_distributions(value=x, seed=jd_lib.dummy_seed())[0] q_dists = q.sample_distributions(value=y, seed=jd_lib.dummy_seed())[0] lp_diffs = tf.nest.map_structure(log_prob_ratio.log_prob_ratio, p_dists, x, q_dists, y) return lp_diffs return tf.add_n( tf.nest.flatten( distribute_lib.make_sharded_log_prob_parts( log_prob_ratio_parts_fn, # Stack, because make_sharded_log_prob_parts expects # inputs/outputs to be 1 to 1. TODO(b/175084455): revisit this # after the distributed bijectors are done, as it is likely that # make_sharded_log_prob_parts will be adjusted then to not have # this limitation. p.get_sharded_distributions(), axis_name=p.shard_axis_name)(tf.nest.map_structure( lambda x, y: tf.stack([x, y], axis=0), x, y))))
def log_prob(*value): w, x = value sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, {'w': False, 'x': True, 'data': True}, axis_name=self.axis_name) parts = sharded_log_prob_parts({'w': w, 'x': x, 'data': data}) return tf.add_n(tf.nest.flatten(parts))
def _map_measure_over_dists(self, attr, value): """Override the default implementation to shard its log_prob calculation.""" if any(x is None for x in tf.nest.flatten(value)): raise ValueError('No `value` part can be `None`; saw: {}.'.format(value)) if attr == 'log_prob' and any(self.experimental_shard_axis_names): def inner_log_prob_parts(flat_value): unflat_value = self._model_unflatten(flat_value) ds, xs = self._call_flat_sample_distributions( value=unflat_value, seed=samplers.zeros_seed()) # For sharded distributions, we need to make sure not to do an # all-reduce. axis_names = self._model_flatten(self.experimental_shard_axis_names) log_prob_fns = [ functools.partial(d.log_prob, reduce_over_shards=False) if axis_name else d.log_prob for d, axis_name in zip(ds, axis_names) ] # We need to flatten and unflatten here to ensure the output structure # matches `flat_sharded_distributions`. vals = self._model_unflatten( [log_prob_fn(x) for log_prob_fn, x in zip(log_prob_fns, xs)]) return self._model_flatten(vals) flat_value = self._model_flatten(value) flat_axis_names = self._model_flatten(self.experimental_shard_axis_names) flat_xs = distribute_lib.make_sharded_log_prob_parts( inner_log_prob_parts, flat_axis_names)( flat_value) return iter(flat_xs) ds, xs = self._call_flat_sample_distributions( value=value, seed=samplers.zeros_seed()) return (getattr(d, attr)(x) for d, x in zip(ds, xs))
def _map_measure_over_dists(self, attr, value): """Overrides the default implementation to shard its log_prob calculation.""" if any(x is None for x in tf.nest.flatten(value)): raise ValueError( 'No `value` part can be `None`; saw: {}.'.format(value)) if attr == 'log_prob' and any(self.get_sharded_distributions()): def inner_log_prob_parts(flat_value): unflat_value = self._model_unflatten(flat_value) ds, xs = self._call_flat_sample_distributions( value=unflat_value, seed=42) # We need to flatten and unflatten here to ensure the output structure # matches `flat_sharded_distributions`. vals = self._model_unflatten( [getattr(d, attr)(x) for d, x in zip(ds, xs)]) return self._model_flatten(vals) flat_value = self._model_flatten(value) flat_sharded_distributions = self._model_flatten( self.get_sharded_distributions()) flat_xs = distribute_lib.make_sharded_log_prob_parts( inner_log_prob_parts, flat_sharded_distributions)(flat_value) return iter(flat_xs) ds, xs = self._call_flat_sample_distributions(value=value, seed=42) return (getattr(d, attr)(x) for d, x in zip(ds, xs))
def log_prob(*value): w, x = value sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [False, True, True], axis_name=self.axis_name) parts = sharded_log_prob_parts([w, x, data]) return tf.add_n(parts)
def _log_prob(self, x, reduce_over_shards=True, **kwargs): def log_prob_fn(value): return self.distribution.log_prob(value, **kwargs) if reduce_over_shards: log_prob_fn = distribute_lib.make_sharded_log_prob_parts( log_prob_fn, is_sharded=True, axis_name=self.shard_axis_name) return log_prob_fn(x)
def log_prob(x, y, z): sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [ self.axis_name, other_axis_name, [self.axis_name, other_axis_name] ]) parts = sharded_log_prob_parts([x, y, z]) return tf.add_n(parts)
def _log_prob(self, x, reduce_over_shards=True, **kwargs): def log_prob_fn(value): new_kwargs = dict(kwargs) if self.distribution.experimental_shard_axis_names: new_kwargs['reduce_over_shards'] = reduce_over_shards return self.distribution.log_prob(value, **new_kwargs) if reduce_over_shards: log_prob_fn = distribute_lib.make_sharded_log_prob_parts( log_prob_fn, self.experimental_shard_axis_names) return log_prob_fn(x)
def lp_fn(self, x, reduce_over_shards=True, **kwargs): def impl(value): new_kwargs = dict(kwargs) if self.distribution.experimental_shard_axis_names: new_kwargs['reduce_over_shards'] = reduce_over_shards return getattr(self.distribution, fn_name)(value, **new_kwargs) if reduce_over_shards: impl = distribute_lib.make_sharded_log_prob_parts( impl, self.experimental_shard_axis_names) return impl(x)
def run(x, data): def log_prob_parts(value): x, data = value return [ tfd.Normal(0., 1.).log_prob(x), tf.reduce_sum(tfd.Normal(x, 1.).log_prob(data)) ] sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [False, True], axis_name=self.axis_name) return sharded_log_prob_parts([x, data])
def run(w, x, data): def log_prob_parts(values): w, x, data = values return [ tfd.Normal(0., 1.).log_prob(w), tfd.Normal(w, 1.).log_prob(x), tf.reduce_sum(tfd.Normal(x, 1.).log_prob(data)) ] sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [None, self.axis_name, self.axis_name]) return sharded_log_prob_parts([w, x, data])
def unnormalized_log_prob(self, x, reduce_over_shards=True, **kwargs): def unnormalized_log_prob_fn(value): new_kwargs = dict(kwargs) if self.distribution.experimental_shard_axis_names: new_kwargs['reduce_over_shards'] = reduce_over_shards if hasattr(self.distribution, 'unnormalized_log_prob'): return self.distribution.unnormalized_log_prob(value, **new_kwargs) return self.distribution.log_prob(value, **new_kwargs) if reduce_over_shards: unnormalized_log_prob_fn = distribute_lib.make_sharded_log_prob_parts( unnormalized_log_prob_fn, self.experimental_shard_axis_names) return unnormalized_log_prob_fn(x)
def test_correct_log_prob_for_global_variable_no_strategy(self): data = tf.ones(4) def log_prob_parts(value): x, data = value return [ tfd.Normal(0., 1.).log_prob(x), tf.reduce_sum(tfd.Normal(x, 1.).log_prob(data)) ] sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [False, True], axis_name=None) self.assertAllEqualNested( self.evaluate(sharded_log_prob_parts([tf.constant(0.), data])), self.evaluate([ tfd.Normal(0., 1.).log_prob(0.), tf.reduce_sum(tfd.Normal(0., 1.).log_prob(data)) ]))
def test_correct_log_prob_for_local_variable_no_strategy(self): data = tf.ones(4) def log_prob_parts(value): x, data = value return [ tf.reduce_sum(tfd.Normal(0., 1.).log_prob(x)), tf.reduce_sum(tfd.Normal(x, 1.).log_prob(data)) ] sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [True, True]) self.assertAllEqualNested( self.evaluate(sharded_log_prob_parts([tf.ones(4), data])), self.evaluate([ tf.reduce_sum(tfd.Normal(0., 1.).log_prob(tf.ones(4))), tf.reduce_sum(tfd.Normal(tf.ones(4), 1.).log_prob(data)) ]))
def _sharded_log_prob_ratio(p, x, q, y, name=None, reduce_over_shards=True): """Distributed log-prob ratio for Sharded.""" with tf.name_scope(name or 'sharded_log_prob_ratio'): if p.shard_axis_name != q.shard_axis_name: raise ValueError('Mismatched axis names ' f'"{p.shard_axis_name}" vs "{q.shard_axis_name}"') def log_prob_ratio_fn(x_y): return log_prob_ratio.log_prob_ratio(p.distribution, x_y[0], q.distribution, x_y[1]) if reduce_over_shards: return distribute_lib.make_sharded_log_prob_parts( # Stack, because make_sharded_log_prob_parts expects inputs/outputs to # be 1 to 1. TODO(b/175084455): revisit this after the distributed # bijectors are done, as it is likely that make_sharded_log_prob_parts # will be adjusted then to not have this limitation. log_prob_ratio_fn, is_sharded=True, axis_name=p.shard_axis_name)(tf.stack([x, y], axis=0)) return log_prob_ratio_fn([x, y])
def log_prob(x, data): sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [None, {self.axis_name, other_axis_name}]) parts = sharded_log_prob_parts([x, data]) return tf.add_n(parts)
def log_prob(*value): w, x = value sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [None, self.axis_name, self.axis_name]) parts = sharded_log_prob_parts([w, x, data]) return tf.add_n(parts)
def log_prob(x): sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [True, True], axis_name=self.axis_name) parts = sharded_log_prob_parts([x, data]) return tf.add_n(parts)
def log_prob(x, y): sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [None, self.axis_name]) parts = sharded_log_prob_parts([x, y]) return tf.add_n(parts)
def log_prob(x): sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( log_prob_parts, [False, True]) parts = sharded_log_prob_parts([x, data]) return tf.add_n(parts)