def testMakeVJP(self): def f(x): return x * x wrapped_fn = backprop.make_vjp(f, persistent=False) result, vjp = wrapped_fn(constant_op.constant(3.0)) self.assertAllEqual(result, 9.0) self.assertAllEqual(vjp(2.0)[0], 12.0)
def testPersistentMakeVJP(self): def f(x): return x * x wrapped_fn = backprop.make_vjp(f, persistent=True) _, vjp = wrapped_fn(constant_op.constant(3.0)) vjp_result1 = vjp(2.0)[0] vjp_result2 = vjp(2.0)[0] self.assertAllEqual(vjp_result1, vjp_result2, 12.0)
def trace_grad(fn, args): """Trace a function, and return a VJP and the function's output.""" result, vjp = make_vjp(fn)(*args) return result, vjp
def trace_grad(fn, args): """Trace a function, and return a VJP and the function's output.""" from tensorflow.python.eager.backprop import make_vjp result, vjp = make_vjp(fn)(*args) return result, vjp