def test_recv(self):
    device = xm.xla_device()
    tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
    output_list = [tensor]
    set_world_size(6)
    ranks = [0, 3]
    world_rank = 0
    set_world_rank(world_rank)

    torch_xla.distributed.xla_backend.ProcessGroupXla.make_recv_channel_id = (
        lambda self, src_rank, tag: src_rank * 3)

    with new_group_barrier_disabled():
      pg_xla = dist.new_group(ranks=ranks)

    recv_pattern = r'%recv\.\d+ = .+ recv\(.+\), channel_id=3'
    recvdone_pattern = r'%recv\-done\.\d+ = .+ recv\-done\(.+\), channel_id=3'
    # seeing 'recv is not implemented on CPU' means we have successfully
    # generated `recv` in the HLO.
    with self.assertRaises(RuntimeError) as cm:
      pg_xla.recv(output_list, 1)
      hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_list)
      hlo_matches(hlo, recv_pattern)
      hlo_matches(hlo, recvdone_pattern)
      xm.mark_step()
    assert 'UNIMPLEMENTED: Recv is not implemented on CPU.' in str(
        cm.exception), str(cm.exception)
    # reset token to clean up the mess after the RuntimeError.
    xm.set_replication(device, [])
示例#2
0
 def _module_runner(self, loop_fn, device, module, loader, context, result):
     xm.set_replication(device, self._device_ids)
     try:
         result.result = loop_fn(module, loader, torch.device(device),
                                 context)
     except Exception as e:
         result.result = e
         self._handle_runner_exception(device, e)
示例#3
0
 def _module_runner(self, loop_fn, device, module, loader, context, result):
   if len(self._device_ids) > 1:
     xm.set_replication(device, self._device_ids)
   else:
     torch_xla._XLAC._xla_set_default_device(device)
   try:
     result.result = loop_fn(module, loader, torch.device(device), context)
   except Exception as e:
     result.result = e
     self._handle_runner_exception(device, e)
示例#4
0
def run_thread_per_device(rank: int, processes: int,
                          fn: Callable[..., R]) -> Dict[int, R]:
  """Runs `fn` in a separate thread on each visible device.

  Args:
    rank: rank of current process
    processes: number of processes on this host
    fn: Function to run on all devices

  Returns:
    Dict of the form {thread_rank: return_value}, where return_value is the
    result of calling `fn`.
  """
  if device_type() == 'TPU':
    configure_tpu_topology(rank, processes)

  xm.set_replication(xm.xla_device(), xm.get_xla_supported_devices())
  threads = len(xm.get_xla_supported_devices())

  def _thread_fn(fn, device_index):

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
      # Assumes same number of threads per process
      set_global_ordinal(rank * threads + device_index)
      set_local_ordinal(rank * threads + device_index)

      return fn(*args, **kwargs)

    return wrapper

  with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
    futures = {executor.submit(_thread_fn(fn, i)): i for i in range(threads)}

    results = {
        futures[f]: f.result() for f in concurrent.futures.as_completed(futures)
    }

  return results
示例#5
0
def _setup_replication():
    if xm.xrt_world_size() > 1:
        device = xm.xla_device()
        xm.set_replication(str(device), [str(device)])
示例#6
0
def _setup_replication():
    # At this point xla_model.py APIs are allowed as the setup is already
    # completed.
    if xm.xrt_world_size() > 1:
        device = xm.xla_device()
        xm.set_replication(device, [device])