def test_already_has_grad(): model = nn.Sequential(nn.Conv2d(3, 3, 1)) sample = torch.rand(1, 3, 32, 32) model(sample).norm().backward() with pytest.raises(ValueError, match="some parameter already has gradient"): balance_by_time(1, model, sample, device="cpu")
def test_balance_by_time_tuple(): class Twin(nn.Module): def forward(self, x): return x, x.detach() class Add(nn.Module): def forward(self, a, b): return a + b model = nn.Sequential(Twin(), Add()) sample = torch.rand(1, requires_grad=True) balance_by_time(1, model, sample, device="cpu")
def test_sandbox_during_profiling(device): model = nn.Sequential(nn.BatchNorm2d(3)) before = {k: v.clone() for k, v in model.state_dict().items()} sample = torch.rand(1, 3, 10, 10) balance_by_time(1, model, sample, device=device) after = model.state_dict() assert before.keys() == after.keys() for key, value in before.items(): assert torch.allclose(after[key], value), key
def test_not_training(): class AssertTraining(nn.Module): def forward(self, x): assert self.training return x model = nn.Sequential(AssertTraining()) model.eval() assert not model.training sample = torch.rand(1) balance_by_time(1, model, sample, device="cpu") assert not model.training
def test_balance_by_time_loop_resets_input(): # nn.Flatten was introduced at PyTorch 1.2.0. class Flatten(nn.Module): def forward(self, x): return x.flatten(1) model = nn.Sequential(nn.Conv2d(3, 2, 1), Flatten(), nn.Linear(128, 10)) sample = torch.rand(10, 3, 8, 8) balance = balance_by_time(2, model, sample, device="cpu") assert balance == [1, 2]
def test_balance_by_time(device): class Delay(nn.Module): def __init__(self, seconds): super().__init__() self.seconds = seconds def forward(self, x): time.sleep(self.seconds) return x model = nn.Sequential(*[Delay(i / 10) for i in [1, 2, 3, 4, 5, 6]]) sample = torch.rand(1) balance = balance_by_time(2, model, sample, device=device) assert balance == [4, 2]