def test_current_microbatch(): class Twice(nn.Module): def forward(self, x): return x * 2 class CurrentMicrobatch(nn.Module): def forward(self, _): return current_microbatch() # Not in a partition. assert current_microbatch() is None input = torch.tensor([1., 2., 3.]) model = nn.Sequential(Twice(), CurrentMicrobatch()) model = GPipe(model, balance=[1, 1], devices=['cpu', 'cpu'], chunks=3) output = model(input) assert torch.allclose(output, torch.tensor([1., 2., 3.])) # Not in a partition. assert current_microbatch() is None
def forward(self, _): return current_microbatch()