Пример #1
0
    def test_chained_then(self):
        fut = Future()
        futs = []
        last_fut = fut
        for _ in range(20):
            last_fut = last_fut.then(add_one)
            futs.append(last_fut)

        fut.set_result(torch.ones(2, 2))

        for i in range(len(futs)):
            self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
Пример #2
0
    def test_wait_multi_thread(self):

        def slow_set_future(fut, value):
            time.sleep(0.5)
            fut.set_result(value)

        f = Future()

        t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
        t.start()

        self.assertEqual(f.wait(), torch.ones(2, 2))
        t.join()
Пример #3
0
    def __init__(self,
                 rank: int,
                 num_callees: int = 1,
                 num_callers: int = 1,
                 threads_process: int = 1,
                 caller_class: object = None,
                 caller_args=None,
                 future_keys: list = None):

        # ASSERTIONS
        assert num_callees > 0
        assert num_callers > 0

        # caller_class must be given
        assert caller_class is not None

        # callee_rref is correct subclass
        # use import here to omit circular import
        # pylint: disable=import-outside-toplevel
        from ..agents.rpc_caller import RpcCaller
        assert issubclass(caller_class, RpcCaller)
        assert isinstance(future_keys, list)

        # ATTRIBUTES

        # RPC
        self.rank = rank
        # pylint: disable=invalid-name
        self.id = rpc.get_worker_info().id
        self.name = rpc.get_worker_info().name
        self.rref = RRef(self)

        self.shutdown = False
        self._shutdown_done = False

        # COUNTERS
        self._t_start = time.time()
        self._loop_iteration = 0

        # STORAGE
        self._caller_rrefs = []
        self._pending_rpcs = deque()
        self._future_answers = {k: Future() for k in future_keys}
        self._current_futures = deque(maxlen=len(future_keys))

        # THREADS
        self.lock_batching = mp.Lock()
        self._processing_threads = [
            Thread(target=self._process_batch,
                   daemon=True,
                   name='processing_thread_%d' % i)
            for i in range(threads_process)
        ]

        for thread in self._processing_threads:
            thread.start()

        # spawn actors
        self._spawn_callers(caller_class, num_callees, num_callers,
                            *caller_args)
Пример #4
0
    def test_collect_all(self):
        fut1 = Future()
        fut2 = Future()
        fut_all = torch.futures.collect_all([fut1, fut2])

        def slow_in_thread(fut, value):
            time.sleep(0.1)
            fut.set_result(value)

        t = threading.Thread(target=slow_in_thread, args=(fut1, 1))
        fut2.set_result(2)
        t.start()

        res = fut_all.wait()
        self.assertEqual(res[0].wait(), 1)
        self.assertEqual(res[1].wait(), 2)
        t.join()
Пример #5
0
 def write_bytes(self, requests: List[BytesWriteRequest]) -> Future[None]:
     for req in requests:
         (self.path / req.storage_key).write_bytes(
             req.bytes.getbuffer()
         )
     fut: Future[None] = Future()
     fut.set_result(None)
     return fut
Пример #6
0
    def read_bytes(self, requests: List[BytesReadRequest]) -> Future[None]:
        for req in requests:
            with (self.path / req.storage_key).open("rb") as storage:
                req.bytes.write(storage.read())

        fut: Future = Future()
        fut.set_result(None)
        return fut
Пример #7
0
    def write_bytes(self, requests: List[BytesWriteRequest]) -> Future[None]:
        for req in requests:
            with (self.path / req.storage_key).open("wb") as w:
                w.write(req.bytes.getbuffer())
                os.fsync(w.fileno())

        fut: Future[None] = Future()
        fut.set_result(None)
        return fut
Пример #8
0
    def test_then(self):
        fut = Future()
        then_fut = fut.then(lambda x: x.wait() + 1)

        fut.set_result(torch.ones(2, 2))
        self.assertEqual(fut.wait(), torch.ones(2, 2))
        self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
Пример #9
0
    def test_wait_all(self):
        fut1 = Future()
        fut2 = Future()

        # No error version
        fut1.set_result(1)
        fut2.set_result(2)
        res = torch.futures.wait_all([fut1, fut2])
        print(res)
        self.assertEqual(res, [1, 2])

        # Version with an exception
        def raise_in_fut(fut):
            raise ValueError("Expected error")

        fut3 = fut1.then(raise_in_fut)
        with self.assertRaisesRegex(RuntimeError, "Expected error"):
            torch.futures.wait_all([fut3, fut2])
Пример #10
0
 def _fail_rank_async(self, name):
     ranks = self._get_ranks(name)
     fut = Future()
     if ranks is not None and self.rank in ranks:
         fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}"))
     else:
         fut.set_result(None)
     return fut
Пример #11
0
    def _test_error(self, cb, errMsg):
        fut = Future()
        then_fut = fut.then(cb)

        fut.set_result(5)
        self.assertEqual(5, fut.wait())
        with self.assertRaisesRegex(RuntimeError, errMsg):
            then_fut.wait()
Пример #12
0
 def test_mark_future_twice(self):
     fut = Future()
     fut.set_result(1)
     with self.assertRaisesRegex(
         RuntimeError,
         "Future can only be marked completed once"
     ):
         fut.set_result(1)
Пример #13
0
def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs):
    def _rref_type_cont(rref_fut):
        rref_type = rref_fut.value()

        _invoke_func = _local_invoke
        # Bypass ScriptModules when checking for async function attribute.
        bypass_type = issubclass(rref_type,
                                 torch.jit.ScriptModule) or issubclass(
                                     rref_type, torch._C.ScriptModule)
        if not bypass_type:
            func = getattr(rref_type, func_name)
            if hasattr(func, "_wrapped_async_rpc_function"):
                _invoke_func = _local_invoke_async_execution

        return rpc_api(rref.owner(),
                       _invoke_func,
                       args=(rref, func_name, args, kwargs),
                       timeout=timeout)

    rref_fut = rref._get_type(timeout=timeout, blocking=False)

    if rpc_api != rpc_async:
        rref_fut.wait()
        return _rref_type_cont(rref_fut)
    else:
        # A little explanation on this.
        # rpc_async returns a Future pointing to the return value of `func_name`, it returns a `Future[T]`
        # Calling _rref_type_cont from the `then` lambda causes Future wrapping. IOW, `then` returns a `Future[Future[T]]`
        # To address that, we return a Future that is completed with the result of the async call.
        result: Future = Future()

        def _wrap_rref_type_cont(fut):
            try:
                _rref_type_cont(fut).then(_complete_op)
            except BaseException as ex:
                result.set_exception(ex)

        def _complete_op(fut):
            try:
                result.set_result(fut.value())
            except BaseException as ex:
                result.set_exception(ex)

        rref_fut.then(lambda fut: _wrap_rref_type_cont(fut))
        return result
Пример #14
0
    def write_tensors(self, requests: List[TensorWriteRequest]) -> Future[None]:
        for req in requests:
            # The following couple lines are simple implementation to get
            # things going.
            #
            # At load time, to enable resharding, we use (sub)view of the tensor.
            # Since the storage of the tensor might not be contiguous. we need to
            # preserve the original view, to calculate the correct sub view at load.
            #
            # `torch.save` saves both the view and storage, it is a good option
            # for unblocking. There are two drawbacks:
            # 1. `torch.save` is pickle based, and pickle is not known for its
            #   compatibility, we should consider replacing it with a more
            #   stable option.
            # 2. pickle is not streamable.
            with (self.path / req.storage_key).open("wb") as w:
                torch.save(req.tensor, w)

        fut: Future[None] = Future()
        fut.set_result(None)
        return fut
Пример #15
0
    def read_tensors(self, requests: List[TensorReadRequest]) -> Future[None]:
        """
        Very basic implementation that read from file system.
        """
        # Sort the the requests by storage key and try to reuse the loaded tensors
        requests.sort(key=operator.attrgetter("storage_key"))

        cached_storage_key = None
        view_cached: Optional[Tensor] = None

        for req in requests:
            if cached_storage_key != req.storage_key or \
                    (view_cached is not None and view_cached.device != req.tensor.device):

                with (self.path / req.storage_key).open("rb") as storage:
                    view_cached = cast(Tensor, torch.load(storage, map_location=req.tensor.device))
                    cached_storage_key = req.storage_key

            view_to_copy: Tensor = cast(Tensor, view_cached)
            # FileSystemWrite writes the tensor as is during save.
            # During load time, we will load the Tensor (with it orignal view)
            # narrow it along all dimemsions, and copy_ it to the
            # target tensor, which will be the same size.
            for dim, (start, length) in enumerate(zip(req.offsets, req.lengths)):
                view_to_copy = torch.narrow(view_to_copy, dim, start, length)

            assert (
                view_to_copy.size() == req.tensor.size()
            ), f"The {req.storage_key} src/dst size does not match."


            assert (
                view_to_copy.device == req.tensor.device
            ), f"cannot load across devices {view_to_copy.device} vs {req.tensor.device}"

            req.tensor.copy_(view_to_copy)

        fut: Future = Future()
        fut.set_result(None)
        return fut
Пример #16
0
    def _process_batch(self, waiting_time: float = 0.001):
        """Prepares batched data held by :py:attr:`self._pending_rpcs` and
        invokes :py:meth:`process_batch()` on this data.
        Sets :py:class:`Future` with according results.

        Parameters
        ----------
        waiting_time: `float`
            Waiting time between each iteration.
        """
        while not self.shutdown:
            # check once every microsecond
            time.sleep(waiting_time)
            # print(len(self._pending_rpcs))
            with self.lock_batching:
                if len(self._pending_rpcs) == 0:
                    # skip, if no rpcs pending
                    continue
                else:
                    pending_rpcs = [
                        self._pending_rpcs.popleft()
                        for _ in range(len(self._pending_rpcs))
                    ]

            # transform rpc data
            caller_ids, *args, kwargs = zip(*pending_rpcs)
            args = [listdict_to_dictlist(b) for b in args]
            kwargs = listdict_to_dictlist(kwargs)

            # run actual internal process
            process_output = self.process_batch(caller_ids, *args, **kwargs)

            # answer futures
            for caller_id, result in process_output.items():
                f_answers = self._future_answers[caller_id]
                self._future_answers[caller_id] = Future()
                f_answers.set_result(
                    (result, self.shutdown, caller_id, dict()))
Пример #17
0
    def test_set_exception(self) -> None:
        # This test is to ensure errors can propagate across futures.
        error_msg = "Intentional Value Error"
        value_error = ValueError(error_msg)

        f = Future[T]()
        # Set exception
        f.set_exception(value_error)
        # Exception should throw on wait
        with self.assertRaisesRegex(ValueError, "Intentional"):
            f.wait()

        # Exception should also throw on value
        f = Future()
        f.set_exception(value_error)
        with self.assertRaisesRegex(ValueError, "Intentional"):
            f.value()

        def cb(fut):
            fut.value()

        f = Future()
        f.set_exception(value_error)

        with self.assertRaisesRegex(RuntimeError, "Got the following error"):
            cb_fut = f.then(cb)
            cb_fut.wait()
Пример #18
0
 def test_pickle_future(self):
     fut = Future()
     errMsg = "Can not pickle torch.futures.Future"
     with TemporaryFileName() as fname:
         with self.assertRaisesRegex(RuntimeError, errMsg):
             torch.save(fut, fname)
Пример #19
0
def create_work(result):
    future = Future()
    future.set_result(result)
    return _create_work_from_future(future)
Пример #20
0
    def test_wait(self):
        f = Future()
        f.set_result(torch.ones(2, 2))

        self.assertEqual(f.wait(), torch.ones(2, 2))