def test_sync_params_no_buffers(self): # Set up process group. store = c10d.TCPStore('localhost', self.port, self.is_master) options = c10d.ProcessGroupGloo.Options() options.devices = [ c10d.ProcessGroupGloo.create_tcp_device(interface="lo") ] process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) # Use all available devices on every process here (data is small, so should be fine). devices = gpus_for_rank(self.world_size)[self.rank] target = torch.arange(10, dtype=torch.float64, device='cuda:0').chunk(5) parameter_data = [target] parameter_data += [ torch.zeros(10, device=torch.device('cuda', d)).chunk(5) for d in devices[1:] ] buffer_data = [[]] * len(parameter_data) c10d._sync_params(process_group, parameter_data=parameter_data, buffer_data=buffer_data, devices=devices, broadcast_bucket_size=10, broadcast_buffers=False) for device_data in parameter_data: for i, parameter in enumerate(device_data): self.assertEqual(parameter, target[i])
def test_sync_params_with_buffers(self): store = c10d.FileStore(self.file.name) options = c10d.ProcessGroupGloo.Options() options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) devices = gpus_for_rank(self.world_size)[self.rank] target = torch.arange(10, dtype=torch.float64, device='cuda:0').chunk(5) parameter_data = [target] parameter_data += [torch.zeros(10, device=torch.device('cuda', d)).chunk(5) for d in devices[1:]] # sync_params should do a dist_broadcast for buffers, so we only populate the master buffers and # then check that other processes' tensors end up matching. if self.is_master: buffer_data = [target] buffer_data += [torch.zeros(10, device=torch.device('cuda', d)).chunk(5) for d in devices[1:]] else: buffer_data = [torch.zeros(10, device=torch.device('cuda', d)).chunk(5) for d in devices] c10d._sync_params( process_group, parameter_data=parameter_data, buffer_data=buffer_data, devices=devices, broadcast_bucket_size=10, broadcast_buffers=True) for device_data in parameter_data: for i, parameter in enumerate(device_data): self.assertEqual(parameter, target[i]) for device_data in buffer_data: for i, buffer in enumerate(device_data): self.assertEqual(buffer, target[i])