Ejemplo n.º 1
0
def test_already_has_grad():
    model = nn.Sequential(nn.Conv2d(3, 3, 1))
    sample = torch.rand(1, 3, 32, 32)
    model(sample).norm().backward()

    with pytest.raises(ValueError, match="some parameter already has gradient"):
        balance_by_time(1, model, sample, device="cpu")
Ejemplo n.º 2
0
def test_balance_by_time_tuple():
    class Twin(nn.Module):
        def forward(self, x):
            return x, x.detach()

    class Add(nn.Module):
        def forward(self, a, b):
            return a + b

    model = nn.Sequential(Twin(), Add())
    sample = torch.rand(1, requires_grad=True)
    balance_by_time(1, model, sample, device="cpu")
Ejemplo n.º 3
0
def test_sandbox_during_profiling(device):
    model = nn.Sequential(nn.BatchNorm2d(3))

    before = {k: v.clone() for k, v in model.state_dict().items()}

    sample = torch.rand(1, 3, 10, 10)
    balance_by_time(1, model, sample, device=device)

    after = model.state_dict()

    assert before.keys() == after.keys()
    for key, value in before.items():
        assert torch.allclose(after[key], value), key
Ejemplo n.º 4
0
def test_not_training():
    class AssertTraining(nn.Module):
        def forward(self, x):
            assert self.training
            return x

    model = nn.Sequential(AssertTraining())

    model.eval()
    assert not model.training

    sample = torch.rand(1)
    balance_by_time(1, model, sample, device="cpu")

    assert not model.training
Ejemplo n.º 5
0
def test_balance_by_time_loop_resets_input():
    # nn.Flatten was introduced at PyTorch 1.2.0.
    class Flatten(nn.Module):
        def forward(self, x):
            return x.flatten(1)

    model = nn.Sequential(nn.Conv2d(3, 2, 1), Flatten(), nn.Linear(128, 10))
    sample = torch.rand(10, 3, 8, 8)
    balance = balance_by_time(2, model, sample, device="cpu")
    assert balance == [1, 2]
Ejemplo n.º 6
0
def test_balance_by_time(device):
    class Delay(nn.Module):
        def __init__(self, seconds):
            super().__init__()
            self.seconds = seconds

        def forward(self, x):
            time.sleep(self.seconds)
            return x

    model = nn.Sequential(*[Delay(i / 10) for i in [1, 2, 3, 4, 5, 6]])
    sample = torch.rand(1)
    balance = balance_by_time(2, model, sample, device=device)
    assert balance == [4, 2]