Example #1
0
def test_current_microbatch():
    class Twice(nn.Module):
        def forward(self, x):
            return x * 2

    class CurrentMicrobatch(nn.Module):
        def forward(self, _):
            return current_microbatch()

    # Not in a partition.
    assert current_microbatch() is None

    input = torch.tensor([1., 2., 3.])

    model = nn.Sequential(Twice(), CurrentMicrobatch())
    model = GPipe(model, balance=[1, 1], devices=['cpu', 'cpu'], chunks=3)

    output = model(input)

    assert torch.allclose(output, torch.tensor([1., 2., 3.]))

    # Not in a partition.
    assert current_microbatch() is None
Example #2
0
 def forward(self, _):
     return current_microbatch()