Example #1
0
 def _module_runner(self, loop_fn, device, module, loader, context, result):
   xm.set_replication(device, self._replication)
   try:
     result.result = loop_fn(module, loader, torch.device(device), context)
   except Exception as e:
     result.result = e
     self._handle_runner_exception(device, e)
Example #2
0
def _setup_replication():
    if xm.xrt_world_size() > 1:
        device = xm.xla_device()
        xm.set_replication(str(device), [str(device)])