Example #1
0
def test_batch_call():
    a = Batch(torch.tensor(42))
    b = Batch((torch.tensor(42), torch.tensor(21)))

    def f(x):
        return x

    assert a.call(f).atomic
    assert not b.call(f).atomic
Example #2
0
 def compute(
     batch: Batch = batch,
     partition: nn.Sequential = partition,
     chunk_id: int = i,
     part_id: int = j,
 ) -> Batch:
     with record_function("chunk%d-part%d" % (chunk_id, part_id)):
         return batch.call(partition)