def _run_basic_test(self, backend, checkpoint, find_unused_parameters=False, static_graph=False): dist.init_process_group( backend="nccl", init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), world_size=self.world_size, rank=self.rank, ) # Use 4 GPUs, two replicas of a pipe across GPU 0 and 1 and another # pipe between GPU 2 and 3. Both replicas are replicated via DDP. fc1 = nn.Linear(16, 8, bias=False).cuda(2 * self.rank) class MyModule(nn.Module): def __init__(self, device): super(MyModule, self).__init__() self.fc2 = nn.Linear(8, 4, bias=False).cuda(device) self.fc3 = nn.Linear(4, 2, bias=False).cuda(device) def forward(self, inp): if find_unused_parameters: return self.fc2(inp) else: return self.fc3(self.fc2(inp)) layer2 = MyModule(2 * self.rank + 1) model = nn.Sequential( fc1, layer2 ) model = Pipe(model, chunks=2, checkpoint=checkpoint) model = DistributedDataParallel(model, find_unused_parameters=find_unused_parameters) if static_graph: model._set_static_graph() out = model(torch.rand(16, 16).cuda(2 * self.rank)).local_value() out.sum().backward() # Run forward again for find_unused_parameters to trigger any potential errors. if find_unused_parameters: model(torch.rand(16, 16).cuda(2 * self.rank)) # Check grads output = [torch.empty_like(fc1.weight.grad), torch.empty_like(fc1.weight.grad)] dist.all_gather(output, fc1.weight.grad) self.assertEqual(output[0], output[1]) output = [torch.empty_like(layer2.fc2.weight.grad), torch.empty_like(layer2.fc2.weight.grad)] dist.all_gather(output, layer2.fc2.weight.grad) self.assertEqual(output[0], output[1]) if not find_unused_parameters: output = [torch.empty_like(layer2.fc3.weight.grad), torch.empty_like(layer2.fc3.weight.grad)] dist.all_gather(output, layer2.fc3.weight.grad) self.assertEqual(output[0], output[1])
def _run_basic_test(self, backend, checkpoint, find_unused_parameters=False, static_graph=False): dist.init_process_group( backend=backend, init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name), world_size=self.world_size, rank=self.rank, ) # Use 4 GPUs, two replicas of a pipe across GPU 0 and 1 and another # pipe between GPU 2 and 3. Both replicas are replicated via DDP. fc1 = nn.Linear(16, 8, bias=False).cuda(2 * self.rank) class MyModule(nn.Module): def __init__(self, device): super(MyModule, self).__init__() self.fc2 = nn.Linear(8, 4, bias=False).cuda(device) self.fc3 = nn.Linear(4, 2, bias=False).cuda(device) def forward(self, inp): if find_unused_parameters: return self.fc2(inp) else: return self.fc3(self.fc2(inp)) layer2 = MyModule(2 * self.rank + 1) model = nn.Sequential(fc1, layer2) model = Pipe(model, chunks=2, checkpoint=checkpoint) model = DistributedDataParallel( model, find_unused_parameters=find_unused_parameters) if static_graph: model._set_static_graph() # Ensure inputs are different across ranks to verify that gradient # sync indeed occurs. model_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1) out = model(model_input).local_value() out.sum().backward() # Run forward again for find_unused_parameters to trigger any potential errors. if find_unused_parameters: # Ensure inputs are different across ranks to verify that gradient # sync indeed occurs. unused_param_input = torch.rand(16, 16).cuda( 2 * self.rank) * (self.rank + 1) model(unused_param_input).local_value().sum().backward() # Run a few more iterations of fwd + bwd to ensure gradient synchronization # occurs properly across iterations via delay_all_reduce/bucketized allreduce. for _ in range(3): model_input = torch.rand(16, 16).cuda( 2 * self.rank) * (self.rank + 1) out = model(model_input).local_value() out.sum().backward() # Check grads output = [ torch.empty_like(fc1.weight.grad), torch.empty_like(fc1.weight.grad) ] dist.all_gather(output, fc1.weight.grad) self.assertEqual(output[0], output[1]) output = [ torch.empty_like(layer2.fc2.weight.grad), torch.empty_like(layer2.fc2.weight.grad) ] dist.all_gather(output, layer2.fc2.weight.grad) self.assertEqual(output[0], output[1]) if not find_unused_parameters: output = [ torch.empty_like(layer2.fc3.weight.grad), torch.empty_like(layer2.fc3.weight.grad) ] dist.all_gather(output, layer2.fc3.weight.grad) self.assertEqual(output[0], output[1])