def test_psum_binary_function_applies_psum_to_outputs(self): def f(x, y): return x + y f_psum = distribute_lib.make_psum_function( f, (self.axis_name, self.axis_name), self.axis_name, out_dtype=tf.float32) x = self.shard_values(tf.ones(4)) y = self.shard_values(2 * tf.ones(4)) out_parts = self.per_replica_to_tensor( self.strategy_run(f_psum, (x, y))) self.assertAllEqual(self.evaluate(out_parts), self.evaluate(12 * tf.ones(4))) f_psum = distribute_lib.make_psum_function( f, (self.axis_name, self.axis_name), None, out_dtype=tf.float32) x = self.shard_values(tf.ones(4)) y = self.shard_values(2 * tf.ones(4)) out_parts = self.per_replica_to_tensor( self.strategy_run(f_psum, (x, y))) self.assertAllEqual(self.evaluate(out_parts), self.evaluate(3 * tf.ones(4))) f_psum = distribute_lib.make_psum_function( f, (self.axis_name, self.axis_name), None, out_dtype=tf.float32) x = self.shard_values(tf.ones(4)) y = self.shard_values(2 * tf.ones(4)) out_parts = self.per_replica_to_tensor( self.strategy_run(f_psum, (x, y), in_axes=(0, 0))) self.assertAllEqual(self.evaluate(out_parts), self.evaluate(3 * tf.ones(4))) f_psum = distribute_lib.make_psum_function(f, (self.axis_name, None), None, out_dtype=tf.float32) x = self.shard_values(tf.ones(4)) y = 2. out_parts = self.per_replica_to_tensor( self.strategy_run(f_psum, (x, y), in_axes=(0, None))) self.assertAllEqual(self.evaluate(out_parts), self.evaluate(3 * tf.ones(4)))
def _dist_jd_log_prob_ratio(p, x, q, y, name=None): """Distributed log-prob ratio for JDs.""" with tf.name_scope(name or 'dist_jd_log_prob_ratio'): tf.nest.assert_same_structure(x, y) p_axis_names = p.experimental_shard_axis_names q_axis_names = q.experimental_shard_axis_names if p_axis_names != q_axis_names: raise ValueError( 'p and q must use the same sharding. ' f'Saw: p: {p}, {p_axis_names}, q: {q}, {q_axis_names}') def log_prob_ratio_parts_fn(x, y): p_dists = p.sample_distributions(value=x, seed=samplers.zeros_seed())[0] q_dists = q.sample_distributions(value=y, seed=samplers.zeros_seed())[0] # Ensure sharded distributions defer reductions. kwds = lambda a: {'reduce_over_shards': False} if a else {} return nest.map_structure_up_to( p_dists, lambda p, x, q, y, s: lp_ratio.log_prob_ratio( p, x, q, y, **kwds(s)), p_dists, x, q_dists, y, p_axis_names) return tf.add_n( tf.nest.flatten( distribute_lib.make_psum_function(log_prob_ratio_parts_fn, in_axes=(p_axis_names, p_axis_names), out_axes=p_axis_names, out_dtype=x)(x, y)))
def test_psum_unary_function_applies_psum_to_outputs(self): def f(x): return x f = distribute_lib.make_psum_function(f, self.axis_name, self.axis_name, out_dtype=tf.float32) x = self.shard_values(tf.ones(4)) out_parts = self.per_replica_to_tensor(self.strategy_run(f, (x, ))) self.assertAllEqual(self.evaluate(out_parts), self.evaluate(4 * tf.ones(4)))
def _sharded_log_prob_ratio(p, x, q, y, name=None): """Distributed log-prob ratio for Sharded.""" with tf.name_scope(name or 'sharded_log_prob_ratio'): if p.experimental_shard_axis_names != q.experimental_shard_axis_names: raise ValueError('Mismatched axis names ' f'"{p.experimental_shard_axis_names}" vs "' f'"{q.experimental_shard_axis_names}"') def log_prob_ratio_fn(x, y): return log_prob_ratio.log_prob_ratio(p.distribution, x, q.distribution, y) axes = p.experimental_shard_axis_names return distribute_lib.make_psum_function(log_prob_ratio_fn, in_axes=(axes, axes), out_axes=axes, out_dtype=x)(x, y)
def test_psum_binary_function_corrects_gradients_to_inputs(self): def f(x, y): return x * y f_psum = distribute_lib.make_psum_function( f, (self.axis_name, self.axis_name), self.axis_name, out_dtype=tf.float32) def f_grad(x, y): return tfp.math.value_and_gradient(f_psum, (x, y))[1] x = self.shard_values(tf.ones(4)) y = self.shard_values(2. * tf.range(4.)) out_grads = self.per_replica_to_tensor( self.strategy_run(f_grad, (x, y))) self.assertAllEqual(self.evaluate(out_grads[0]), 2. * tf.range(4.)) self.assertAllEqual(self.evaluate(out_grads[1]), tf.ones(4)) f_psum = distribute_lib.make_psum_function(f, (self.axis_name, None), self.axis_name, out_dtype=tf.float32) def f_grad2(x, y): return tfp.math.value_and_gradient(f_psum, (x, y))[1] x = self.shard_values(tf.range(4.)) y = 2. out_grads = self.per_replica_to_tensor( self.strategy_run(f_grad2, (x, y), in_axes=(0, None))) self.assertAllEqual(self.evaluate(out_grads[0]), 2 * tf.ones(4)) self.assertAllEqual(self.evaluate(out_grads[1]), 6 * tf.ones(4)) f_psum = distribute_lib.make_psum_function( f, (self.axis_name, self.axis_name), None, out_dtype=tf.float32) def f_grad3(x, y): return tfp.math.value_and_gradient(f_psum, (x, y))[1] x = self.shard_values(tf.range(4.)) y = self.shard_values(tf.ones(4)) out_grads = self.per_replica_to_tensor( self.strategy_run(f_grad3, (x, y))) self.assertAllEqual(self.evaluate(out_grads[0]), tf.ones(4)) self.assertAllEqual(self.evaluate(out_grads[1]), tf.range(4.)) f_psum = distribute_lib.make_psum_function(f, (self.axis_name, None), None, out_dtype=tf.float32) def f_grad4(x, y): return tfp.math.value_and_gradient(f_psum, (x, y))[1] x = self.shard_values(tf.range(4.)) y = 2. out_grads = self.per_replica_to_tensor( self.strategy_run(f_grad4, (x, y), in_axes=(0, None))) self.assertAllEqual(self.evaluate(out_grads[0]), 2 * tf.ones(4)) self.assertAllEqual(self.evaluate(out_grads[1]), tf.range(4.))