Exemplo n.º 1
0
 def testRecordingSelectively(self):
   c = constant_op.constant(1.)
   c_tangent = constant_op.constant(2.)
   with forwardprop.ForwardAccumulator(c, c_tangent) as acc:
     with backprop.GradientTape(persistent=True) as tape:
       tape.watch(c)
       with tape_lib.stop_recording():
         two = constant_op.constant(2.)
         d = c * two
         three = constant_op.constant(3.)
         e = c * three
       self.assertIsNone(acc.jvp(d))
       self.assertIsNone(acc.jvp(e))
       self.assertIsNone(tape.gradient(d, c))
       self.assertIsNone(tape.gradient(e, c))
       tape_lib.record_operation_forwardprop_only(
           "CustomForwardMul", [d], [c, two],
           lambda dd: (two * dd, c * dd), None)
       tape_lib.record_operation_backprop_only(
           "CustomBackwardMul", [e], [c, three],
           lambda de: (three * de, c * de))
       self.assertAllClose(4., acc.jvp(d))
       self.assertIsNone(acc.jvp(e))
       self.assertIsNone(tape.gradient(d, c))
       self.assertAllClose(3., tape.gradient(e, c))
Exemplo n.º 2
0
 def testRecordingWithJVPIndices(self):
     c = constant_op.constant(1.)
     with forwardprop.ForwardAccumulator(c, 10.) as acc:
         packed_input_tangents = forwardprop_util.pack_tangents([c
                                                                 ]).tangents
         self.assertAllClose([10.], packed_input_tangents)
         d = constant_op.constant(2.)
         d_tangent = constant_op.constant(3.)
         tape_lib.record_operation_forwardprop_only(
             "FunctionWithInlineJVPs", [d] + [d_tangent],
             [c] + packed_input_tangents, None, (((0, 1), ), ))
         self.assertAllClose(3., acc.jvp(d))
Exemplo n.º 3
0
 def testRecordingWithJVPIndices(self):
     with forwardprop.ForwardGradientAccumulator() as acc:
         c = constant_op.constant(1.)
         acc.watch(c, 10.)
         _, packed_input_tangents = (
             pywrap_tensorflow.TFE_Py_PackForwardGradients([c]))
         self.assertAllClose([10.], packed_input_tangents)
         d = constant_op.constant(2.)
         d_tangent = constant_op.constant(3.)
         tape_lib.record_operation_forwardprop_only(
             "FunctionWithInlineJVPs", [d] + [d_tangent],
             [c] + packed_input_tangents, None, (((0, 1), ), ))
         self.assertAllClose(3., acc.jvp(d))