def test_recv(self): device = xm.xla_device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() output_list = [tensor] set_world_size(6) ranks = [0, 3] world_rank = 0 set_world_rank(world_rank) torch_xla.distributed.xla_backend.ProcessGroupXla.make_recv_channel_id = ( lambda self, src_rank, tag: src_rank * 3) with new_group_barrier_disabled(): pg_xla = dist.new_group(ranks=ranks) recv_pattern = r'%recv\.\d+ = .+ recv\(.+\), channel_id=3' recvdone_pattern = r'%recv\-done\.\d+ = .+ recv\-done\(.+\), channel_id=3' # seeing 'recv is not implemented on CPU' means we have successfully # generated `recv` in the HLO. with self.assertRaises(RuntimeError) as cm: pg_xla.recv(output_list, 1) hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_list) hlo_matches(hlo, recv_pattern) hlo_matches(hlo, recvdone_pattern) xm.mark_step() assert 'UNIMPLEMENTED: Recv is not implemented on CPU.' in str( cm.exception), str(cm.exception) # reset token to clean up the mess after the RuntimeError. xm.set_replication(device, [])
def _module_runner(self, loop_fn, device, module, loader, context, result): xm.set_replication(device, self._device_ids) try: result.result = loop_fn(module, loader, torch.device(device), context) except Exception as e: result.result = e self._handle_runner_exception(device, e)
def _module_runner(self, loop_fn, device, module, loader, context, result): if len(self._device_ids) > 1: xm.set_replication(device, self._device_ids) else: torch_xla._XLAC._xla_set_default_device(device) try: result.result = loop_fn(module, loader, torch.device(device), context) except Exception as e: result.result = e self._handle_runner_exception(device, e)
def run_thread_per_device(rank: int, processes: int, fn: Callable[..., R]) -> Dict[int, R]: """Runs `fn` in a separate thread on each visible device. Args: rank: rank of current process processes: number of processes on this host fn: Function to run on all devices Returns: Dict of the form {thread_rank: return_value}, where return_value is the result of calling `fn`. """ if device_type() == 'TPU': configure_tpu_topology(rank, processes) xm.set_replication(xm.xla_device(), xm.get_xla_supported_devices()) threads = len(xm.get_xla_supported_devices()) def _thread_fn(fn, device_index): @functools.wraps(fn) def wrapper(*args, **kwargs): # Assumes same number of threads per process set_global_ordinal(rank * threads + device_index) set_local_ordinal(rank * threads + device_index) return fn(*args, **kwargs) return wrapper with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor: futures = {executor.submit(_thread_fn(fn, i)): i for i in range(threads)} results = { futures[f]: f.result() for f in concurrent.futures.as_completed(futures) } return results
def _setup_replication(): if xm.xrt_world_size() > 1: device = xm.xla_device() xm.set_replication(str(device), [str(device)])
def _setup_replication(): # At this point xla_model.py APIs are allowed as the setup is already # completed. if xm.xrt_world_size() > 1: device = xm.xla_device() xm.set_replication(device, [device])