Ejemplo n.º 1
0
 def log_prob(x, y, z):
     sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
         log_prob_parts, [
             self.axis_name, other_axis_name,
             [self.axis_name, other_axis_name]
         ])
     parts = sharded_log_prob_parts([x, y, z])
     return tf.add_n(parts)
Ejemplo n.º 2
0
 def lp_grad(x):
     untransformed_log_prob = distribute_lib.make_sharded_log_prob_parts(
         log_prob, self.axis_name)
     transformed_log_prob = transform_log_prob(
         untransformed_log_prob,
         sharded.Sharded(tfb.Sigmoid(), shard_axis_name=self.axis_name))
     lp, g = tfp.math.value_and_gradient(transformed_log_prob, (x, ))
     return lp, g
Ejemplo n.º 3
0
 def log_prob(*value):
     w, x = value
     sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
         log_prob_parts, {
             'w': None,
             'x': self.axis_name,
             'data': self.axis_name
         })
     parts = sharded_log_prob_parts({'w': w, 'x': x, 'data': data})
     return tf.add_n(tf.nest.flatten(parts))
Ejemplo n.º 4
0
    def lp_fn(self, x, reduce_over_shards=True, **kwargs):
        def impl(value):
            new_kwargs = dict(kwargs)
            if self.distribution.experimental_shard_axis_names:
                new_kwargs['reduce_over_shards'] = reduce_over_shards
            return getattr(self.distribution, fn_name)(value, **new_kwargs)

        if reduce_over_shards:
            impl = distribute_lib.make_sharded_log_prob_parts(
                impl, self.experimental_shard_axis_names)
        return impl(x)
Ejemplo n.º 5
0
        def run(x, data):
            def log_prob_parts(value):
                x, data = value
                return [
                    tfd.Normal(0., 1.).log_prob(x),
                    tf.reduce_sum(tfd.Normal(x, 1.).log_prob(data))
                ]

            sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
                log_prob_parts, [self.axis_name, self.axis_name])

            return sharded_log_prob_parts([x, data])
Ejemplo n.º 6
0
    def test_correct_log_prob_for_global_variable_no_strategy(self):
        data = tf.ones(4)

        def log_prob_parts(value):
            x, data = value
            return [
                tfd.Normal(0., 1.).log_prob(x),
                tf.reduce_sum(tfd.Normal(x, 1.).log_prob(data))
            ]

        sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
            log_prob_parts, [None, True])
        self.assertAllEqualNested(
            self.evaluate(sharded_log_prob_parts([tf.constant(0.), data])),
            self.evaluate([
                tfd.Normal(0., 1.).log_prob(0.),
                tf.reduce_sum(tfd.Normal(0., 1.).log_prob(data))
            ]))
Ejemplo n.º 7
0
def _sharded_log_prob_ratio(p, x, q, y, name=None, reduce_over_shards=True):
    """Distributed log-prob ratio for Sharded."""
    with tf.name_scope(name or 'sharded_log_prob_ratio'):
        if p.experimental_shard_axis_names != q.experimental_shard_axis_names:
            raise ValueError('Mismatched axis names '
                             f'"{p.experimental_shard_axis_names}" vs "'
                             f'"{q.experimental_shard_axis_names}"')

        def log_prob_ratio_fn(x_y):
            return log_prob_ratio.log_prob_ratio(p.distribution, x_y[0],
                                                 q.distribution, x_y[1])

        if reduce_over_shards:
            return distribute_lib.make_sharded_log_prob_parts(
                # Stack, because make_sharded_log_prob_parts expects inputs/outputs to
                # be 1 to 1. TODO(b/175084455): revisit this after the distributed
                # bijectors are done, as it is likely that make_sharded_log_prob_parts
                # will be adjusted then to not have this limitation.
                log_prob_ratio_fn,
                p.experimental_shard_axis_names)(tf.stack([x, y], axis=0))
        return log_prob_ratio_fn([x, y])
Ejemplo n.º 8
0
    def _map_measure_over_dists(self, attr, value):
        """Override the default implementation to shard its log_prob calculation."""
        if any(x is None for x in tf.nest.flatten(value)):
            raise ValueError(
                'No `value` part can be `None`; saw: {}.'.format(value))
        if (attr in ('log_prob', 'unnormalized_log_prob')) and any(
                self.experimental_shard_axis_names):

            def inner_log_prob_parts(flat_value):
                unflat_value = self._model_unflatten(flat_value)
                ds, xs = self._call_flat_sample_distributions(
                    value=unflat_value, seed=samplers.zeros_seed())
                # For sharded distributions, we need to make sure not to do an
                # all-reduce.
                axis_names = self._model_flatten(
                    self.experimental_shard_axis_names)
                log_prob_fns = [
                    functools.partial(getattr(d, attr),
                                      reduce_over_shards=False)
                    if axis_name else getattr(d, attr)
                    for d, axis_name in zip(ds, axis_names)
                ]
                # We need to flatten and unflatten here to ensure the output structure
                # matches `flat_sharded_distributions`.
                vals = self._model_unflatten([
                    log_prob_fn(x) for log_prob_fn, x in zip(log_prob_fns, xs)
                ])
                return self._model_flatten(vals)

            flat_value = self._model_flatten(value)
            flat_axis_names = self._model_flatten(
                self.experimental_shard_axis_names)
            flat_xs = distribute_lib.make_sharded_log_prob_parts(
                inner_log_prob_parts, flat_axis_names)(flat_value)
            return iter(flat_xs)
        ds, xs = self._call_flat_sample_distributions(
            value=value, seed=samplers.zeros_seed())
        return (getattr(d, attr)(x) for d, x in zip(ds, xs))
Ejemplo n.º 9
0
def _dist_jd_log_prob_ratio(p, x, q, y, name=None):
    """Distributed log-prob ratio for JDs."""
    with tf.name_scope(name or 'dist_jd_log_prob_ratio'):
        tf.nest.assert_same_structure(x, y)

        p_axis_names = p.experimental_shard_axis_names
        q_axis_names = q.experimental_shard_axis_names
        if p_axis_names != q_axis_names:
            raise ValueError(
                'p and q must use the same sharding. '
                f'Saw: p: {p}, {p_axis_names}, q: {q}, {q_axis_names}')

        def log_prob_ratio_parts_fn(x_y):
            x = tf.nest.map_structure(lambda part: part[0], x_y)
            y = tf.nest.map_structure(lambda part: part[1], x_y)
            p_dists = p.sample_distributions(value=x,
                                             seed=samplers.zeros_seed())[0]
            q_dists = q.sample_distributions(value=y,
                                             seed=samplers.zeros_seed())[0]
            # Ensure sharded distributions defer reductions.
            kwds = lambda a: {'reduce_over_shards': False} if a else {}
            return nest.map_structure_up_to(
                p_dists, lambda p, x, q, y, s: lp_ratio.log_prob_ratio(
                    p, x, q, y, **kwds(s)), p_dists, x, q_dists, y,
                p_axis_names)

        return tf.add_n(
            tf.nest.flatten(
                distribute_lib.make_sharded_log_prob_parts(
                    log_prob_ratio_parts_fn,
                    # Stack, because make_sharded_log_prob_parts expects
                    # inputs/outputs to be 1 to 1. TODO(b/175084455): revisit this
                    # after the distributed bijectors are done, as it is likely that
                    # make_sharded_log_prob_parts will be adjusted then to not have
                    # this limitation.
                    p_axis_names)(tf.nest.map_structure(
                        lambda x, y: tf.stack([x, y], axis=0), x, y))))
Ejemplo n.º 10
0
 def log_prob(x, data):
     sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
         log_prob_parts, [None, {self.axis_name, other_axis_name}])
     parts = sharded_log_prob_parts([x, data])
     return tf.add_n(parts)
Ejemplo n.º 11
0
 def log_prob(x, y):
     sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
         log_prob_parts, [None, self.axis_name])
     parts = sharded_log_prob_parts([x, y])
     return tf.add_n(parts)
Ejemplo n.º 12
0
 def log_prob(*value):
     w, x = value
     sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
         log_prob_parts, [None, self.axis_name, self.axis_name])
     parts = sharded_log_prob_parts([w, x, data])
     return tf.add_n(parts)