Ejemplo n.º 1
0
  def test(self):
    xla_device = xm.xla_device()

    kdata = [_gen_tensor(2, 3), _gen_tensor(3, 4)]
    kdata.append([_gen_tensor(2, 5), _gen_tensor(3, 6)])
    data = dict()
    data[_gen_tensor(2, 2)] = tuple(kdata)
    data[_gen_tensor(2, 4)] = set([12.0, _gen_tensor(3, 7)])
    data['ABC'] = _gen_tensor(4, 3)

    def select_fn(v):
      return type(v) == torch.Tensor

    def convert_fn(tensors):
      devices = [str(xla_device)] * len(tensors)
      return torch_xla._XLAC._xla_tensors_from_aten(tensors, devices)

    def check_fn(v):
      if select_fn(v):
        return xm.is_xla_tensor(v)
      elif isinstance(v, (list, tuple, set)):
        for x in v:
          if not check_fn(x):
            return False
      elif isinstance(v, dict):
        for k, x in v.items():
          if not check_fn(k) or not check_fn(x):
            return False
      return True

    xla_data = xm.ToXlaTensorArena(convert_fn, select_fn).transform(data)
    self.assertTrue(check_fn(xla_data))
Ejemplo n.º 2
0
    def _send_data_to(self, data, device):
        def convert_fn(tensors):
            devices = [str(device)] * len(tensors)
            return torch_xla._XLAC._xla_tensors_from_aten(tensors, devices)

        def select_fn(v):
            return type(v) == torch.Tensor

        return xm.ToXlaTensorArena(convert_fn, select_fn).transform(data)
Ejemplo n.º 3
0
def _rewrite_data(path, data, save_tensors):
    def convert_fn(tensors):
        torch_xla._XLAC._xla_sync_multi(tensors,
                                        devices=[],
                                        wait=True,
                                        sync_xla_data=True)
        rewritten_tensors = []
        for i, t in enumerate(tensors):
            if save_tensors:
                torch.save(t.cpu(), _get_tensor_file(path, i))
            rewritten_tensors.append(TensorReference(i))
        return rewritten_tensors

    def select_fn(v):
        return type(v) == torch.Tensor and xm.is_xla_tensor(v)

    os.mkdir(path)
    return xm.ToXlaTensorArena(convert_fn, select_fn).transform(data)
Ejemplo n.º 4
0
def load(path):
    """Loads data previously saved with the `save()` API.

  Args:
    path (str): The path passed to the `save()` API.
  Returns:
    The loaded data.
  """
    ref_data = torch.load(path)
    tensor_folder = _get_tensors_folder(path)

    def convert_fn(tensors):
        rewritten_tensors = []
        for t in tensors:
            rewritten_tensors.append(
                torch.load(_get_tensor_file(tensor_folder, t.tid)))
        return rewritten_tensors

    def select_fn(v):
        return type(v) == TensorReference

    return xm.ToXlaTensorArena(convert_fn, select_fn).transform(ref_data)