def log_prob_ratio_parts_fn(x_y): x = tf.nest.map_structure(lambda part: part[0], x_y) y = tf.nest.map_structure(lambda part: part[1], x_y) p_dists = p.sample_distributions(value=x, seed=jd_lib.dummy_seed())[0] q_dists = q.sample_distributions(value=y, seed=jd_lib.dummy_seed())[0] lp_diffs = tf.nest.map_structure(log_prob_ratio.log_prob_ratio, p_dists, x, q_dists, y) return lp_diffs
def log_prob_ratio_parts_fn(x_y): x = tf.nest.map_structure(lambda part: part[0], x_y) y = tf.nest.map_structure(lambda part: part[1], x_y) p_dists = p.sample_distributions(value=x, seed=jd_lib.dummy_seed())[0] q_dists = q.sample_distributions(value=y, seed=jd_lib.dummy_seed())[0] # Ensure sharded distributions defer reductions. kwds = lambda s: {'reduce_over_shards': False} if s else {} return tf.nest.map_structure( lambda p, x, q, y, s: lp_ratio.log_prob_ratio( p, x, q, y, **kwds(s)), p_dists, x, q_dists, y, is_sharded)
def map_measure_fn(value): # We always provide a seed, since _flat_sample_distributions will # unconditionally split the seed. with tf.name_scope('map_measure_fn'): constant_seed = joint_distribution_lib.dummy_seed() return [getattr(d, attr)(x) for (d, x) in zip( *self._flat_sample_distributions(value=value, seed=constant_seed))]
def _call_attr(self, attr): if any(self._dist_fn_args): # Const seed for maybe CSE. ds, _ = self._flat_sample_distributions( seed=joint_distribution_lib.dummy_seed()) else: ds = tuple(d() for d in self._dist_fn_wrapped) return (getattr(d, attr)() for d in ds)
def inner_log_prob_parts(flat_value): unflat_value = self._model_unflatten(flat_value) ds, xs = self._call_flat_sample_distributions( value=unflat_value, seed=jd_lib.dummy_seed()) # For sharded distributions, we need to make sure not to do an # all-reduce. flat_sharded = self._model_flatten( self.get_sharded_distributions()) log_prob_fns = [ functools.partial(d.log_prob, reduce_over_shards=False) if s else d.log_prob for d, s in zip(ds, flat_sharded) ] # We need to flatten and unflatten here to ensure the output structure # matches `flat_sharded_distributions`. vals = self._model_unflatten([ log_prob_fn(x) for log_prob_fn, x in zip(log_prob_fns, xs) ]) return self._model_flatten(vals)
def _map_measure_over_dists(self, attr, value): """Override the default implementation to shard its log_prob calculation.""" if any(x is None for x in tf.nest.flatten(value)): raise ValueError( 'No `value` part can be `None`; saw: {}.'.format(value)) if attr == 'log_prob' and any(self.get_sharded_distributions()): def inner_log_prob_parts(flat_value): unflat_value = self._model_unflatten(flat_value) ds, xs = self._call_flat_sample_distributions( value=unflat_value, seed=jd_lib.dummy_seed()) # For sharded distributions, we need to make sure not to do an # all-reduce. flat_sharded = self._model_flatten( self.get_sharded_distributions()) log_prob_fns = [ functools.partial(d.log_prob, reduce_over_shards=False) if s else d.log_prob for d, s in zip(ds, flat_sharded) ] # We need to flatten and unflatten here to ensure the output structure # matches `flat_sharded_distributions`. vals = self._model_unflatten([ log_prob_fn(x) for log_prob_fn, x in zip(log_prob_fns, xs) ]) return self._model_flatten(vals) flat_value = self._model_flatten(value) flat_sharded_distributions = self._model_flatten( self.get_sharded_distributions()) flat_xs = distribute_lib.make_sharded_log_prob_parts( inner_log_prob_parts, flat_sharded_distributions, axis_name=self.shard_axis_name)(flat_value) return iter(flat_xs) ds, xs = self._call_flat_sample_distributions(value=value, seed=jd_lib.dummy_seed()) return (getattr(d, attr)(x) for d, x in zip(ds, xs))