def test_broadcast_checks(self): store = c10d.FileStore(self.file_name, self.world_size) pg = c10d.ProcessGroupCCL(store, self.rank, self.world_size) t1 = torch.zeros([1], dtype=torch.float32) t2 = torch.zeros([1], dtype=torch.float64) t3 = torch.zeros([2], dtype=torch.float32) with self.assertRaisesRegex(ValueError, "unexpected rank"): opts = c10d.BroadcastOptions() opts.rootRank = -1 opts.rootTensor = 0 pg.broadcast([t1], opts)
def test_broadcast(self): device = xm.xla_device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() pg_xla = get_process_group_xla(rank=0, size=8) opts = dist.BroadcastOptions() opts.rootRank = 0 opts.rootTensor = 0 # xla doesn't have broadcast. We use all_reduce to implement broadcast. all_reduce_pattern = r'%all\-reduce\.\d+ = .+ all\-reduce\(' with xm_cc_op_intercepted('all_reduce'): pg_xla.broadcast([tensor], opts) hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor]) hlo_matches(hlo, all_reduce_pattern) # purge all computations attached the device. xm.mark_step()
def broadcast(xs, rootRank, rootTensor): opts = c10d.BroadcastOptions() opts.rootRank = rootRank opts.rootTensor = rootTensor work = pg.broadcast(xs, opts) work.wait()