Exemplo n.º 1
0
 def log_prob_ratio_parts_fn(x_y):
     x = tf.nest.map_structure(lambda part: part[0], x_y)
     y = tf.nest.map_structure(lambda part: part[1], x_y)
     p_dists = p.sample_distributions(value=x, seed=jd_lib.dummy_seed())[0]
     q_dists = q.sample_distributions(value=y, seed=jd_lib.dummy_seed())[0]
     lp_diffs = tf.nest.map_structure(log_prob_ratio.log_prob_ratio,
                                      p_dists, x, q_dists, y)
     return lp_diffs
Exemplo n.º 2
0
 def log_prob_ratio_parts_fn(x_y):
     x = tf.nest.map_structure(lambda part: part[0], x_y)
     y = tf.nest.map_structure(lambda part: part[1], x_y)
     p_dists = p.sample_distributions(value=x, seed=jd_lib.dummy_seed())[0]
     q_dists = q.sample_distributions(value=y, seed=jd_lib.dummy_seed())[0]
     # Ensure sharded distributions defer reductions.
     kwds = lambda s: {'reduce_over_shards': False} if s else {}
     return tf.nest.map_structure(
         lambda p, x, q, y, s: lp_ratio.log_prob_ratio(
             p, x, q, y, **kwds(s)), p_dists, x, q_dists, y, is_sharded)
Exemplo n.º 3
0
 def map_measure_fn(value):
   # We always provide a seed, since _flat_sample_distributions will
   # unconditionally split the seed.
   with tf.name_scope('map_measure_fn'):
     constant_seed = joint_distribution_lib.dummy_seed()
     return [getattr(d, attr)(x) for (d, x) in zip(
         *self._flat_sample_distributions(value=value, seed=constant_seed))]
 def _call_attr(self, attr):
   if any(self._dist_fn_args):
     # Const seed for maybe CSE.
     ds, _ = self._flat_sample_distributions(
         seed=joint_distribution_lib.dummy_seed())
   else:
     ds = tuple(d() for d in self._dist_fn_wrapped)
   return (getattr(d, attr)() for d in ds)
Exemplo n.º 5
0
 def inner_log_prob_parts(flat_value):
     unflat_value = self._model_unflatten(flat_value)
     ds, xs = self._call_flat_sample_distributions(
         value=unflat_value, seed=jd_lib.dummy_seed())
     # For sharded distributions, we need to make sure not to do an
     # all-reduce.
     flat_sharded = self._model_flatten(
         self.get_sharded_distributions())
     log_prob_fns = [
         functools.partial(d.log_prob, reduce_over_shards=False)
         if s else d.log_prob for d, s in zip(ds, flat_sharded)
     ]
     # We need to flatten and unflatten here to ensure the output structure
     # matches `flat_sharded_distributions`.
     vals = self._model_unflatten([
         log_prob_fn(x) for log_prob_fn, x in zip(log_prob_fns, xs)
     ])
     return self._model_flatten(vals)
Exemplo n.º 6
0
    def _map_measure_over_dists(self, attr, value):
        """Override the default implementation to shard its log_prob calculation."""
        if any(x is None for x in tf.nest.flatten(value)):
            raise ValueError(
                'No `value` part can be `None`; saw: {}.'.format(value))
        if attr == 'log_prob' and any(self.get_sharded_distributions()):

            def inner_log_prob_parts(flat_value):
                unflat_value = self._model_unflatten(flat_value)
                ds, xs = self._call_flat_sample_distributions(
                    value=unflat_value, seed=jd_lib.dummy_seed())
                # For sharded distributions, we need to make sure not to do an
                # all-reduce.
                flat_sharded = self._model_flatten(
                    self.get_sharded_distributions())
                log_prob_fns = [
                    functools.partial(d.log_prob, reduce_over_shards=False)
                    if s else d.log_prob for d, s in zip(ds, flat_sharded)
                ]
                # We need to flatten and unflatten here to ensure the output structure
                # matches `flat_sharded_distributions`.
                vals = self._model_unflatten([
                    log_prob_fn(x) for log_prob_fn, x in zip(log_prob_fns, xs)
                ])
                return self._model_flatten(vals)

            flat_value = self._model_flatten(value)
            flat_sharded_distributions = self._model_flatten(
                self.get_sharded_distributions())
            flat_xs = distribute_lib.make_sharded_log_prob_parts(
                inner_log_prob_parts,
                flat_sharded_distributions,
                axis_name=self.shard_axis_name)(flat_value)
            return iter(flat_xs)
        ds, xs = self._call_flat_sample_distributions(value=value,
                                                      seed=jd_lib.dummy_seed())
        return (getattr(d, attr)(x) for d, x in zip(ds, xs))