Ejemplo n.º 1
0
    def test_then(self):
        fut = Future()
        then_fut = fut.then(lambda x: x.wait() + 1)

        fut.set_result(torch.ones(2, 2))
        self.assertEqual(fut.wait(), torch.ones(2, 2))
        self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
Ejemplo n.º 2
0
    def _test_error(self, cb, errMsg):
        fut = Future()
        then_fut = fut.then(cb)

        fut.set_result(5)
        self.assertEqual(5, fut.wait())
        with self.assertRaisesRegex(RuntimeError, errMsg):
            then_fut.wait()
Ejemplo n.º 3
0
 def test_mark_future_twice(self):
     fut = Future()
     fut.set_result(1)
     with self.assertRaisesRegex(
         RuntimeError,
         "Future can only be marked completed once"
     ):
         fut.set_result(1)
Ejemplo n.º 4
0
 def _fail_rank_async(self, name):
     ranks = self._get_ranks(name)
     fut = Future()
     if ranks is not None and self.rank in ranks:
         fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}"))
     else:
         fut.set_result(None)
     return fut
Ejemplo n.º 5
0
    def test_chained_then(self):
        fut = Future()
        futs = []
        last_fut = fut
        for _ in range(20):
            last_fut = last_fut.then(add_one)
            futs.append(last_fut)

        fut.set_result(torch.ones(2, 2))

        for i in range(len(futs)):
            self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
Ejemplo n.º 6
0
    def test_collect_all(self):
        fut1 = Future()
        fut2 = Future()
        fut_all = torch.futures.collect_all([fut1, fut2])

        def slow_in_thread(fut, value):
            time.sleep(0.1)
            fut.set_result(value)

        t = threading.Thread(target=slow_in_thread, args=(fut1, 1))
        fut2.set_result(2)
        t.start()

        res = fut_all.wait()
        self.assertEqual(res[0].wait(), 1)
        self.assertEqual(res[1].wait(), 2)
        t.join()
Ejemplo n.º 7
0
    def test_wait_all(self):
        fut1 = Future()
        fut2 = Future()

        # No error version
        fut1.set_result(1)
        fut2.set_result(2)
        res = torch.futures.wait_all([fut1, fut2])
        print(res)
        self.assertEqual(res, [1, 2])

        # Version with an exception
        def raise_in_fut(fut):
            raise ValueError("Expected error")

        fut3 = fut1.then(raise_in_fut)
        with self.assertRaisesRegex(RuntimeError, "Expected error"):
            torch.futures.wait_all([fut3, fut2])
Ejemplo n.º 8
0
    def test_wait(self):
        f = Future()
        f.set_result(torch.ones(2, 2))

        self.assertEqual(f.wait(), torch.ones(2, 2))
Ejemplo n.º 9
0
def create_work(result):
    future = Future()
    future.set_result(result)
    return _create_work_from_future(future)