def test(self): xla_device = xm.xla_device() kdata = [_gen_tensor(2, 3), _gen_tensor(3, 4)] kdata.append([_gen_tensor(2, 5), _gen_tensor(3, 6)]) data = dict() data[_gen_tensor(2, 2)] = tuple(kdata) data[_gen_tensor(2, 4)] = set([12.0, _gen_tensor(3, 7)]) data['ABC'] = _gen_tensor(4, 3) def select_fn(v): return type(v) == torch.Tensor def convert_fn(tensors): devices = [str(xla_device)] * len(tensors) return torch_xla._XLAC._xla_tensors_from_aten(tensors, devices) def check_fn(v): if select_fn(v): return xm.is_xla_tensor(v) elif isinstance(v, (list, tuple, set)): for x in v: if not check_fn(x): return False elif isinstance(v, dict): for k, x in v.items(): if not check_fn(k) or not check_fn(x): return False return True xla_data = xm.ToXlaTensorArena(convert_fn, select_fn).transform(data) self.assertTrue(check_fn(xla_data))
def _send_data_to(self, data, device): def convert_fn(tensors): devices = [str(device)] * len(tensors) return torch_xla._XLAC._xla_tensors_from_aten(tensors, devices) def select_fn(v): return type(v) == torch.Tensor return xm.ToXlaTensorArena(convert_fn, select_fn).transform(data)
def _rewrite_data(path, data, save_tensors): def convert_fn(tensors): torch_xla._XLAC._xla_sync_multi(tensors, devices=[], wait=True, sync_xla_data=True) rewritten_tensors = [] for i, t in enumerate(tensors): if save_tensors: torch.save(t.cpu(), _get_tensor_file(path, i)) rewritten_tensors.append(TensorReference(i)) return rewritten_tensors def select_fn(v): return type(v) == torch.Tensor and xm.is_xla_tensor(v) os.mkdir(path) return xm.ToXlaTensorArena(convert_fn, select_fn).transform(data)
def load(path): """Loads data previously saved with the `save()` API. Args: path (str): The path passed to the `save()` API. Returns: The loaded data. """ ref_data = torch.load(path) tensor_folder = _get_tensors_folder(path) def convert_fn(tensors): rewritten_tensors = [] for t in tensors: rewritten_tensors.append( torch.load(_get_tensor_file(tensor_folder, t.tid))) return rewritten_tensors def select_fn(v): return type(v) == TensorReference return xm.ToXlaTensorArena(convert_fn, select_fn).transform(ref_data)