def test_then(self): fut = Future() then_fut = fut.then(lambda x: x.wait() + 1) fut.set_result(torch.ones(2, 2)) self.assertEqual(fut.wait(), torch.ones(2, 2)) self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
def _test_error(self, cb, errMsg): fut = Future() then_fut = fut.then(cb) fut.set_result(5) self.assertEqual(5, fut.wait()) with self.assertRaisesRegex(RuntimeError, errMsg): then_fut.wait()
def test_mark_future_twice(self): fut = Future() fut.set_result(1) with self.assertRaisesRegex( RuntimeError, "Future can only be marked completed once" ): fut.set_result(1)
def _fail_rank_async(self, name): ranks = self._get_ranks(name) fut = Future() if ranks is not None and self.rank in ranks: fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}")) else: fut.set_result(None) return fut
def test_chained_then(self): fut = Future() futs = [] last_fut = fut for _ in range(20): last_fut = last_fut.then(add_one) futs.append(last_fut) fut.set_result(torch.ones(2, 2)) for i in range(len(futs)): self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
def test_collect_all(self): fut1 = Future() fut2 = Future() fut_all = torch.futures.collect_all([fut1, fut2]) def slow_in_thread(fut, value): time.sleep(0.1) fut.set_result(value) t = threading.Thread(target=slow_in_thread, args=(fut1, 1)) fut2.set_result(2) t.start() res = fut_all.wait() self.assertEqual(res[0].wait(), 1) self.assertEqual(res[1].wait(), 2) t.join()
def test_wait_all(self): fut1 = Future() fut2 = Future() # No error version fut1.set_result(1) fut2.set_result(2) res = torch.futures.wait_all([fut1, fut2]) print(res) self.assertEqual(res, [1, 2]) # Version with an exception def raise_in_fut(fut): raise ValueError("Expected error") fut3 = fut1.then(raise_in_fut) with self.assertRaisesRegex(RuntimeError, "Expected error"): torch.futures.wait_all([fut3, fut2])
def test_wait(self): f = Future() f.set_result(torch.ones(2, 2)) self.assertEqual(f.wait(), torch.ones(2, 2))
def create_work(result): future = Future() future.set_result(result) return _create_work_from_future(future)