def testVariableReadInFunction(self): v = variables.Variable(1.) with forwardprop.ForwardAccumulator(v, 11.) as acc: @def_function.function def f(): return v.read_value(), 2. * v.read_value() result = f() self.assertAllClose((1.0, 2.), result) self.assertAllClose((11., 22.), acc.jvp(result))
def testVariableWatched(self): v = variables.Variable([1., 2., 3.]) with forwardprop.ForwardAccumulator( v, constant_op.constant([.1, -.2, .3])) as acc: self.assertAllClose([.1, -.2, .3], acc.jvp(v)) x = v * 2. self.assertAllClose([.2, -.4, .6], acc.jvp(x)) x2 = v + .1 self.assertAllClose([.1, -.2, .3], acc.jvp(x2))
def testGradPureForward(self, decorator): @decorator def f(x): return x ** 3.5 primal = constant_op.constant(1.1) with forwardprop.ForwardAccumulator( primal, constant_op.constant(1.)) as outer_acc: with forwardprop.ForwardAccumulator( primal, constant_op.constant(1.)) as acc: primal_out = f(primal) inner_jvp = acc.jvp(primal_out) outer_jvp = outer_acc.jvp(inner_jvp) self.assertAllClose(1.1 ** 3.5, primal_out) self.assertAllClose(3.5 * 1.1 ** 2.5, inner_jvp) self.assertAllClose(3.5 * 2.5 * 1.1 ** 1.5, outer_jvp) self.assertIsNone(acc.jvp(outer_acc.jvp(primal_out)))
def testIndexSlicesGradInFunction(self): @def_function.function def f(a): return array_ops.gather(a, 0) x = constant_op.constant([1.]) with forwardprop.ForwardAccumulator(x, constant_op.constant([3.])) as acc: y = f(x) self.assertAllClose(3., acc.jvp(y))
def testRunFunctionsEagerly(self): try: original_setting = def_function.functions_run_eagerly() def_function.run_functions_eagerly(True) x = constant_op.constant(1.) with forwardprop.ForwardAccumulator(x, 2.) as acc: y = x * 3. self.assertAllClose(6., acc.jvp(y)) finally: def_function.run_functions_eagerly(original_setting)
def testFunctionCacheLimited(self): # Every time this loop is executed, it will create a slightly larger Tensor # and push it through Add's gradient. # We run TRACE_COUNT_LIMIT x 2 so that it is tested with both # experimental_relax_shapes on and off. for execution_count in range(forwardprop._TRACE_COUNT_LIMIT * 2): x = array_ops.zeros([execution_count]) with forwardprop.ForwardAccumulator(x, array_ops.ones_like(x)) as acc: y = x + x self.assertAllClose(2. * array_ops.ones_like(x), acc.jvp(y))
def testMultipleWatchesAdd(self): x = constant_op.constant(-2.) with forwardprop.ForwardAccumulator([x, x], [1., 2.]) as acc: self.assertAllClose(3., acc.jvp(x)) acc._watch(x, constant_op.constant(10.)) self.assertAllClose(13., acc.jvp(x)) acc._watch(x, constant_op.constant(11.)) self.assertAllClose(24., acc.jvp(x)) y = constant_op.constant(3.) * x self.assertAllClose(24., acc.jvp(x)) self.assertAllClose(24. * 3., acc.jvp(y))
def testRecordingWithJVPIndices(self): c = constant_op.constant(1.) with forwardprop.ForwardAccumulator(c, 10.) as acc: packed_input_tangents = forwardprop_util.pack_tangents([c ]).tangents self.assertAllClose([10.], packed_input_tangents) d = constant_op.constant(2.) d_tangent = constant_op.constant(3.) tape_lib.record_operation_forwardprop_only( "FunctionWithInlineJVPs", [d] + [d_tangent], [c] + packed_input_tangents, None, (((0, 1), ), )) self.assertAllClose(3., acc.jvp(d))
def testFunctionReturnsResource(self): v = variables.Variable([[1.]]) x = constant_op.constant(1.) xt = constant_op.constant(2.) @def_function.function def f(a): return a, v.handle with forwardprop.ForwardAccumulator(x, xt) as acc: y, _ = f(x) self.assertAllClose(2., acc.jvp(y))
def testArgumentUnused(self): v = constant_op.constant(1.) with forwardprop.ForwardAccumulator(v, 11.) as acc: @def_function.function def _f(x): del x return constant_op.constant(1.) result = _f(v) self.assertAllClose(1.0, result) self.assertIsNone(acc.jvp(result))
def testSpecialForwardFunctionUsed(self): c = constant_op.constant(1.) d = constant_op.constant(2.) e = constant_op.constant(3.) with forwardprop.ForwardAccumulator(c, 10.) as acc: tape_lib.record_operation("ForwardIsSpecial", [d], [c], None, lambda jvp: [-2. * jvp]) self.assertAllClose(-20., acc.jvp(d)) tape_lib.record_operation("ForwardIsSpecial2", [], [], None, lambda: []) tape_lib.record_operation("ForwardIsSpecial3", [e], [d], None, lambda x: [x]) self.assertAllClose(-20., acc.jvp(e))
def testFunctionCacheLimited(self): # Every time this test is executed, it will create a slightly larger Tensor # and push it through Add's gradient. Since we check for new pyobjects after # the warmup, retracing each time without cleaning up old traces fails the # test. It works because of experimental_relax_shapes. for _ in range(forwardprop._TRACE_COUNT_LIMIT): execution_count = getattr(self, "_execution_count", 0) self._execution_count = execution_count + 1 x = array_ops.zeros([execution_count]) with forwardprop.ForwardAccumulator( x, array_ops.ones_like(x)) as acc: y = x + x self.assertAllClose(2. * array_ops.ones_like(x), acc.jvp(y))
def testExceptionInCustomGradientNotSwallowed(self): @custom_gradient.custom_gradient def f(unused_x): def grad(unused_dy): raise ValueError("test_error_string") return 1., grad c = constant_op.constant(1.) d = constant_op.constant(2.) with forwardprop.ForwardAccumulator(c, d): with self.assertRaisesRegex(ValueError, "test_error_string"): f(c)
def testJVPFunctionUsedByAccumulatorForOps(self): previous_fn = forwardprop._jvp_dispatch try: x = constant_op.constant(1.) with forwardprop.ForwardAccumulator(x, 2.) as acc: y = x + x pywrap_tfe.TFE_Py_RegisterJVPFunction( lambda *args, **kwargs: [constant_op.constant(-15.)]) z = x + x self.assertAllClose(4., acc.jvp(y)) self.assertAllClose(-15., acc.jvp(z)) finally: pywrap_tfe.TFE_Py_RegisterJVPFunction(previous_fn)
def testReenter(self): x = constant_op.constant(-2.) with forwardprop.ForwardAccumulator(x, 1.5) as acc: self.assertAllClose(1.5, acc.jvp(x)) y = 4. * x self.assertAllClose(6., acc.jvp(y)) with self.assertRaisesRegex(ValueError, "already recording"): with acc: pass z = 4. * x self.assertIsNone(acc.jvp(z)) with acc: yy = y * y self.assertAllClose(6. * -8. * 2., acc.jvp(yy))
def testReusingJVP(self): m1 = random_ops.random_uniform((256, 2096)) m2 = array_ops.identity(m1) tangent1 = random_ops.random_uniform((256, 2096)) tangent2 = random_ops.random_uniform((256, 2096)) matmul = def_function.function(math_ops.matmul) with forwardprop.ForwardAccumulator( primals=[m1, m2], tangents=[tangent1, tangent2]) as acc: result1 = matmul(m1, m1, transpose_b=True) result2 = matmul(m2, m2, transpose_b=True) def _expected(mat, tangent): return (math_ops.matmul(tangent, mat, transpose_b=True) + math_ops.matmul(mat, tangent, transpose_b=True)) self.assertAllClose(result1, result2) self.assertAllClose(_expected(m1, tangent1), acc.jvp(result1)) self.assertAllClose(_expected(m2, tangent2), acc.jvp(result2))
def testDeadTensorsJVPCleared(self): x = array_ops.ones([100]) x_weak = weakref.ref(x) grad_tensor = constant_op.constant(array_ops.zeros([100])) grad_tensor_weak = weakref.ref(grad_tensor) with forwardprop.ForwardAccumulator(x, grad_tensor) as acc: derived_tensor = constant_op.constant(2.) * x del grad_tensor self.assertAllClose(array_ops.zeros([100]), acc.jvp(x)) del x self.assertIsNone(x_weak()) self.assertIsNone(grad_tensor_weak()) derived_tensor_weak = weakref.ref(derived_tensor) derived_tensor_grad = acc.jvp(derived_tensor) derived_tensor_grad_weak = weakref.ref(derived_tensor_grad) del derived_tensor del derived_tensor_grad self.assertIsNone(derived_tensor_weak()) self.assertIsNone(derived_tensor_grad_weak())
def testShouldRecordAndStopRecord(self): c = constant_op.constant(1.) c_tangent = constant_op.constant(2.) with forwardprop.ForwardAccumulator(c, c_tangent) as acc: with backprop.GradientTape() as tape: self.assertFalse(tape_lib.should_record_backprop([c])) self.assertEqual(1, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c])) tape.watch(c) self.assertEqual(2, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c])) self.assertTrue(tape_lib.should_record_backprop([c])) with tape_lib.stop_recording(): self.assertEqual(0, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c])) self.assertFalse(tape_lib.should_record_backprop([c])) d = c * 2. self.assertEqual(2, pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c])) self.assertTrue(tape_lib.should_record_backprop([c])) self.assertFalse(tape_lib.should_record_backprop([d])) self.assertIsNone(acc.jvp(d)) self.assertIsNone(tape.gradient(d, c))
def testPushPopAccumulatorState(self): # Note that this example is somewhat contrived. push_forwardprop_state is # probably only useful in practice for building functions that compute jvps # alongside their usual outputs. c = constant_op.constant(1.) d = constant_op.constant(2.) with forwardprop.ForwardAccumulator(c, d) as acc: @custom_gradient.custom_gradient def f(x): y = math_ops.sin(x.numpy()) def grad(dy): with forwardprop_util.push_forwardprop_state(): x_copy = constant_op.constant(x.numpy()) acc._watch(x_copy, dy) y_copy = math_ops.sin(x_copy) return dy * acc.jvp(y_copy) return y, grad output = f(c) self.assertAllClose(d * math_ops.cos(c), acc.jvp(output))
def _jvp(f, primals, tangents): """Compute the jacobian of `f` at `primals` multiplied by `tangents`.""" with forwardprop.ForwardAccumulator(primals, tangents) as acc: primals_out = f(*primals) return primals_out, acc.jvp(primals_out.data)
def compiled_function(x, tangent): with forwardprop.ForwardAccumulator(x, tangent) as acc: result = matmul(x, x, transpose_b=True) return result, acc.jvp(result)
def testIndexSlicesGrad(self): x = constant_op.constant([1.]) with forwardprop.ForwardAccumulator(x, constant_op.constant([3.])) as acc: y = array_ops.gather(x, 0) self.assertAllClose(3., acc.jvp(y))
def testOpWithNoTrainableOutputs(self): v = variables.Variable(1.) with forwardprop.ForwardAccumulator(v, 11.): v.assign_sub(0.5) self.assertAllClose(0.5, self.evaluate(v))
def _compute_forwardgrad(primal): tangent = constant_op.constant(1.) with forwardprop.ForwardAccumulator(primal, tangent) as acc: primal_out = f(primal) return acc.jvp(primal_out)
def func(): with forwardprop.ForwardAccumulator(m, tangent) as acc: result = matmul(m, m, transpose_b=True) return result, acc.jvp(result)
def _fprop_while(iters, y): with forwardprop.ForwardAccumulator(y, 1.) as acc: ret = 0. for i in math_ops.range(iters): ret += y * math_ops.cast(i, dtypes.float32) return acc.jvp(ret)
def _jvp(f, primals, tangents): """Compute the jacobian of `f` at `primals` multiplied by `tangents`.""" with forwardprop.ForwardAccumulator(primals, tangents) as acc: primals_out = f(*primals) return primals_out, acc.jvp( primals_out, unconnected_gradients=UnconnectedGradients.ZERO)
def testOfFunctionWhile(self): y = constant_op.constant(1.) with forwardprop.ForwardAccumulator(y, 1.) as acc: self.assertAllClose( 10., acc.jvp(_has_loop(constant_op.constant(5), y)))
def _single_jvp(param_mask): with forwardprop.ForwardAccumulator(primals=[params[argnums]], tangents=param_mask) as acc: primals_out = f(*params) return acc.jvp(primals_out)