コード例 #1
0
def test_batch_atomic():
    x = torch.tensor(42)
    b = Batch(x)

    assert b.atomic

    assert b.tensor is x
    with pytest.raises(AttributeError):
        b.tensors

    assert list(b) == [x]
    assert len(b) == 1
    assert b[0] is x
コード例 #2
0
def test_tensor_life_without_checkpointing():
    skip_layout = SkipLayout(num_partitions=2,
                             skip_routes={(None, "test"): (0, 1)})
    skip_tracker = SkipTrackerThroughPotals(skip_layout)

    batch = Batch(torch.tensor([1.0]))
    tensor = torch.tensor([2.0])

    skip_tracker.save(batch, None, "test", tensor)
    assert skip_tracker.portals[(None, "test")].tensor_life == 1

    skip_tracker.load(batch, None, "test")
    assert skip_tracker.portals[(None, "test")].tensor_life == 0
コード例 #3
0
def test_batch_non_atomic():
    x, y = torch.tensor(42), torch.tensor(21)
    b = Batch((x, y))

    assert not b.atomic

    with pytest.raises(AttributeError):
        b.tensor
    assert b.tensors == (x, y)

    assert list(b) == [x, y]
    assert len(b) == 2
    assert b[0] is x
    assert b[1] is y
コード例 #4
0
def test_reuse_portal():
    skip_layout = SkipLayout(num_partitions=2,
                             skip_routes={(None, "test"): (0, 1)})
    skip_tracker = SkipTrackerThroughPotals(skip_layout)

    batch = Batch(torch.tensor([1.0]))
    a = torch.tensor([2.0])
    b = torch.tensor([2.0])

    skip_tracker.save(batch, None, "test", a)
    portal = skip_tracker.portals[(None, "test")]

    skip_tracker.save(batch, None, "test", b)
    assert portal is skip_tracker.portals[(None, "test")]
コード例 #5
0
def test_batch_call():
    a = Batch(torch.tensor(42))
    b = Batch((torch.tensor(42), torch.tensor(21)))

    def f(x):
        return x

    assert a.call(f).atomic
    assert not b.call(f).atomic
コード例 #6
0
ファイル: test_checkpoint.py プロジェクト: shiftlab/pytorch-1
def test_not_requires_grad():
    x = Batch(torch.rand(1, requires_grad=False))
    assert not x[0].requires_grad

    def f(x):
        return x * 2

    chk = Checkpointing(f, x)
    x = chk.checkpoint()
    assert x[0].requires_grad

    chk.recompute(x)
    assert x[0].requires_grad

    x.tensor.backward()
コード例 #7
0
def test_no_copy_no_portal():
    skip_layout = SkipLayout(num_partitions=2,
                             skip_routes={
                                 (None, "copy"): (0, 1),
                                 (None, "not_copy"): (0, 0)
                             })
    skip_tracker = SkipTrackerThroughPotals(skip_layout)

    batch = Batch(torch.tensor([1.0]))
    a = torch.tensor([2.0])
    b = torch.tensor([2.0])

    skip_tracker.save(batch, None, "copy", a)
    skip_tracker.save(batch, None, "not_copy", b)

    assert (None, "copy") in skip_tracker.portals
    assert (None, "copy") not in skip_tracker.tensors
    assert (None, "not_copy") in skip_tracker.tensors
    assert (None, "not_copy") not in skip_tracker.portals
コード例 #8
0
 def log_thread_id():
     thread_id = threading.current_thread().ident
     thread_ids.add(thread_id)
     return Batch(())
コード例 #9
0
 def counter():
     nonlocal count
     time.sleep(0.1)
     count += 1
     return Batch(())
コード例 #10
0
 def detect_grad_enabled():
     x = torch.rand(1, requires_grad=torch.is_grad_enabled())
     return Batch(x)
コード例 #11
0
 def _42():
     return Batch(torch.tensor(42))