def async_send_inner(self, batch: Batch, index: int) -> Tuple[Batch, PipeMessage]: task = create_task_without_skip_trackers( self.checkpoint_stop, index, self.group.rank(), batch, self.partitions[0].module, ) result = task.compute() task.finalize(result) ranks = get_pipeline_parallel_ranks() this_rank = torch.distributed.get_rank() body = AsyncMessageBody( AsyncMessageType.Activations, index, Location(this_rank, 0), Location(ranks[ranks.index(this_rank) + 1], 0), 0, ) message = PipeMessage( this_rank, ranks[ranks.index(this_rank) + 1], queue_name=EVENT_LOOP_ACTIVATIONS_QUEUE, args=body, tensors=tuple([*result]), ) return result, message
def check_partitions(model, balance, expected_order, expected_ranks): """Check the instantiated model matches expectation of order and rank model: a list of modules or an nn.Sequential balance: the balance argument to Pipe expected_order: the index of modules in `model` in the order they will be executed, grouped by nn.Sequential expected_rank: the rank that each module will be executed on """ invocations = [] invocation_wrapper = dict() # Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from # instantiated model for rank in range(len(balance)): instantiated = instantiate_partition(model, balance, FakeGroup(rank, len(balance)), Pipe.AsyncSchedule) for part in instantiated: assert isinstance(part.module, nn.Sequential) for inv in part.invocations: invocations.append(inv) invocation_wrapper[inv] = part modules = [] prev = None current = Location(0, 0) ranks = [] for order, inv in enumerate(sorted(invocations, key=lambda x: x.order)): # Check integrity of Location chain assert inv.order == order assert inv.source == prev assert inv.this == current prev = inv.this current = inv.dest modules.append(list(invocation_wrapper[inv].module.children())) ranks.append(inv.this.stage) # assert len(modules) == len(expected_order) for left, right in zip(modules, expected_order): assert len(left) == len(right), f"{right}" assert list(map(id, left)) == list(map(id, (model[e] for e in right))), f"{right}" assert ranks == expected_ranks