class LightningPipeModule(nn.Module): """ This class wraps Fairscale Pipe and PipeRCPWrapper class. """ def __init__(self, module: nn.Sequential, balance: List[int], microbatches: int = 8, checkpoint='never'): super().__init__() self.module = module self.balance = balance self.microbatches = microbatches self.checkpoint = checkpoint self._init_pipe() def _init_pipe(self): device = torch.device("cuda", torch_distrib.get_rank()) self.module = PipeRPCWrapper( module=self.module, balance=self.balance, chunks=self.microbatches, style=PipelineStyle.MultiProcess, input_device=device, worker_map=self.get_worker_map(), checkpoint=self.checkpoint, ) def foreach_worker(self, *args, **kwargs): self.module.foreach_worker(*args, **kwargs) def forward(self, *args, **kwargs): return self.module(*args, **kwargs) def get_worker_map(self): # TODO, is this correct with multinodes? We also assume "worker" is the same as defined in the RPCPlugin return {rank: f"worker{rank}" for rank in range(torch_distrib.get_world_size())}
def _init_pipe(self): device = torch.device("cuda", torch_distrib.get_rank()) self.module = PipeRPCWrapper( module=self.module, balance=self.balance, chunks=self.microbatches, style=PipelineStyle.MultiProcess, input_device=device, worker_map=self.get_worker_map(), checkpoint=self.checkpoint, )