예제 #1
0
    def testMakeVJP(self):
        def f(x):
            return x * x

        wrapped_fn = backprop.make_vjp(f, persistent=False)
        result, vjp = wrapped_fn(constant_op.constant(3.0))
        self.assertAllEqual(result, 9.0)
        self.assertAllEqual(vjp(2.0)[0], 12.0)
예제 #2
0
    def testPersistentMakeVJP(self):
        def f(x):
            return x * x

        wrapped_fn = backprop.make_vjp(f, persistent=True)
        _, vjp = wrapped_fn(constant_op.constant(3.0))
        vjp_result1 = vjp(2.0)[0]
        vjp_result2 = vjp(2.0)[0]
        self.assertAllEqual(vjp_result1, vjp_result2, 12.0)
예제 #3
0
  def testMakeVJP(self):

    def f(x):
      return x * x

    wrapped_fn = backprop.make_vjp(f, persistent=False)
    result, vjp = wrapped_fn(constant_op.constant(3.0))
    self.assertAllEqual(result, 9.0)
    self.assertAllEqual(vjp(2.0)[0], 12.0)
예제 #4
0
  def testPersistentMakeVJP(self):

    def f(x):
      return x * x

    wrapped_fn = backprop.make_vjp(f, persistent=True)
    _, vjp = wrapped_fn(constant_op.constant(3.0))
    vjp_result1 = vjp(2.0)[0]
    vjp_result2 = vjp(2.0)[0]
    self.assertAllEqual(vjp_result1, vjp_result2, 12.0)
예제 #5
0
def trace_grad(fn, args):
    """Trace a function, and return a VJP and the function's output."""
    result, vjp = make_vjp(fn)(*args)
    return result, vjp
예제 #6
0
def trace_grad(fn, args):
    """Trace a function, and return a VJP and the function's output."""
    from tensorflow.python.eager.backprop import make_vjp
    result, vjp = make_vjp(fn)(*args)
    return result, vjp