Exemplo n.º 1
0
    def test_then(self):
        fut = Future()
        then_fut = fut.then(lambda x: x.wait() + 1)

        fut.set_result(torch.ones(2, 2))
        self.assertEqual(fut.wait(), torch.ones(2, 2))
        self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
Exemplo n.º 2
0
    def test_set_exception(self) -> None:
        # This test is to ensure errors can propagate across futures.
        error_msg = "Intentional Value Error"
        value_error = ValueError(error_msg)

        f = Future[T]()
        # Set exception
        f.set_exception(value_error)
        # Exception should throw on wait
        with self.assertRaisesRegex(ValueError, "Intentional"):
            f.wait()

        # Exception should also throw on value
        f = Future()
        f.set_exception(value_error)
        with self.assertRaisesRegex(ValueError, "Intentional"):
            f.value()

        def cb(fut):
            fut.value()

        f = Future()
        f.set_exception(value_error)

        with self.assertRaisesRegex(RuntimeError, "Got the following error"):
            cb_fut = f.then(cb)
            cb_fut.wait()
Exemplo n.º 3
0
    def _test_error(self, cb, errMsg):
        fut = Future()
        then_fut = fut.then(cb)

        fut.set_result(5)
        self.assertEqual(5, fut.wait())
        with self.assertRaisesRegex(RuntimeError, errMsg):
            then_fut.wait()
Exemplo n.º 4
0
    def test_wait_all(self):
        fut1 = Future()
        fut2 = Future()

        # No error version
        fut1.set_result(1)
        fut2.set_result(2)
        res = torch.futures.wait_all([fut1, fut2])
        print(res)
        self.assertEqual(res, [1, 2])

        # Version with an exception
        def raise_in_fut(fut):
            raise ValueError("Expected error")

        fut3 = fut1.then(raise_in_fut)
        with self.assertRaisesRegex(RuntimeError, "Expected error"):
            torch.futures.wait_all([fut3, fut2])