Esempio n. 1
0
    def test_psum_binary_function_applies_psum_to_outputs(self):
        def f(x, y):
            return x + y

        f_psum = distribute_lib.make_psum_function(
            f, (self.axis_name, self.axis_name),
            self.axis_name,
            out_dtype=tf.float32)

        x = self.shard_values(tf.ones(4))
        y = self.shard_values(2 * tf.ones(4))
        out_parts = self.per_replica_to_tensor(
            self.strategy_run(f_psum, (x, y)))

        self.assertAllEqual(self.evaluate(out_parts),
                            self.evaluate(12 * tf.ones(4)))

        f_psum = distribute_lib.make_psum_function(
            f, (self.axis_name, self.axis_name), None, out_dtype=tf.float32)

        x = self.shard_values(tf.ones(4))
        y = self.shard_values(2 * tf.ones(4))
        out_parts = self.per_replica_to_tensor(
            self.strategy_run(f_psum, (x, y)))

        self.assertAllEqual(self.evaluate(out_parts),
                            self.evaluate(3 * tf.ones(4)))

        f_psum = distribute_lib.make_psum_function(
            f, (self.axis_name, self.axis_name), None, out_dtype=tf.float32)

        x = self.shard_values(tf.ones(4))
        y = self.shard_values(2 * tf.ones(4))
        out_parts = self.per_replica_to_tensor(
            self.strategy_run(f_psum, (x, y), in_axes=(0, 0)))

        self.assertAllEqual(self.evaluate(out_parts),
                            self.evaluate(3 * tf.ones(4)))

        f_psum = distribute_lib.make_psum_function(f, (self.axis_name, None),
                                                   None,
                                                   out_dtype=tf.float32)

        x = self.shard_values(tf.ones(4))
        y = 2.
        out_parts = self.per_replica_to_tensor(
            self.strategy_run(f_psum, (x, y), in_axes=(0, None)))

        self.assertAllEqual(self.evaluate(out_parts),
                            self.evaluate(3 * tf.ones(4)))
def _dist_jd_log_prob_ratio(p, x, q, y, name=None):
    """Distributed log-prob ratio for JDs."""
    with tf.name_scope(name or 'dist_jd_log_prob_ratio'):
        tf.nest.assert_same_structure(x, y)

        p_axis_names = p.experimental_shard_axis_names
        q_axis_names = q.experimental_shard_axis_names
        if p_axis_names != q_axis_names:
            raise ValueError(
                'p and q must use the same sharding. '
                f'Saw: p: {p}, {p_axis_names}, q: {q}, {q_axis_names}')

        def log_prob_ratio_parts_fn(x, y):
            p_dists = p.sample_distributions(value=x,
                                             seed=samplers.zeros_seed())[0]
            q_dists = q.sample_distributions(value=y,
                                             seed=samplers.zeros_seed())[0]
            # Ensure sharded distributions defer reductions.
            kwds = lambda a: {'reduce_over_shards': False} if a else {}
            return nest.map_structure_up_to(
                p_dists, lambda p, x, q, y, s: lp_ratio.log_prob_ratio(
                    p, x, q, y, **kwds(s)), p_dists, x, q_dists, y,
                p_axis_names)

        return tf.add_n(
            tf.nest.flatten(
                distribute_lib.make_psum_function(log_prob_ratio_parts_fn,
                                                  in_axes=(p_axis_names,
                                                           p_axis_names),
                                                  out_axes=p_axis_names,
                                                  out_dtype=x)(x, y)))
Esempio n. 3
0
    def test_psum_unary_function_applies_psum_to_outputs(self):
        def f(x):
            return x

        f = distribute_lib.make_psum_function(f,
                                              self.axis_name,
                                              self.axis_name,
                                              out_dtype=tf.float32)

        x = self.shard_values(tf.ones(4))
        out_parts = self.per_replica_to_tensor(self.strategy_run(f, (x, )))

        self.assertAllEqual(self.evaluate(out_parts),
                            self.evaluate(4 * tf.ones(4)))
Esempio n. 4
0
def _sharded_log_prob_ratio(p, x, q, y, name=None):
    """Distributed log-prob ratio for Sharded."""
    with tf.name_scope(name or 'sharded_log_prob_ratio'):
        if p.experimental_shard_axis_names != q.experimental_shard_axis_names:
            raise ValueError('Mismatched axis names '
                             f'"{p.experimental_shard_axis_names}" vs "'
                             f'"{q.experimental_shard_axis_names}"')

        def log_prob_ratio_fn(x, y):
            return log_prob_ratio.log_prob_ratio(p.distribution, x,
                                                 q.distribution, y)

        axes = p.experimental_shard_axis_names
        return distribute_lib.make_psum_function(log_prob_ratio_fn,
                                                 in_axes=(axes, axes),
                                                 out_axes=axes,
                                                 out_dtype=x)(x, y)
Esempio n. 5
0
    def test_psum_binary_function_corrects_gradients_to_inputs(self):
        def f(x, y):
            return x * y

        f_psum = distribute_lib.make_psum_function(
            f, (self.axis_name, self.axis_name),
            self.axis_name,
            out_dtype=tf.float32)

        def f_grad(x, y):
            return tfp.math.value_and_gradient(f_psum, (x, y))[1]

        x = self.shard_values(tf.ones(4))
        y = self.shard_values(2. * tf.range(4.))
        out_grads = self.per_replica_to_tensor(
            self.strategy_run(f_grad, (x, y)))

        self.assertAllEqual(self.evaluate(out_grads[0]), 2. * tf.range(4.))
        self.assertAllEqual(self.evaluate(out_grads[1]), tf.ones(4))

        f_psum = distribute_lib.make_psum_function(f, (self.axis_name, None),
                                                   self.axis_name,
                                                   out_dtype=tf.float32)

        def f_grad2(x, y):
            return tfp.math.value_and_gradient(f_psum, (x, y))[1]

        x = self.shard_values(tf.range(4.))
        y = 2.
        out_grads = self.per_replica_to_tensor(
            self.strategy_run(f_grad2, (x, y), in_axes=(0, None)))

        self.assertAllEqual(self.evaluate(out_grads[0]), 2 * tf.ones(4))
        self.assertAllEqual(self.evaluate(out_grads[1]), 6 * tf.ones(4))

        f_psum = distribute_lib.make_psum_function(
            f, (self.axis_name, self.axis_name), None, out_dtype=tf.float32)

        def f_grad3(x, y):
            return tfp.math.value_and_gradient(f_psum, (x, y))[1]

        x = self.shard_values(tf.range(4.))
        y = self.shard_values(tf.ones(4))
        out_grads = self.per_replica_to_tensor(
            self.strategy_run(f_grad3, (x, y)))

        self.assertAllEqual(self.evaluate(out_grads[0]), tf.ones(4))
        self.assertAllEqual(self.evaluate(out_grads[1]), tf.range(4.))

        f_psum = distribute_lib.make_psum_function(f, (self.axis_name, None),
                                                   None,
                                                   out_dtype=tf.float32)

        def f_grad4(x, y):
            return tfp.math.value_and_gradient(f_psum, (x, y))[1]

        x = self.shard_values(tf.range(4.))
        y = 2.
        out_grads = self.per_replica_to_tensor(
            self.strategy_run(f_grad4, (x, y), in_axes=(0, None)))

        self.assertAllEqual(self.evaluate(out_grads[0]), 2 * tf.ones(4))
        self.assertAllEqual(self.evaluate(out_grads[1]), tf.range(4.))