예제 #1
0
def test_compute_exception():
    """Task.compute returns (False, exc_info) on failure."""
    def zero_div():
        0 / 0

    with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues):
        t = Task(CPUStream, compute=zero_div, finalize=None)
        in_queues[0].put(t)
        ok, exc_info = out_queues[0].get()

        assert not ok
        assert isinstance(exc_info, tuple)
        assert issubclass(exc_info[0], ZeroDivisionError)
예제 #2
0
def test_grad_mode(grad_mode):
    def detect_grad_enabled():
        x = torch.rand(1, requires_grad=torch.is_grad_enabled())
        return Batch(x)

    with torch.set_grad_enabled(grad_mode):
        with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues):
            task = Task(CPUStream, compute=detect_grad_enabled, finalize=None)
            in_queues[0].put(task)

            ok, (_, batch) = out_queues[0].get()

            assert ok
            assert batch[0].requires_grad == grad_mode
예제 #3
0
def test_compute_success():
    """Task.compute returns (True, (task, batch)) on success."""
    def _42():
        return Batch(torch.tensor(42))

    with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues):
        t = Task(CPUStream, compute=_42, finalize=None)
        in_queues[0].put(t)
        ok, (task, batch) = out_queues[0].get()

        assert ok
        assert task is t
        assert isinstance(batch, Batch)
        assert batch[0].item() == 42
예제 #4
0
def test_compute_multithreading():
    """Task.compute should be executed on multiple threads."""
    thread_ids = set()

    def log_thread_id():
        thread_id = threading.current_thread().ident
        thread_ids.add(thread_id)
        return Batch(())

    with spawn_workers([fake_device()
                        for _ in range(2)]) as (in_queues, out_queues):
        for i in range(2):
            t = Task(CPUStream, compute=log_thread_id, finalize=None)
            in_queues[i].put(t)
        for i in range(2):
            out_queues[i].get()

    assert len(thread_ids) == 2
예제 #5
0
 def call_in_worker(i, f):
     task = Task(CPUStream, compute=f, finalize=None)
     in_queues[i].put(task)