Esempio n. 1
0
 def testShouldRecordAndStopRecord(self):
   c = constant_op.constant(1.)
   c_tangent = constant_op.constant(2.)
   with forwardprop.ForwardAccumulator(c, c_tangent) as acc:
     with backprop.GradientTape() as tape:
       self.assertFalse(tape_lib.should_record_backprop([c]))
       self.assertEqual(1,
                        pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
       tape.watch(c)
       self.assertEqual(2,
                        pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
       self.assertTrue(tape_lib.should_record_backprop([c]))
       with tape_lib.stop_recording():
         self.assertEqual(0,
                          pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
         self.assertFalse(tape_lib.should_record_backprop([c]))
         d = c * 2.
       self.assertEqual(2,
                        pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes([c]))
       self.assertTrue(tape_lib.should_record_backprop([c]))
       self.assertFalse(tape_lib.should_record_backprop([d]))
       self.assertIsNone(acc.jvp(d))
     self.assertIsNone(tape.gradient(d, c))
Esempio n. 2
0
def PossibleTapeGradientTypes(tensors):
    """Determines whether and how `args` may require tape gradients."""
    return pywrap_tfe.TFE_Py_TapeSetPossibleGradientTypes(tensors)