def test_reentrant(self): # In/grad data are random; these do not simulate the actually possible # cases. f = self.f g = functions.Identity() # any function other than f: Exp is ok self.h.backward_preprocess(f, (self.x, ), (self.gy, )) self.h.forward_preprocess(g, (self.x, )) self.h._memory_hook.used_bytes += 512 self.h._memory_hook.acquired_bytes += 512 self.h.forward_postprocess(g, (self.x, )) self.h._memory_hook.used_bytes += 512 self.h._memory_hook.acquired_bytes += 512 self.h.backward_postprocess(f, (self.x, ), (self.gy, )) history = {f: (u, a, d) for (f, u, a, d) in self.h.call_history} self.assertEqual(len(history), 2) self.assertIn(f._impl_name, history) self.assertIn(g._impl_name, history) f_used_bytes, f_acquired_bytes, f_depth = history[f._impl_name] g_used_bytes, g_acquired_bytes, g_depth = history[g._impl_name] self.assertEqual(f_depth, 0) self.assertEqual(g_depth, 1) self.assertGreater(f_used_bytes, g_used_bytes) self.assertGreater(f_acquired_bytes, g_acquired_bytes)
def test_reentrant_total_time(self): g = functions.Identity() t0 = time.time() self.h.backward_preprocess(self.f, (self.x, ), (self.gy, )) t1 = time.time() self.h.forward_preprocess(g, (self.x, )) time.sleep(0.001) self.h.forward_postprocess(g, (self.x, )) t2 = time.time() self.h.backward_postprocess(self.f, (self.x, ), (self.gy, )) t3 = time.time() self.assertLessEqual(self.h.total_time(), t3 - t0) self.assertGreaterEqual(self.h.total_time(), t2 - t1)
def test_reentrant_total_bytes(self): f = self.f g = functions.Identity() self.h.backward_preprocess(f, (self.x, ), (self.gy, )) self.h.forward_preprocess(g, (self.x, )) self.h._memory_hook.used_bytes += 512 self.h._memory_hook.acquired_bytes += 512 self.h.forward_postprocess(g, (self.x, )) self.h._memory_hook.used_bytes += 512 self.h._memory_hook.acquired_bytes += 512 self.h.backward_postprocess(f, (self.x, ), (self.gy, )) self.assertEqual(self.h.total_used_bytes(), 1024) self.assertEqual(self.h.total_acquired_bytes(), 1024)
def test_reentrant(self): # In/grad data are random; these do not simulate the actually possible # cases. g = functions.Identity() # any function other than Exp is ok self.h.backward_preprocess(self.f, (self.x, ), (self.gy, )) t1 = time.time() time.sleep(0.001) # longer than each hook call self.h.forward_preprocess(g, (self.x, )) self.h.forward_postprocess(g, (self.x, )) t2 = time.time() self.h.backward_postprocess(self.f, (self.x, ), (self.gy, )) history = dict(self.h.call_history) self.assertEqual(len(history), 2) self.assertIn(self.f, history) self.assertIn(g, history) f_time = history[self.f] g_time = history[g] self.assertLessEqual(g_time, t2 - t1) self.assertGreaterEqual(f_time, t2 - t1)