def check_view_sharing(obj): tensors = set() aliases = dict() def tensor_info(t): return '{}{}'.format(t.dtype, list(t.size())) def tensor_id(t): if is_xla_tensor(t): return torch_xla._XLAC._xla_get_tensor_id(t), 'xla' return id(t), 'torch' def alias_id(t): if is_xla_tensor(t): aid = torch_xla._XLAC._xla_get_tensor_view_alias_id(t) return None if aid == 0 else aid, 'xla' return t.storage().data_ptr(), 'torch' def check_object(obj): tid = tensor_id(obj) if tid not in tensors: tensors.add(tid) aid = alias_id(obj) if aid[0] is not None: if aid in aliases: oobj = aliases[aid] raise RuntimeError( 'Tensor ID {} ({}) is sharing a view with tensor ID {} ({})' .format(tid, tensor_info(obj), tensor_id(oobj), tensor_info(oobj))) aliases[aid] = obj xu.for_each_instance(obj, lambda x: type(x) == torch.Tensor, check_object)
def test_util_foreach_api(self): class ForTest(object): def __init__(self): self.a = {'k': [1, 2, 3], 4.9: 'y', 5: {'a': 'n'}} self.b = ('f', 17) duped_data = ForTest() data = { 2.3: 11, 21: ForTest(), 'w': [12, ForTest(), duped_data], 123: duped_data, } ids = [] def collect(x): ids.append(id(x)) xu.for_each_instance(data, lambda x: isinstance(x, (int, str, float)), collect) wids = [] def convert(x): wids.append(id(x)) return x xu.for_each_instance_rewrite(data, lambda x: isinstance(x, (int, str, float)), convert) self.assertEqual(len(ids), 17) self.assertEqual(ids, wids)
def _get_batch_size(self, data, dim): size = [] def fn(v): csize = v.size()[dim] if not size: size.append(csize) else: assert csize == size[0] xu.for_each_instance(data, torch.Tensor, fn) return size[0] if size else None
def check_view_sharing(obj): tensors = set() aliases = dict() def check_object(obj): if is_xla_tensor(obj): tid = torch_xla._XLAC._xla_get_tensor_id(obj) if tid not in tensors: tensors.add(tid) aid = torch_xla._XLAC._xla_get_tensor_view_alias_id(obj) if aid != 0: if aid in aliases: oobj = aliases[aid] raise RuntimeError( 'Tensor ID {} is sharing a view with tensor ID {}'. format(tid, torch_xla._XLAC._xla_get_tensor_id(oobj))) aliases[aid] = obj xu.for_each_instance(obj, torch.Tensor, check_object)
def _collect_tensors(self, inputs): def collect_fn(value): self._add(value) xu.for_each_instance(inputs, lambda x: self._select_fn(x), collect_fn)