コード例 #1
0
    def test(self):
        xla_device = xm.xla_device()

        kdata = [_gen_tensor(2, 3), _gen_tensor(3, 4)]
        kdata.append([_gen_tensor(2, 5), _gen_tensor(3, 6)])
        data = dict()
        data[_gen_tensor(2, 2)] = tuple(kdata)
        data[_gen_tensor(2, 4)] = set([12.0, _gen_tensor(3, 7)])
        data['ABC'] = _gen_tensor(4, 3)

        def select_fn(v):
            return type(v) == torch.Tensor

        def convert_fn(tensors):
            devices = [str(xla_device)] * len(tensors)
            return torch_xla._XLAC._xla_tensors_from_aten(tensors, devices)

        def check_fn(v):
            if select_fn(v):
                return xm.is_xla_tensor(v)
            elif isinstance(v, (list, tuple, set)):
                for x in v:
                    if not check_fn(x):
                        return False
            elif isinstance(v, dict):
                for k, x in v.items():
                    if not check_fn(k) or not check_fn(x):
                        return False
            return True

        xla_data = xm.ToXlaTensorArena(convert_fn, select_fn).transform(data)
        self.assertTrue(check_fn(xla_data))
コード例 #2
0
ファイル: data_parallel.py プロジェクト: Saiuz/xla
    def _send_data_to(self, data, device):
        def convert_fn(tensors):
            devices = [str(device)] * len(tensors)
            return torch_xla._XLAC._xla_tensors_from_aten(tensors, devices)

        def select_fn(v):
            return type(v) == torch.Tensor

        return xm.ToXlaTensorArena(convert_fn, select_fn).transform(data)