def _dist_jd_log_prob_ratio(p, x, q, y, name=None):
  """Distributed log-prob ratio for JDs."""
  with tf.name_scope(name or 'dist_jd_log_prob_ratio'):
    tf.nest.assert_same_structure(x, y)

    p_axis_names = p.experimental_shard_axis_names
    q_axis_names = q.experimental_shard_axis_names
    if p_axis_names != q_axis_names:
      raise ValueError('p and q must use the same sharding. '
                       f'Saw: p: {p}, {p_axis_names}, q: {q}, {q_axis_names}')

    def log_prob_ratio_parts_fn(x_y):
      x = tf.nest.map_structure(lambda part: part[0], x_y)
      y = tf.nest.map_structure(lambda part: part[1], x_y)
      p_dists = p.sample_distributions(value=x, seed=samplers.zeros_seed())[0]
      q_dists = q.sample_distributions(value=y, seed=samplers.zeros_seed())[0]
      # Ensure sharded distributions defer reductions.
      kwds = lambda a: {'reduce_over_shards': False} if a else {}
      return nest.map_structure_up_to(
          p_dists,
          lambda p, x, q, y, s: lp_ratio.log_prob_ratio(p, x, q, y, **kwds(s)),
          p_dists, x, q_dists, y, p_axis_names)

    return tf.add_n(
        tf.nest.flatten(
            distribute_lib.make_sharded_log_prob_parts(
                log_prob_ratio_parts_fn,
                # Stack, because make_sharded_log_prob_parts expects
                # inputs/outputs to be 1 to 1. TODO(b/175084455): revisit this
                # after the distributed bijectors are done, as it is likely that
                # make_sharded_log_prob_parts will be adjusted then to not have
                # this limitation.
                p_axis_names)(tf.nest.map_structure(
                    lambda x, y: tf.stack([x, y], axis=0), x, y))))
def _dist_jd_log_prob_ratio(p, x, q, y):
    """Distributed log-prob ratio for JDs."""
    tf.nest.assert_same_structure(x, y)
    if p.shard_axis_name != q.shard_axis_name:
        raise ValueError(
            'p and q must have the same shard_axis_name. '
            f'Saw: p: {p}, {p.shard_axis_name}, q: {q}, {q.shard_axis_name}')

    def log_prob_ratio_parts_fn(x_y):
        x = tf.nest.map_structure(lambda part: part[0], x_y)
        y = tf.nest.map_structure(lambda part: part[1], x_y)
        p_dists = p.sample_distributions(value=x, seed=jd_lib.dummy_seed())[0]
        q_dists = q.sample_distributions(value=y, seed=jd_lib.dummy_seed())[0]
        lp_diffs = tf.nest.map_structure(log_prob_ratio.log_prob_ratio,
                                         p_dists, x, q_dists, y)
        return lp_diffs

    return tf.add_n(
        tf.nest.flatten(
            distribute_lib.make_sharded_log_prob_parts(
                log_prob_ratio_parts_fn,
                # Stack, because make_sharded_log_prob_parts expects
                # inputs/outputs to be 1 to 1. TODO(b/175084455): revisit this
                # after the distributed bijectors are done, as it is likely that
                # make_sharded_log_prob_parts will be adjusted then to not have
                # this limitation.
                p.get_sharded_distributions(),
                axis_name=p.shard_axis_name)(tf.nest.map_structure(
                    lambda x, y: tf.stack([x, y], axis=0), x, y))))
 def log_prob(*value):
   w, x = value
   sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
       log_prob_parts, {'w': False, 'x': True, 'data': True},
       axis_name=self.axis_name)
   parts = sharded_log_prob_parts({'w': w, 'x': x, 'data': data})
   return tf.add_n(tf.nest.flatten(parts))
  def _map_measure_over_dists(self, attr, value):
    """Override the default implementation to shard its log_prob calculation."""
    if any(x is None for x in tf.nest.flatten(value)):
      raise ValueError('No `value` part can be `None`; saw: {}.'.format(value))
    if attr == 'log_prob' and any(self.experimental_shard_axis_names):

      def inner_log_prob_parts(flat_value):
        unflat_value = self._model_unflatten(flat_value)
        ds, xs = self._call_flat_sample_distributions(
            value=unflat_value, seed=samplers.zeros_seed())
        # For sharded distributions, we need to make sure not to do an
        # all-reduce.
        axis_names = self._model_flatten(self.experimental_shard_axis_names)
        log_prob_fns = [
            functools.partial(d.log_prob, reduce_over_shards=False)
            if axis_name else d.log_prob
            for d, axis_name in zip(ds, axis_names)
        ]
        # We need to flatten and unflatten here to ensure the output structure
        # matches `flat_sharded_distributions`.
        vals = self._model_unflatten(
            [log_prob_fn(x) for log_prob_fn, x in zip(log_prob_fns, xs)])
        return self._model_flatten(vals)

      flat_value = self._model_flatten(value)
      flat_axis_names = self._model_flatten(self.experimental_shard_axis_names)
      flat_xs = distribute_lib.make_sharded_log_prob_parts(
          inner_log_prob_parts, flat_axis_names)(
              flat_value)
      return iter(flat_xs)
    ds, xs = self._call_flat_sample_distributions(
        value=value, seed=samplers.zeros_seed())
    return (getattr(d, attr)(x) for d, x in zip(ds, xs))
Example #5
0
    def _map_measure_over_dists(self, attr, value):
        """Overrides the default implementation to shard its log_prob calculation."""
        if any(x is None for x in tf.nest.flatten(value)):
            raise ValueError(
                'No `value` part can be `None`; saw: {}.'.format(value))
        if attr == 'log_prob' and any(self.get_sharded_distributions()):

            def inner_log_prob_parts(flat_value):
                unflat_value = self._model_unflatten(flat_value)
                ds, xs = self._call_flat_sample_distributions(
                    value=unflat_value, seed=42)
                # We need to flatten and unflatten here to ensure the output structure
                # matches `flat_sharded_distributions`.
                vals = self._model_unflatten(
                    [getattr(d, attr)(x) for d, x in zip(ds, xs)])
                return self._model_flatten(vals)

            flat_value = self._model_flatten(value)
            flat_sharded_distributions = self._model_flatten(
                self.get_sharded_distributions())
            flat_xs = distribute_lib.make_sharded_log_prob_parts(
                inner_log_prob_parts, flat_sharded_distributions)(flat_value)
            return iter(flat_xs)
        ds, xs = self._call_flat_sample_distributions(value=value, seed=42)
        return (getattr(d, attr)(x) for d, x in zip(ds, xs))
Example #6
0
 def log_prob(*value):
     w, x = value
     sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
         log_prob_parts, [False, True, True],
         axis_name=self.axis_name)
     parts = sharded_log_prob_parts([w, x, data])
     return tf.add_n(parts)
Example #7
0
    def _log_prob(self, x, reduce_over_shards=True, **kwargs):
        def log_prob_fn(value):
            return self.distribution.log_prob(value, **kwargs)

        if reduce_over_shards:
            log_prob_fn = distribute_lib.make_sharded_log_prob_parts(
                log_prob_fn, is_sharded=True, axis_name=self.shard_axis_name)
        return log_prob_fn(x)
 def log_prob(x, y, z):
     sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
         log_prob_parts, [
             self.axis_name, other_axis_name,
             [self.axis_name, other_axis_name]
         ])
     parts = sharded_log_prob_parts([x, y, z])
     return tf.add_n(parts)
Example #9
0
    def _log_prob(self, x, reduce_over_shards=True, **kwargs):
        def log_prob_fn(value):
            new_kwargs = dict(kwargs)
            if self.distribution.experimental_shard_axis_names:
                new_kwargs['reduce_over_shards'] = reduce_over_shards
            return self.distribution.log_prob(value, **new_kwargs)

        if reduce_over_shards:
            log_prob_fn = distribute_lib.make_sharded_log_prob_parts(
                log_prob_fn, self.experimental_shard_axis_names)
        return log_prob_fn(x)
Example #10
0
    def lp_fn(self, x, reduce_over_shards=True, **kwargs):
        def impl(value):
            new_kwargs = dict(kwargs)
            if self.distribution.experimental_shard_axis_names:
                new_kwargs['reduce_over_shards'] = reduce_over_shards
            return getattr(self.distribution, fn_name)(value, **new_kwargs)

        if reduce_over_shards:
            impl = distribute_lib.make_sharded_log_prob_parts(
                impl, self.experimental_shard_axis_names)
        return impl(x)
Example #11
0
        def run(x, data):
            def log_prob_parts(value):
                x, data = value
                return [
                    tfd.Normal(0., 1.).log_prob(x),
                    tf.reduce_sum(tfd.Normal(x, 1.).log_prob(data))
                ]

            sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
                log_prob_parts, [False, True], axis_name=self.axis_name)

            return sharded_log_prob_parts([x, data])
        def run(w, x, data):
            def log_prob_parts(values):
                w, x, data = values
                return [
                    tfd.Normal(0., 1.).log_prob(w),
                    tfd.Normal(w, 1.).log_prob(x),
                    tf.reduce_sum(tfd.Normal(x, 1.).log_prob(data))
                ]

            sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
                log_prob_parts, [None, self.axis_name, self.axis_name])

            return sharded_log_prob_parts([w, x, data])
Example #13
0
  def unnormalized_log_prob(self, x, reduce_over_shards=True, **kwargs):

    def unnormalized_log_prob_fn(value):
      new_kwargs = dict(kwargs)
      if self.distribution.experimental_shard_axis_names:
        new_kwargs['reduce_over_shards'] = reduce_over_shards
      if hasattr(self.distribution, 'unnormalized_log_prob'):
        return self.distribution.unnormalized_log_prob(value, **new_kwargs)
      return self.distribution.log_prob(value, **new_kwargs)

    if reduce_over_shards:
      unnormalized_log_prob_fn = distribute_lib.make_sharded_log_prob_parts(
          unnormalized_log_prob_fn, self.experimental_shard_axis_names)
    return unnormalized_log_prob_fn(x)
  def test_correct_log_prob_for_global_variable_no_strategy(self):
    data = tf.ones(4)

    def log_prob_parts(value):
      x, data = value
      return [
          tfd.Normal(0., 1.).log_prob(x),
          tf.reduce_sum(tfd.Normal(x, 1.).log_prob(data))
      ]

    sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
        log_prob_parts, [False, True], axis_name=None)
    self.assertAllEqualNested(
        self.evaluate(sharded_log_prob_parts([tf.constant(0.), data])),
        self.evaluate([
            tfd.Normal(0., 1.).log_prob(0.),
            tf.reduce_sum(tfd.Normal(0., 1.).log_prob(data))
        ]))
    def test_correct_log_prob_for_local_variable_no_strategy(self):

        data = tf.ones(4)

        def log_prob_parts(value):
            x, data = value
            return [
                tf.reduce_sum(tfd.Normal(0., 1.).log_prob(x)),
                tf.reduce_sum(tfd.Normal(x, 1.).log_prob(data))
            ]

        sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
            log_prob_parts, [True, True])
        self.assertAllEqualNested(
            self.evaluate(sharded_log_prob_parts([tf.ones(4), data])),
            self.evaluate([
                tf.reduce_sum(tfd.Normal(0., 1.).log_prob(tf.ones(4))),
                tf.reduce_sum(tfd.Normal(tf.ones(4), 1.).log_prob(data))
            ]))
Example #16
0
def _sharded_log_prob_ratio(p, x, q, y, name=None, reduce_over_shards=True):
    """Distributed log-prob ratio for Sharded."""
    with tf.name_scope(name or 'sharded_log_prob_ratio'):
        if p.shard_axis_name != q.shard_axis_name:
            raise ValueError('Mismatched axis names '
                             f'"{p.shard_axis_name}" vs "{q.shard_axis_name}"')

        def log_prob_ratio_fn(x_y):
            return log_prob_ratio.log_prob_ratio(p.distribution, x_y[0],
                                                 q.distribution, x_y[1])

        if reduce_over_shards:
            return distribute_lib.make_sharded_log_prob_parts(
                # Stack, because make_sharded_log_prob_parts expects inputs/outputs to
                # be 1 to 1. TODO(b/175084455): revisit this after the distributed
                # bijectors are done, as it is likely that make_sharded_log_prob_parts
                # will be adjusted then to not have this limitation.
                log_prob_ratio_fn,
                is_sharded=True,
                axis_name=p.shard_axis_name)(tf.stack([x, y], axis=0))
        return log_prob_ratio_fn([x, y])
 def log_prob(x, data):
     sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
         log_prob_parts, [None, {self.axis_name, other_axis_name}])
     parts = sharded_log_prob_parts([x, data])
     return tf.add_n(parts)
 def log_prob(*value):
     w, x = value
     sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
         log_prob_parts, [None, self.axis_name, self.axis_name])
     parts = sharded_log_prob_parts([w, x, data])
     return tf.add_n(parts)
 def log_prob(x):
   sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
       log_prob_parts, [True, True], axis_name=self.axis_name)
   parts = sharded_log_prob_parts([x, data])
   return tf.add_n(parts)
 def log_prob(x, y):
     sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
         log_prob_parts, [None, self.axis_name])
     parts = sharded_log_prob_parts([x, y])
     return tf.add_n(parts)
 def log_prob(x):
     sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
         log_prob_parts, [False, True])
     parts = sharded_log_prob_parts([x, data])
     return tf.add_n(parts)