예제 #1
0
 def testVariableReadInFunction(self):
   v = variables.Variable(1.)
   with forwardprop.ForwardAccumulator(v, 11.) as acc:
     @def_function.function
     def f():
       return v.read_value(), 2. * v.read_value()
     result = f()
     self.assertAllClose((1.0, 2.), result)
     self.assertAllClose((11., 22.), acc.jvp(result))
예제 #2
0
 def testVariableWatched(self):
   v = variables.Variable([1., 2., 3.])
   with forwardprop.ForwardAccumulator(
       v, constant_op.constant([.1, -.2, .3])) as acc:
     self.assertAllClose([.1, -.2, .3], acc.jvp(v))
     x = v * 2.
     self.assertAllClose([.2, -.4, .6], acc.jvp(x))
     x2 = v + .1
     self.assertAllClose([.1, -.2, .3], acc.jvp(x2))
예제 #3
0
  def testGradPureForward(self, decorator):

    @decorator
    def f(x):
      return x ** 3.5

    primal = constant_op.constant(1.1)
    with forwardprop.ForwardAccumulator(
        primal, constant_op.constant(1.)) as outer_acc:
      with forwardprop.ForwardAccumulator(
          primal, constant_op.constant(1.)) as acc:
        primal_out = f(primal)
    inner_jvp = acc.jvp(primal_out)
    outer_jvp = outer_acc.jvp(inner_jvp)
    self.assertAllClose(1.1 ** 3.5, primal_out)
    self.assertAllClose(3.5 * 1.1 ** 2.5, inner_jvp)
    self.assertAllClose(3.5 * 2.5 * 1.1 ** 1.5, outer_jvp)
    self.assertIsNone(acc.jvp(outer_acc.jvp(primal_out)))
예제 #4
0
  def testIndexSlicesGradInFunction(self):
    @def_function.function
    def f(a):
      return array_ops.gather(a, 0)

    x = constant_op.constant([1.])

    with forwardprop.ForwardAccumulator(x, constant_op.constant([3.])) as acc:
      y = f(x)
    self.assertAllClose(3., acc.jvp(y))
예제 #5
0
 def testRunFunctionsEagerly(self):
   try:
     original_setting = def_function.functions_run_eagerly()
     def_function.run_functions_eagerly(True)
     x = constant_op.constant(1.)
     with forwardprop.ForwardAccumulator(x, 2.) as acc:
       y = x * 3.
     self.assertAllClose(6., acc.jvp(y))
   finally:
     def_function.run_functions_eagerly(original_setting)
예제 #6
0
 def testFunctionCacheLimited(self):
     # Every time this loop is executed, it will create a slightly larger Tensor
     # and push it through Add's gradient.
     # We run TRACE_COUNT_LIMIT x 2 so that it is tested with both
     # experimental_relax_shapes on and off.
     for execution_count in range(forwardprop._TRACE_COUNT_LIMIT * 2):
         x = array_ops.zeros([execution_count])
         with forwardprop.ForwardAccumulator(x,
                                             array_ops.ones_like(x)) as acc:
             y = x + x
         self.assertAllClose(2. * array_ops.ones_like(x), acc.jvp(y))
예제 #7
0
 def testMultipleWatchesAdd(self):
     x = constant_op.constant(-2.)
     with forwardprop.ForwardAccumulator([x, x], [1., 2.]) as acc:
         self.assertAllClose(3., acc.jvp(x))
         acc._watch(x, constant_op.constant(10.))
         self.assertAllClose(13., acc.jvp(x))
         acc._watch(x, constant_op.constant(11.))
         self.assertAllClose(24., acc.jvp(x))
         y = constant_op.constant(3.) * x
     self.assertAllClose(24., acc.jvp(x))
     self.assertAllClose(24. * 3., acc.jvp(y))
예제 #8
0
 def testRecordingWithJVPIndices(self):
     c = constant_op.constant(1.)
     with forwardprop.ForwardAccumulator(c, 10.) as acc:
         packed_input_tangents = forwardprop_util.pack_tangents([c
                                                                 ]).tangents
         self.assertAllClose([10.], packed_input_tangents)
         d = constant_op.constant(2.)
         d_tangent = constant_op.constant(3.)
         tape_lib.record_operation_forwardprop_only(
             "FunctionWithInlineJVPs", [d] + [d_tangent],
             [c] + packed_input_tangents, None, (((0, 1), ), ))
         self.assertAllClose(3., acc.jvp(d))
예제 #9
0
    def testFunctionReturnsResource(self):
        v = variables.Variable([[1.]])
        x = constant_op.constant(1.)
        xt = constant_op.constant(2.)

        @def_function.function
        def f(a):
            return a, v.handle

        with forwardprop.ForwardAccumulator(x, xt) as acc:
            y, _ = f(x)
        self.assertAllClose(2., acc.jvp(y))
예제 #10
0
  def testArgumentUnused(self):
    v = constant_op.constant(1.)
    with forwardprop.ForwardAccumulator(v, 11.) as acc:

      @def_function.function
      def _f(x):
        del x
        return constant_op.constant(1.)

      result = _f(v)
      self.assertAllClose(1.0, result)
      self.assertIsNone(acc.jvp(result))
예제 #11
0
 def testSpecialForwardFunctionUsed(self):
     c = constant_op.constant(1.)
     d = constant_op.constant(2.)
     e = constant_op.constant(3.)
     with forwardprop.ForwardAccumulator(c, 10.) as acc:
         tape_lib.record_operation("ForwardIsSpecial", [d], [c], None,
                                   lambda jvp: [-2. * jvp])
         self.assertAllClose(-20., acc.jvp(d))
         tape_lib.record_operation("ForwardIsSpecial2", [], [], None,
                                   lambda: [])
         tape_lib.record_operation("ForwardIsSpecial3", [e], [d], None,
                                   lambda x: [x])
         self.assertAllClose(-20., acc.jvp(e))
예제 #12
0
 def testFunctionCacheLimited(self):
   # Every time this test is executed, it will create a slightly larger Tensor
   # and push it through Add's gradient. Since we check for new pyobjects after
   # the warmup, retracing each time without cleaning up old traces fails the
   # test. It works because of experimental_relax_shapes.
   for _ in range(forwardprop._TRACE_COUNT_LIMIT):
     execution_count = getattr(self, "_execution_count", 0)
     self._execution_count = execution_count + 1
     x = array_ops.zeros([execution_count])
     with forwardprop.ForwardAccumulator(
         x, array_ops.ones_like(x)) as acc:
       y = x + x
     self.assertAllClose(2. * array_ops.ones_like(x), acc.jvp(y))
예제 #13
0
  def testExceptionInCustomGradientNotSwallowed(self):

    @custom_gradient.custom_gradient
    def f(unused_x):
      def grad(unused_dy):
        raise ValueError("test_error_string")
      return 1., grad

    c = constant_op.constant(1.)
    d = constant_op.constant(2.)
    with forwardprop.ForwardAccumulator(c, d):
      with self.assertRaisesRegex(ValueError, "test_error_string"):
        f(c)
예제 #14
0
 def testJVPFunctionUsedByAccumulatorForOps(self):
   previous_fn = forwardprop._jvp_dispatch
   try:
     x = constant_op.constant(1.)
     with forwardprop.ForwardAccumulator(x, 2.) as acc:
       y = x + x
       pywrap_tfe.TFE_Py_RegisterJVPFunction(
           lambda *args, **kwargs: [constant_op.constant(-15.)])
       z = x + x
     self.assertAllClose(4., acc.jvp(y))
     self.assertAllClose(-15., acc.jvp(z))
   finally:
     pywrap_tfe.TFE_Py_RegisterJVPFunction(previous_fn)
예제 #15
0
 def testReenter(self):
   x = constant_op.constant(-2.)
   with forwardprop.ForwardAccumulator(x, 1.5) as acc:
     self.assertAllClose(1.5, acc.jvp(x))
     y = 4. * x
     self.assertAllClose(6., acc.jvp(y))
     with self.assertRaisesRegex(ValueError, "already recording"):
       with acc:
         pass
   z = 4. * x
   self.assertIsNone(acc.jvp(z))
   with acc:
     yy = y * y
   self.assertAllClose(6. * -8. * 2., acc.jvp(yy))
예제 #16
0
  def testReusingJVP(self):
    m1 = random_ops.random_uniform((256, 2096))
    m2 = array_ops.identity(m1)
    tangent1 = random_ops.random_uniform((256, 2096))
    tangent2 = random_ops.random_uniform((256, 2096))
    matmul = def_function.function(math_ops.matmul)

    with forwardprop.ForwardAccumulator(
        primals=[m1, m2], tangents=[tangent1, tangent2]) as acc:
      result1 = matmul(m1, m1, transpose_b=True)
      result2 = matmul(m2, m2, transpose_b=True)

    def _expected(mat, tangent):
      return (math_ops.matmul(tangent, mat, transpose_b=True)
              + math_ops.matmul(mat, tangent, transpose_b=True))

    self.assertAllClose(result1, result2)
    self.assertAllClose(_expected(m1, tangent1), acc.jvp(result1))
    self.assertAllClose(_expected(m2, tangent2), acc.jvp(result2))
예제 #17
0
 def testDeadTensorsJVPCleared(self):
   x = array_ops.ones([100])
   x_weak = weakref.ref(x)
   grad_tensor = constant_op.constant(array_ops.zeros([100]))
   grad_tensor_weak = weakref.ref(grad_tensor)
   with forwardprop.ForwardAccumulator(x, grad_tensor) as acc:
     derived_tensor = constant_op.constant(2.) * x
     del grad_tensor
     self.assertAllClose(array_ops.zeros([100]), acc.jvp(x))
     del x
     self.assertIsNone(x_weak())
     self.assertIsNone(grad_tensor_weak())
     derived_tensor_weak = weakref.ref(derived_tensor)
     derived_tensor_grad = acc.jvp(derived_tensor)
     derived_tensor_grad_weak = weakref.ref(derived_tensor_grad)
     del derived_tensor
     del derived_tensor_grad
     self.assertIsNone(derived_tensor_weak())
     self.assertIsNone(derived_tensor_grad_weak())
예제 #18
0
 def testShouldRecordAndStopRecord(self):
   c = constant_op.constant(1.)
   c_tangent = constant_op.constant(2.)
   with forwardprop.ForwardAccumulator(c, c_tangent) as acc:
     with backprop.GradientTape() as tape:
       self.assertFalse(tape_lib.should_record_backprop([c]))
       self.assertEqual(1,
                        pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
       tape.watch(c)
       self.assertEqual(2,
                        pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
       self.assertTrue(tape_lib.should_record_backprop([c]))
       with tape_lib.stop_recording():
         self.assertEqual(0,
                          pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
         self.assertFalse(tape_lib.should_record_backprop([c]))
         d = c * 2.
       self.assertEqual(2,
                        pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
       self.assertTrue(tape_lib.should_record_backprop([c]))
       self.assertFalse(tape_lib.should_record_backprop([d]))
       self.assertIsNone(acc.jvp(d))
     self.assertIsNone(tape.gradient(d, c))
예제 #19
0
  def testPushPopAccumulatorState(self):
    # Note that this example is somewhat contrived. push_forwardprop_state is
    # probably only useful in practice for building functions that compute jvps
    # alongside their usual outputs.
    c = constant_op.constant(1.)
    d = constant_op.constant(2.)
    with forwardprop.ForwardAccumulator(c, d) as acc:

      @custom_gradient.custom_gradient
      def f(x):
        y = math_ops.sin(x.numpy())

        def grad(dy):
          with forwardprop_util.push_forwardprop_state():
            x_copy = constant_op.constant(x.numpy())
            acc._watch(x_copy, dy)
            y_copy = math_ops.sin(x_copy)
          return dy * acc.jvp(y_copy)

        return y, grad

      output = f(c)
      self.assertAllClose(d * math_ops.cos(c), acc.jvp(output))
예제 #20
0
def _jvp(f, primals, tangents):
  """Compute the jacobian of `f` at `primals` multiplied by `tangents`."""
  with forwardprop.ForwardAccumulator(primals, tangents) as acc:
    primals_out = f(*primals)
  return primals_out, acc.jvp(primals_out.data)
예제 #21
0
 def compiled_function(x, tangent):
     with forwardprop.ForwardAccumulator(x, tangent) as acc:
         result = matmul(x, x, transpose_b=True)
     return result, acc.jvp(result)
예제 #22
0
  def testIndexSlicesGrad(self):
    x = constant_op.constant([1.])

    with forwardprop.ForwardAccumulator(x, constant_op.constant([3.])) as acc:
      y = array_ops.gather(x, 0)
    self.assertAllClose(3., acc.jvp(y))
예제 #23
0
 def testOpWithNoTrainableOutputs(self):
   v = variables.Variable(1.)
   with forwardprop.ForwardAccumulator(v, 11.):
     v.assign_sub(0.5)
     self.assertAllClose(0.5, self.evaluate(v))
예제 #24
0
 def _compute_forwardgrad(primal):
   tangent = constant_op.constant(1.)
   with forwardprop.ForwardAccumulator(primal, tangent) as acc:
     primal_out = f(primal)
   return acc.jvp(primal_out)
예제 #25
0
 def func():
     with forwardprop.ForwardAccumulator(m, tangent) as acc:
         result = matmul(m, m, transpose_b=True)
     return result, acc.jvp(result)
예제 #26
0
def _fprop_while(iters, y):
  with forwardprop.ForwardAccumulator(y, 1.) as acc:
    ret = 0.
    for i in math_ops.range(iters):
      ret += y * math_ops.cast(i, dtypes.float32)
  return acc.jvp(ret)
예제 #27
0
def _jvp(f, primals, tangents):
  """Compute the jacobian of `f` at `primals` multiplied by `tangents`."""
  with forwardprop.ForwardAccumulator(primals, tangents) as acc:
    primals_out = f(*primals)
  return primals_out, acc.jvp(
      primals_out, unconnected_gradients=UnconnectedGradients.ZERO)
예제 #28
0
 def testOfFunctionWhile(self):
   y = constant_op.constant(1.)
   with forwardprop.ForwardAccumulator(y, 1.) as acc:
     self.assertAllClose(
         10., acc.jvp(_has_loop(constant_op.constant(5), y)))
예제 #29
0
 def _single_jvp(param_mask):
   with forwardprop.ForwardAccumulator(primals=[params[argnums]],
                                       tangents=param_mask) as acc:
     primals_out = f(*params)
   return acc.jvp(primals_out)