def test_then(self): fut = Future() then_fut = fut.then(lambda x: x.wait() + 1) fut.set_result(torch.ones(2, 2)) self.assertEqual(fut.wait(), torch.ones(2, 2)) self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
def test_set_exception(self) -> None: # This test is to ensure errors can propagate across futures. error_msg = "Intentional Value Error" value_error = ValueError(error_msg) f = Future[T]() # Set exception f.set_exception(value_error) # Exception should throw on wait with self.assertRaisesRegex(ValueError, "Intentional"): f.wait() # Exception should also throw on value f = Future() f.set_exception(value_error) with self.assertRaisesRegex(ValueError, "Intentional"): f.value() def cb(fut): fut.value() f = Future() f.set_exception(value_error) with self.assertRaisesRegex(RuntimeError, "Got the following error"): cb_fut = f.then(cb) cb_fut.wait()
def _test_error(self, cb, errMsg): fut = Future() then_fut = fut.then(cb) fut.set_result(5) self.assertEqual(5, fut.wait()) with self.assertRaisesRegex(RuntimeError, errMsg): then_fut.wait()
def test_wait_all(self): fut1 = Future() fut2 = Future() # No error version fut1.set_result(1) fut2.set_result(2) res = torch.futures.wait_all([fut1, fut2]) print(res) self.assertEqual(res, [1, 2]) # Version with an exception def raise_in_fut(fut): raise ValueError("Expected error") fut3 = fut1.then(raise_in_fut) with self.assertRaisesRegex(RuntimeError, "Expected error"): torch.futures.wait_all([fut3, fut2])