def test_then(self): fut = Future() then_fut = fut.then(lambda x: x.wait() + 1) fut.set_result(torch.ones(2, 2)) self.assertEqual(fut.wait(), torch.ones(2, 2)) self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
def _test_error(self, cb, errMsg): fut = Future() then_fut = fut.then(cb) fut.set_result(5) self.assertEqual(5, fut.wait()) with self.assertRaisesRegex(RuntimeError, errMsg): then_fut.wait()
def test_wait_multi_thread(self): def slow_set_future(fut, value): time.sleep(0.5) fut.set_result(value) f = Future() t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2))) t.start() self.assertEqual(f.wait(), torch.ones(2, 2)) t.join()
def test_wait(self): f = Future() f.set_result(torch.ones(2, 2)) self.assertEqual(f.wait(), torch.ones(2, 2))