Exemplo n.º 1
0
def check_view_sharing(obj):
    tensors = set()
    aliases = dict()

    def tensor_info(t):
        return '{}{}'.format(t.dtype, list(t.size()))

    def tensor_id(t):
        if is_xla_tensor(t):
            return torch_xla._XLAC._xla_get_tensor_id(t), 'xla'
        return id(t), 'torch'

    def alias_id(t):
        if is_xla_tensor(t):
            aid = torch_xla._XLAC._xla_get_tensor_view_alias_id(t)
            return None if aid == 0 else aid, 'xla'
        return t.storage().data_ptr(), 'torch'

    def check_object(obj):
        tid = tensor_id(obj)
        if tid not in tensors:
            tensors.add(tid)
            aid = alias_id(obj)
            if aid[0] is not None:
                if aid in aliases:
                    oobj = aliases[aid]
                    raise RuntimeError(
                        'Tensor ID {} ({}) is sharing a view with tensor ID {} ({})'
                        .format(tid, tensor_info(obj), tensor_id(oobj),
                                tensor_info(oobj)))
                aliases[aid] = obj

    xu.for_each_instance(obj, lambda x: type(x) == torch.Tensor, check_object)
Exemplo n.º 2
0
  def test_util_foreach_api(self):

    class ForTest(object):

      def __init__(self):
        self.a = {'k': [1, 2, 3], 4.9: 'y', 5: {'a': 'n'}}
        self.b = ('f', 17)

    duped_data = ForTest()
    data = {
        2.3: 11,
        21: ForTest(),
        'w': [12, ForTest(), duped_data],
        123: duped_data,
    }

    ids = []

    def collect(x):
      ids.append(id(x))

    xu.for_each_instance(data, lambda x: isinstance(x, (int, str, float)),
                         collect)

    wids = []

    def convert(x):
      wids.append(id(x))
      return x

    xu.for_each_instance_rewrite(data,
                                 lambda x: isinstance(x, (int, str, float)),
                                 convert)
    self.assertEqual(len(ids), 17)
    self.assertEqual(ids, wids)
Exemplo n.º 3
0
    def _get_batch_size(self, data, dim):
        size = []

        def fn(v):
            csize = v.size()[dim]
            if not size:
                size.append(csize)
            else:
                assert csize == size[0]

        xu.for_each_instance(data, torch.Tensor, fn)
        return size[0] if size else None
Exemplo n.º 4
0
def check_view_sharing(obj):
    tensors = set()
    aliases = dict()

    def check_object(obj):
        if is_xla_tensor(obj):
            tid = torch_xla._XLAC._xla_get_tensor_id(obj)
            if tid not in tensors:
                tensors.add(tid)
                aid = torch_xla._XLAC._xla_get_tensor_view_alias_id(obj)
                if aid != 0:
                    if aid in aliases:
                        oobj = aliases[aid]
                        raise RuntimeError(
                            'Tensor ID {} is sharing a view with tensor ID {}'.
                            format(tid,
                                   torch_xla._XLAC._xla_get_tensor_id(oobj)))
                    aliases[aid] = obj

    xu.for_each_instance(obj, torch.Tensor, check_object)
Exemplo n.º 5
0
    def _collect_tensors(self, inputs):
        def collect_fn(value):
            self._add(value)

        xu.for_each_instance(inputs, lambda x: self._select_fn(x), collect_fn)