def test(self): xla_device = xm.xla_device() kdata = [_gen_tensor(2, 3), _gen_tensor(3, 4)] kdata.append([_gen_tensor(2, 5), _gen_tensor(3, 6)]) data = dict() data[_gen_tensor(2, 2)] = tuple(kdata) data[_gen_tensor(2, 4)] = set([12.0, _gen_tensor(3, 7)]) data['ABC'] = _gen_tensor(4, 3) def select_fn(v): return type(v) == torch.Tensor def convert_fn(tensors): devices = [str(xla_device)] * len(tensors) return torch_xla._XLAC._xla_tensors_from_aten(tensors, devices) def check_fn(v): if select_fn(v): return xm.is_xla_tensor(v) elif isinstance(v, (list, tuple, set)): for x in v: if not check_fn(x): return False elif isinstance(v, dict): for k, x in v.items(): if not check_fn(k) or not check_fn(x): return False return True xla_data = xm.ToXlaTensorArena(convert_fn, select_fn).transform(data) self.assertTrue(check_fn(xla_data))
def _send_data_to(self, data, device): def convert_fn(tensors): devices = [str(device)] * len(tensors) return torch_xla._XLAC._xla_tensors_from_aten(tensors, devices) def select_fn(v): return type(v) == torch.Tensor return xm.ToXlaTensorArena(convert_fn, select_fn).transform(data)