def test_reentrant(self):
        # In/grad data are random; these do not simulate the actually possible
        # cases.
        f = self.f
        g = functions.Identity()  # any function other than f: Exp is ok

        self.h.backward_preprocess(f, (self.x, ), (self.gy, ))
        self.h.forward_preprocess(g, (self.x, ))
        self.h._memory_hook.used_bytes += 512
        self.h._memory_hook.acquired_bytes += 512
        self.h.forward_postprocess(g, (self.x, ))
        self.h._memory_hook.used_bytes += 512
        self.h._memory_hook.acquired_bytes += 512
        self.h.backward_postprocess(f, (self.x, ), (self.gy, ))

        history = {f: (u, a, d) for (f, u, a, d) in self.h.call_history}
        self.assertEqual(len(history), 2)
        self.assertIn(f._impl_name, history)
        self.assertIn(g._impl_name, history)
        f_used_bytes, f_acquired_bytes, f_depth = history[f._impl_name]
        g_used_bytes, g_acquired_bytes, g_depth = history[g._impl_name]
        self.assertEqual(f_depth, 0)
        self.assertEqual(g_depth, 1)
        self.assertGreater(f_used_bytes, g_used_bytes)
        self.assertGreater(f_acquired_bytes, g_acquired_bytes)
Esempio n. 2
0
    def test_reentrant_total_time(self):
        g = functions.Identity()

        t0 = time.time()
        self.h.backward_preprocess(self.f, (self.x, ), (self.gy, ))
        t1 = time.time()
        self.h.forward_preprocess(g, (self.x, ))
        time.sleep(0.001)
        self.h.forward_postprocess(g, (self.x, ))
        t2 = time.time()
        self.h.backward_postprocess(self.f, (self.x, ), (self.gy, ))
        t3 = time.time()

        self.assertLessEqual(self.h.total_time(), t3 - t0)
        self.assertGreaterEqual(self.h.total_time(), t2 - t1)
    def test_reentrant_total_bytes(self):
        f = self.f
        g = functions.Identity()

        self.h.backward_preprocess(f, (self.x, ), (self.gy, ))
        self.h.forward_preprocess(g, (self.x, ))
        self.h._memory_hook.used_bytes += 512
        self.h._memory_hook.acquired_bytes += 512
        self.h.forward_postprocess(g, (self.x, ))
        self.h._memory_hook.used_bytes += 512
        self.h._memory_hook.acquired_bytes += 512
        self.h.backward_postprocess(f, (self.x, ), (self.gy, ))

        self.assertEqual(self.h.total_used_bytes(), 1024)
        self.assertEqual(self.h.total_acquired_bytes(), 1024)
Esempio n. 4
0
    def test_reentrant(self):
        # In/grad data are random; these do not simulate the actually possible
        # cases.
        g = functions.Identity()  # any function other than Exp is ok

        self.h.backward_preprocess(self.f, (self.x, ), (self.gy, ))
        t1 = time.time()
        time.sleep(0.001)  # longer than each hook call
        self.h.forward_preprocess(g, (self.x, ))
        self.h.forward_postprocess(g, (self.x, ))
        t2 = time.time()
        self.h.backward_postprocess(self.f, (self.x, ), (self.gy, ))

        history = dict(self.h.call_history)
        self.assertEqual(len(history), 2)
        self.assertIn(self.f, history)
        self.assertIn(g, history)
        f_time = history[self.f]
        g_time = history[g]
        self.assertLessEqual(g_time, t2 - t1)
        self.assertGreaterEqual(f_time, t2 - t1)