def testVjp(self, has_aux): x_shape = (tf.TensorShape([10]), tf.TensorShape([1, 10])) y_shape = (tf.TensorShape([])) dtype = np.float32 def f(a, b): y = tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) if has_aux: return y, tf_np.asarray(1) else: return y rng = tf.random.Generator.from_seed(1234) x, dy_list = tf.nest.map_structure(lambda shape: uniform(rng, shape, dtype), [x_shape, [y_shape] * 2]) tf_x = to_tf(x) outputs = extensions.vjp(f, *x, has_aux=has_aux) if has_aux: y, vjp, aux = outputs else: y, vjp = outputs with tf.GradientTape(persistent=True) as tape: tape.watch(tf_x) outputs = f(*x) if has_aux: expected_y, expected_aux = outputs self.assertAllClose(to_tf(expected_aux), to_tf(aux)) else: expected_y = outputs self.assertAllClose(to_tf(expected_y), to_tf(y)) for dy in dy_list: expected_dx = tape.gradient( to_tf(expected_y), tf_x, output_gradients=to_tf(dy)) self.assertAllClose(expected_dx, to_tf(vjp(dy)))
def testCustomGrad(self): """Test for custom_grad.""" x_shape = (tf.TensorShape([10]), tf.TensorShape([1, 10])) y_shape = (tf.TensorShape([])) dtype = np.float32 scale1 = 5.0 scale2 = 6.0 def fwd(a, b): return tf_np.sum(tf_np.sqrt(tf_np.exp(a)) + b) @extensions.custom_grad def f(a, b): y = fwd(a, b) def vjp(dy): return dy * scale1 * a, dy * scale2 * b return y, vjp rng = tf.random.Generator.from_seed(1234) x, dy = tf.nest.map_structure(lambda shape: uniform(rng, shape, dtype), [x_shape, y_shape]) expected_y = fwd(*x) expected_dx = (dy * scale1 * x[0], dy * scale2 * x[1]) y, vjp = extensions.vjp(f, *x) dx = vjp(dy) self.assertAllClose(to_tf(expected_y), to_tf(y)) self.assertAllClose(to_tf(expected_dx), to_tf(dx))
def run(x, dy): y, vjp = extensions.vjp(f, *x) dx = vjp(dy) return dx, y