def test_gather_tensors(): a = torch.zeros(1, 1) b = torch.zeros(1, 1) ab = gather([Batch(a), Batch(b)]) assert ab.size() == (2, 1)
def test_gather_tuples(): a = (torch.zeros(1, 1), torch.zeros(2, 2)) b = (torch.zeros(1, 1), torch.zeros(2, 2)) ab = gather([Batch(a), Batch(b)]) assert isinstance(ab, tuple) assert ab[0].size() == (2, 1) assert ab[1].size() == (4, 2)
def test_batch_call(): a = Batch(torch.tensor(42)) b = Batch((torch.tensor(42), torch.tensor(21))) def f(x): return x assert a.call(f).atomic assert not b.call(f).atomic
def test_batch_setitem_by_slice(): a = Batch(torch.tensor(42)) b = Batch((torch.tensor(42), torch.tensor(21))) a[:] = (torch.tensor(0), ) b[:] = (torch.tensor(0), ) assert a.atomic assert a[0].item() == 0 assert not b.atomic assert len(b) == 1 assert b[0].item() == 0
def test_batch_setitem_by_index(): a = Batch(torch.tensor(42)) b = Batch((torch.tensor(42), torch.tensor(21))) a[0] = torch.tensor(0) b[0] = torch.tensor(0) assert a.atomic assert a[0].item() == 0 assert not b.atomic assert len(b) == 2 assert b[0].item() == 0 assert b[1].item() == 21
def test_tensor_life_without_checkpointing(): skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) skip_tracker = SkipTrackerThroughPotals(skip_layout) batch = Batch(torch.tensor([1.0])) tensor = torch.tensor([2.0]) skip_tracker.save(batch, None, "test", tensor) assert skip_tracker.portals[(None, "test")].tensor_life == 1 skip_tracker.load(batch, None, "test") assert skip_tracker.portals[(None, "test")].tensor_life == 0
def test_batch_non_atomic(): x, y = torch.tensor(42), torch.tensor(21) b = Batch((x, y)) assert not b.atomic with pytest.raises(AttributeError): b.tensor assert list(b) == [x, y] assert len(b) == 2 assert b[0] is x assert b[1] is y
def test_batch_atomic(): x = torch.tensor(42) b = Batch(x) assert b.atomic assert b.tensor is x with pytest.raises(AttributeError): b.tensors assert list(b) == [x] assert len(b) == 1 assert b[0] is x
def test_reuse_portal(): skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) skip_tracker = SkipTrackerThroughPotals(skip_layout) batch = Batch(torch.tensor([1.0])) a = torch.tensor([2.0]) b = torch.tensor([2.0]) skip_tracker.save(batch, None, "test", a) portal = skip_tracker.portals[(None, "test")] skip_tracker.save(batch, None, "test", b) assert portal is skip_tracker.portals[(None, "test")]
def test_not_requires_grad(): x = Batch(torch.rand(1, requires_grad=False)) assert not x[0].requires_grad def f(x): return x * 2 chk = Checkpointing(f, x) x = chk.checkpoint() assert x[0].requires_grad chk.recompute(x) assert x[0].requires_grad x.tensor.backward()
def test_no_copy_no_portal(): skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "copy"): (0, 1), (None, "not_copy"): (0, 0)}) skip_tracker = SkipTrackerThroughPotals(skip_layout) batch = Batch(torch.tensor([1.0])) a = torch.tensor([2.0]) b = torch.tensor([2.0]) skip_tracker.save(batch, None, "copy", a) skip_tracker.save(batch, None, "not_copy", b) assert (None, "copy") in skip_tracker.portals assert (None, "copy") not in skip_tracker.tensors assert (None, "not_copy") in skip_tracker.tensors assert (None, "not_copy") not in skip_tracker.portals
def detect_grad_enabled(): x = torch.rand(1, requires_grad=torch.is_grad_enabled()) return Batch(x)
def _42(): return Batch(torch.tensor(42))
def log_thread_id(): thread_id = threading.current_thread().ident thread_ids.add(thread_id) return Batch(())
def counter(): nonlocal count time.sleep(0.1) count += 1 return Batch(())