def test_cat_empty_tensor(self): t = _gen_tensor(2, 5, 3) empty_tensor = torch.Tensor() x = t.to(xm.xla_device()) empty_tensor_xla = empty_tensor.to(xm.xla_device()) t_cat = torch.cat([t, empty_tensor], 0) x_cat = torch.cat([x, empty_tensor_xla], 0) self.assertEqual(t_cat.data, x_cat.data.cpu())
def test_masked_fill_with_tensor(self): input = _gen_tensor(2, 5, 4, 3) mask = torch.randint(0, 2, input.size(), dtype=torch.bool) value = torch.tensor(42) xla_input = input.to(xm.xla_device()) xla_mask = mask.to(xm.xla_device()) xla_value = value.to(xm.xla_device()) result = torch.masked_fill(input, mask, value) xla_result = torch.masked_fill(xla_input, xla_mask, xla_value) self.assertEqual(input.data, xla_input.data.cpu()) self.assertEqual(result.data, xla_result.data.cpu())
def test_empty_strided(self): xla_device = xm.xla_device() m = nn.Conv1d(4, 6, kernel_size=3, groups=2) a = torch.rand(2, 4, 6, requires_grad=True) xla_m = copy.deepcopy(m).to(xla_device) xla_a = a.clone().to(xla_device).detach() xla_a.requires_grad = True output = m(a) grad_input = torch.autograd.grad(output, (a, ) + tuple(m.parameters()), output, create_graph=True) grad_grad_input = torch.autograd.grad( output.sum() + sum(map(lambda x: x.sum(), grad_input)), (a, output) + tuple(m.parameters()), retain_graph=True) xla_output = xla_m(xla_a) xla_grad_input = torch.autograd.grad(xla_output, (xla_a, ) + tuple(xla_m.parameters()), xla_output, create_graph=True) xla_grad_grad_input = torch.autograd.grad( xla_output.sum() + sum(map(lambda x: x.sum(), xla_grad_input)), (xla_a, xla_output) + tuple(xla_m.parameters()), retain_graph=True) self.assertEqual(grad_grad_input, xla_grad_grad_input)
def test_empty_advanced_indexing(self): xla_device = xm.xla_device() base = torch.randn(2, 3, 4, 5) xla_base = base.to(device=xla_device) result = base[:, torch.empty(0, 6, dtype=torch.int64)] xla_result = xla_base[:, torch.empty(0, 6, dtype=torch.int64)] self.assertEqual(result, xla_result)
def test_slice_stepped_other_assign(self): a = torch.ones((10, 4)) xla_device = xm.xla_device() xla_a = a.to(xla_device) a[:, 1::4] = 2 xla_a[:, 1::4] = 2 self.assertEqual(a.data, xla_a.data.cpu())
def test_frac_negative(self): xla_device = xm.xla_device() a = torch.tensor(-3.2) b = a.frac() xla_a = a.to(xla_device) xla_b = xla_a.frac() self.assertEqual(b, xla_b)
def test(self): xla_device = xm.xla_device() kdata = [_gen_tensor(2, 3), _gen_tensor(3, 4)] kdata.append([_gen_tensor(2, 5), _gen_tensor(3, 6)]) data = dict() data[_gen_tensor(2, 2)] = tuple(kdata) data[_gen_tensor(2, 4)] = set([12.0, _gen_tensor(3, 7)]) data['ABC'] = _gen_tensor(4, 3) def select_fn(v): return type(v) == torch.Tensor def convert_fn(tensors): devices = [str(xla_device)] * len(tensors) return torch_xla._XLAC._xla_tensors_from_aten(tensors, devices) def check_fn(v): if select_fn(v): return xm.is_xla_tensor(v) elif isinstance(v, (list, tuple, set)): for x in v: if not check_fn(x): return False elif isinstance(v, dict): for k, x in v.items(): if not check_fn(k) or not check_fn(x): return False return True xla_data = xm.ToXlaTensorArena(convert_fn, select_fn).transform(data) self.assertTrue(check_fn(xla_data))
def test_index_put(self): xla_device = xm.xla_device() a = torch.tensor([1, 1, 1, 1]).to(xla_device).to(dtype=torch.float32) b = torch.rand(4) > 0.1 a[b] = 10 vset = b.sum().item() self.assertEqual(a.sum().item(), 10.0 * vset + (4.0 - vset))
def test_norm_p0(self): # p = 0 is equivalent to nonzero xla_device = xm.xla_device() a = torch.randn(3, 2) xla_a = a.to(xla_device) norm = a.norm(p=0) xla_norm = xla_a.norm(p=0) self.assertEqual(norm, xla_norm)
def test_ailing_slice(self): xla_device = xm.xla_device() a = torch.ones((1000, 324)).to(xla_device) xla_a = a.to(xla_device) w = a[:, 2::4] xla_w = a[:, 2::4] dw = torch.clamp(w, max=3.1) xla_dw = torch.clamp(xla_w, max=3.1) self.assertEqual(w.data, xla_w.data.cpu())
def test_max_broadcast(self): xla_device = xm.xla_device() a = torch.rand(3, 1, 2) b = torch.rand(4, 2) c = torch.max(a, b) xla_a = a.to(xla_device) xla_b = b.to(xla_device) xla_c = torch.max(xla_a, xla_b) self.assertEqual(c.data, xla_c.data.cpu())
def test_slice_rnd_stepped_assign(self): xla_device = xm.xla_device() size = 10 for s in range(0, size - 1): for e in range(1, size - s): a = torch.ones((3, size)) xla_a = a.to(xla_device) a[:, s::e] = 2 xla_a[:, s::e] = 2 self.assertEqual(a.data, xla_a.data.cpu())
def test_slice_assign(self): a = torch.rand(3, 3, 3) xla_device = xm.xla_device() xla_a = a.to(xla_device) shape = (4, 4, 4) b = a.new(*shape).zero_() xla_b = xla_a.new(*shape).zero_() b[0, :, :] = 1 xla_b[0, :, :] = 1 self.assertEqual(b.data, xla_b.data.cpu())
def test_slice_copy(self): a = torch.rand(3, 3, 3) xla_device = xm.xla_device() xla_a = a.to(xla_device) shape = (4, 4, 4) b = a.new(*shape).zero_() xla_b = xla_a.new(*shape).zero_() b[:a.shape[0], :a.shape[1], :a.shape[2]].copy_(a) xla_b[:a.shape[0], :a.shape[1], :a.shape[2]].copy_(xla_a) self.assertEqual(b.data, xla_b.data.cpu())
def runAtenTest(self, tensors, fn, device=None, rel_err=1e-2, abs_err=1e-5): if device is None: device = xm.xla_device() tensors = xu.as_list(tensors) xla_tensors = [x.to(device) for x in tensors] results = xu.as_list(fn(*tensors)) xla_results = xu.as_list(fn(*xla_tensors)) for at, xt in zip(results, xla_results): self.assertEqualRel( self.makeComparable(xt), at, rel_err=rel_err, abs_err=abs_err)
def test_save(self): xla_device = xm.xla_device() x = torch.randn(5, device=xla_device) x_file = tempfile.mktemp() try: torch.save(x, x_file) x_loaded = torch.load(x_file) self.assertEqual(x, x_loaded) finally: os.remove(x_file)
def test_bitwise_type(self): xla_device = xm.xla_device() a = torch.randint(255, (4, ), dtype=torch.long) xla_a = a.to(xla_device) self.assertRaises(RuntimeError, lambda: a & a.byte()) self.assertRaises(RuntimeError, lambda: xla_a & xla_a.byte()) def test_fn(a): return a & (~a) self.runAtenTest(a, test_fn)
def test_scatter_add_bool(self): xla_device = xm.xla_device() a = torch.tensor([[True, True, True, True, True], [True, True, True, True, True]]) b = torch.zeros(3, 5, dtype=torch.bool) index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]) b.scatter_add_(0, index, a) xla_a = a.to(xla_device) xla_b = b.to(xla_device) xla_index = index.to(xla_device) xla_b.scatter_add_(0, xla_index, xla_a) self.assertEqual(b, xla_b)
def test(self): device = xm.xla_device() orig_x = torch.Tensor([[1, 2], [3, 4]]) orig_y = torch.Tensor([[0.1, 0.2], [0.3, 0.4]]) x = orig_x y = orig_y xla_x = orig_x.to(device) xla_y = orig_y.to(device) for i in range(0, 2000): x = x + 2 * y xla_x = xla_x + 2 * xla_y self.assertEqualRel(x, xla_x.cpu(), rel_err=1e-3, abs_err=5)
def test_save_view_alias_check(self): class Nested(object): def __init__(self, x, y): self.x = x self.y = y a = torch.rand(16, device=xm.xla_device()) b = a[:10] c = a[6:] self.assertRaises(RuntimeError, lambda: xm.check_view_sharing([b, c])) nested = Nested(b, c) self.assertRaises(RuntimeError, lambda: xm.check_view_sharing(nested))
def _mp_fn(index): device = xm.xla_device() real_device = xm.xla_real_devices([str(device)])[0] if real_device.startswith('TPU:'): ones = torch.ones((2, 3)) xones = ones.to(device) torch_xla._XLAC._xla_cross_replica_sum([xones], 1.0, []) if not xones.cpu().allclose(ones * float(xm.xrt_world_size())): print('CrossReplicaSum produced wrong reductions') print(xones, file=sys.stderr) sys.exit(1) else: print('Default device {} is not a TPU device'.format(real_device), file=sys.stderr)
def test_pred_type(self): xla_device = xm.xla_device() a = torch.rand(4) b = torch.rand(4) xla_a = a.to(xla_device) xla_b = b.to(xla_device) c = (a >= 0.25) d = (b >= 0.5) xla_c = (xla_a >= 0.25) xla_d = (xla_b >= 0.5) e = torch.cat([a, b], dim=0) xla_e = torch.cat([xla_a, xla_b], dim=0) f = e.sum().item() xla_f = xla_e.sum().item() self.assertEqual(f, xla_f)
def test_rrelu_module(self): xla_device = xm.xla_device() a = torch.rand(1, 2, 2, requires_grad=True) xla_a = a.to(xla_device).detach() xla_a.requires_grad = True m = nn.RReLU() xla_m = m.to(xla_device) output = m(a) xla_output = xla_m(xla_a) self.assertEqual(output, xla_output.cpu()) output.sum().backward() xla_output.sum().backward() self.assertEqual(a.grad, xla_a.grad.cpu())
def test_reduction_0dim(self): self.runAtenTest(torch.rand(2, 0, 4).bool(), lambda x: torch.all(x)) self.runAtenTest(torch.rand(2, 0, 4).bool(), lambda x: torch.any(x)) self.runAtenTest(torch.rand(2, 0, 4), lambda x: torch.sum(x)) self.runAtenTest(torch.rand(2, 0, 4), lambda x: torch.mean(x)) self.runAtenTest(torch.rand(2, 0, 4), lambda x: torch.prod(x)) # min & max throws xla_device = xm.xla_device() a = torch.rand(2, 0, 4) xla_a = a.to(xla_device) self.assertRaises(RuntimeError, lambda: torch.max(a, dim=1)) self.assertRaises(RuntimeError, lambda: torch.max(a)) self.assertRaises(RuntimeError, lambda: torch.min(a, dim=1)) self.assertRaises(RuntimeError, lambda: torch.min(a)) self.assertRaises(RuntimeError, lambda: torch.max(xla_a, dim=1)) self.assertRaises(RuntimeError, lambda: torch.max(xla_a)) self.assertRaises(RuntimeError, lambda: torch.min(xla_a, dim=1)) self.assertRaises(RuntimeError, lambda: torch.min(xla_a))
def test(self): xla_device = xm.xla_device() x = _gen_tensor(8, 1, 28, 28) torch.manual_seed(42) model = MNISTComparator() save_dir1 = xu.TmpFolder() mc.configure(save_dir1.name) model(x) save_dir2 = xu.TmpFolder() mc.configure(save_dir2.name) torch.manual_seed(42) xla_model = MNISTComparator().to(xla_device) xla_x = x.to(xla_device) xla_model(xla_x) report = mc.compare(save_dir1.name, save_dir2.name, rtol=1e-03, atol=1e-04) if report: print(report) self.assertEqual(len(report), 0)
def test_pred_type(self): xla_device = xm.xla_device() a = torch.rand(4) b = torch.rand(4) xla_a = a.to(xla_device) xla_b = b.to(xla_device) c = (a >= 0.25) d = (b >= 0.5) xla_c = (xla_a >= 0.25) xla_d = (xla_b >= 0.5) e = torch.cat([a, b], dim=0) xla_e = torch.cat([xla_a, xla_b], dim=0) f = e.sum().item() xla_f = xla_e.sum().item() self.assertEqual(f, xla_f) # PRED can be automatically promoted in arithmetic ops. self.runAtenTest(c, lambda x: x + x.byte()) # PRED cannot be automatically promoted to other dtypes in bitwise ops. # This is not aligned with numpy behavior which means it might change # in the future. self.assertRaises(RuntimeError, lambda: c & c.byte()) self.assertRaises(RuntimeError, lambda: xla_c & xla_c.byte())
def test_print(self): xla_device = xm.xla_device() x = torch.tensor([5], device=xla_device) expected_str = 'tensor([5], device=\'' + str(xla_device) + '\')' self.assertExpectedInline(str(x), expected_str)
def test_copy(self): xla_device = xm.xla_device() x = torch.rand(5, device=xla_device) y = copy.copy(x) self.assertEqual(x, y)
def test_byte_dtype(self): xla_device = xm.xla_device() x = torch.ByteTensor([0, 1]).to(xla_device) y = torch.ByteTensor([0, 1]).to(xla_device) z = x + y self.assertEqual(z.dtype, torch.uint8)
def test_slice_zero_sized_dim(self): xla_device = xm.xla_device() v = torch.randn(2, 3, 4, 5).to(xla_device) y = v[:, :, :, 1] z = y[:, 1:1, :] self.assertEqual(z.size()[1], 0)