def initialize_distributed_model(self): if self.local_rank != -1: if not self.fp16: self.model = DDP(self.model) else: flat_dist_call([param.data for param in self.model.parameters()], torch.distributed.broadcast, (0,)) elif self.n_gpu > 1: self.model = nn.DataParallel(self.model)
class ForwardRunner: def __init__(self, model: ProteinModel, device: torch.device = torch.device('cuda:0'), n_gpu: int = 1, fp16: bool = False, local_rank: int = -1): self.model = model self.device = device self.n_gpu = n_gpu self.fp16 = fp16 self.local_rank = local_rank forward_arg_keys = inspect.getfullargspec(model.forward).args forward_arg_keys = forward_arg_keys[1:] # remove self argument self._forward_arg_keys = forward_arg_keys assert 'input_ids' in self._forward_arg_keys def initialize_distributed_model(self): if self.local_rank != -1: if not self.fp16: self.model = DDP(self.model) else: flat_dist_call([param.data for param in self.model.parameters()], torch.distributed.broadcast, (0,)) elif self.n_gpu > 1: self.model = nn.DataParallel(self.model) def forward(self, batch: typing.Dict[str, torch.Tensor], return_outputs: bool = False, no_loss: bool = False): # Filter out batch items that aren't used in this model # Requires that dataset keys match the forward args of the model # Useful if some elements of the data are only used by certain models # e.g. PSSMs / MSAs and other evolutionary data batch = {name: tensor for name, tensor in batch.items() if name in self._forward_arg_keys} if self.device.type == 'cuda': batch = {name: tensor.cuda(device=self.device, non_blocking=True) for name, tensor in batch.items()} outputs = self.model(**batch) if no_loss: return outputs if isinstance(outputs[0], tuple): # model also returned metrics loss, metrics = outputs[0] else: # no metrics loss = outputs[0] metrics = {} if self.n_gpu > 1: # pytorch DataDistributed doesn't mean scalars loss = loss.mean() metrics = {name: metric.mean() for name, metric in metrics.items()} if return_outputs: return loss, metrics, outputs else: return loss, metrics def train(self): self.model.train() return self def eval(self): self.model.eval() return self