def test_batch_atomic(): x = torch.tensor(42) b = Batch(x) assert b.atomic assert b.tensor is x with pytest.raises(AttributeError): b.tensors assert list(b) == [x] assert len(b) == 1 assert b[0] is x
def test_tensor_life_without_checkpointing(): skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) skip_tracker = SkipTrackerThroughPotals(skip_layout) batch = Batch(torch.tensor([1.0])) tensor = torch.tensor([2.0]) skip_tracker.save(batch, None, "test", tensor) assert skip_tracker.portals[(None, "test")].tensor_life == 1 skip_tracker.load(batch, None, "test") assert skip_tracker.portals[(None, "test")].tensor_life == 0
def test_batch_non_atomic(): x, y = torch.tensor(42), torch.tensor(21) b = Batch((x, y)) assert not b.atomic with pytest.raises(AttributeError): b.tensor assert b.tensors == (x, y) assert list(b) == [x, y] assert len(b) == 2 assert b[0] is x assert b[1] is y
def test_reuse_portal(): skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) skip_tracker = SkipTrackerThroughPotals(skip_layout) batch = Batch(torch.tensor([1.0])) a = torch.tensor([2.0]) b = torch.tensor([2.0]) skip_tracker.save(batch, None, "test", a) portal = skip_tracker.portals[(None, "test")] skip_tracker.save(batch, None, "test", b) assert portal is skip_tracker.portals[(None, "test")]
def test_batch_call(): a = Batch(torch.tensor(42)) b = Batch((torch.tensor(42), torch.tensor(21))) def f(x): return x assert a.call(f).atomic assert not b.call(f).atomic
def test_not_requires_grad(): x = Batch(torch.rand(1, requires_grad=False)) assert not x[0].requires_grad def f(x): return x * 2 chk = Checkpointing(f, x) x = chk.checkpoint() assert x[0].requires_grad chk.recompute(x) assert x[0].requires_grad x.tensor.backward()
def test_no_copy_no_portal(): skip_layout = SkipLayout(num_partitions=2, skip_routes={ (None, "copy"): (0, 1), (None, "not_copy"): (0, 0) }) skip_tracker = SkipTrackerThroughPotals(skip_layout) batch = Batch(torch.tensor([1.0])) a = torch.tensor([2.0]) b = torch.tensor([2.0]) skip_tracker.save(batch, None, "copy", a) skip_tracker.save(batch, None, "not_copy", b) assert (None, "copy") in skip_tracker.portals assert (None, "copy") not in skip_tracker.tensors assert (None, "not_copy") in skip_tracker.tensors assert (None, "not_copy") not in skip_tracker.portals
def log_thread_id(): thread_id = threading.current_thread().ident thread_ids.add(thread_id) return Batch(())
def counter(): nonlocal count time.sleep(0.1) count += 1 return Batch(())
def detect_grad_enabled(): x = torch.rand(1, requires_grad=torch.is_grad_enabled()) return Batch(x)
def _42(): return Batch(torch.tensor(42))