def _module_runner(self, loop_fn, device, module, loader, context, result): xm.set_replication(device, self._replication) try: result.result = loop_fn(module, loader, torch.device(device), context) except Exception as e: result.result = e self._handle_runner_exception(device, e)
def _setup_replication(): if xm.xrt_world_size() > 1: device = xm.xla_device() xm.set_replication(str(device), [str(device)])