Beispiel #1
0
    def test_broadcast_checks(self):
        store = c10d.FileStore(self.file_name, self.world_size)
        pg = c10d.ProcessGroupCCL(store, self.rank, self.world_size)

        t1 = torch.zeros([1], dtype=torch.float32)
        t2 = torch.zeros([1], dtype=torch.float64)
        t3 = torch.zeros([2], dtype=torch.float32)

        with self.assertRaisesRegex(ValueError, "unexpected rank"):
            opts = c10d.BroadcastOptions()
            opts.rootRank = -1
            opts.rootTensor = 0
            pg.broadcast([t1], opts)
 def test_broadcast(self):
   device = xm.xla_device()
   tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
   pg_xla = get_process_group_xla(rank=0, size=8)
   opts = dist.BroadcastOptions()
   opts.rootRank = 0
   opts.rootTensor = 0
   # xla doesn't have broadcast. We use all_reduce to implement broadcast.
   all_reduce_pattern = r'%all\-reduce\.\d+ = .+ all\-reduce\('
   with xm_cc_op_intercepted('all_reduce'):
     pg_xla.broadcast([tensor], opts)
   hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
   hlo_matches(hlo, all_reduce_pattern)
   # purge all computations attached the device.
   xm.mark_step()
Beispiel #3
0
 def broadcast(xs, rootRank, rootTensor):
     opts = c10d.BroadcastOptions()
     opts.rootRank = rootRank
     opts.rootTensor = rootTensor
     work = pg.broadcast(xs, opts)
     work.wait()