def test_property_values(self): for device in self.devices: for dtype in self.dtypes: a = k2r.RaggedTensor( [[1], [2], [], [3, 4]], device=device, dtype=dtype ) assert torch.all( torch.eq( a.values, torch.tensor([1, 2, 3, 4], dtype=dtype, device=device), ) ) with self.assertRaises(AttributeError): # the `values` attribute is const. You cannot rebind it a.values = 10 # However, we can change the elements of a.values a.values[0] = 10 a.values[-1] *= 2 expected = k2r.RaggedTensor( [[10], [2], [], [3, 8]], dtype=dtype, device=device ) assert a == expected a.values[0] = 1 assert a != expected
def test_getitem_scalar(self): for device in self.devices: for dtype in self.dtypes: a = k2r.RaggedTensor( "[ [[1 2] [] [10]] [[3] [5]] ]", dtype=dtype ) a = a.to(device) b = a[0] expected = k2r.RaggedTensor("[[1 2] [] [10]]", dtype=dtype).to( device ) assert b == expected b = a[1] expected = k2r.RaggedTensor("[[3] [5]]", dtype=dtype).to(device) assert b == expected
def test_tot_size_2axes(self): for device in self.devices: for dtype in self.dtypes: a = k2r.RaggedTensor("[ [1 2 3] [] [5 8] ]", dtype=dtype) a = a.to(device) assert a.tot_size(0) == 3 assert a.tot_size(1) == 5
def test_clone(self): a = k2r.RaggedTensor([[1, 2], [], [3]]) b = a.clone() assert a == b a.values[0] = 10 assert a != b
def test_sum_no_grad(self): for device in self.devices: for dtype in self.dtypes: a = k2r.RaggedTensor([[1, 2], [], [5]], dtype=dtype) a = a.to(device) b = a.sum() expected_sum = torch.tensor( [3, 0, 5], dtype=dtype, device=device ) assert torch.all(torch.eq(b, expected_sum))
def test_tot_size_3axes(self): for device in self.devices: for dtype in self.dtypes: a = k2r.RaggedTensor( "[ [[1 2 3] [] [5 8]] [[] [1 5 9 10 -1] [] [] []] ]", dtype=dtype, ) a = a.to(device) assert a.tot_size(0) == 2 assert a.tot_size(1) == 8 assert a.tot_size(2) == 10
def test_setstate_3axes(self): for device in self.devices: for dtype in self.dtypes: a = k2r.RaggedTensor("[ [[1] [2 3] []] [[10]]]", dtype=dtype) fid, tmp_filename = tempfile.mkstemp() os.close(fid) torch.save(a, tmp_filename) b = torch.load(tmp_filename) os.remove(tmp_filename) assert a == b
def test_grad(self): a = k2r.RaggedTensor([[1, 2], [10], []], dtype=torch.float32) assert a.grad is None assert a.requires_grad is False a.requires_grad = True assert a.requires_grad is True a.requires_grad_(False) assert a.requires_grad is False a.requires_grad_(True) assert a.requires_grad is True
def test_setstate_2axes(self): for device in self.devices: for dtype in self.dtypes: a = k2r.RaggedTensor([[1], [2, 3], []], dtype=dtype) a = a.to(device) fid, tmp_filename = tempfile.mkstemp() os.close(fid) torch.save(a, tmp_filename) b = torch.load(tmp_filename) os.remove(tmp_filename) # It checks both dtype and device, not just value assert a == b
def test_getstate_2axes(self): for device in self.devices: for dtype in self.dtypes: a = k2r.RaggedTensor([[1, 2], [3], []], dtype=dtype).to(device) b = a.__getstate__() assert isinstance(b, tuple) assert len(b) == 3 # b contains (row_splits, "row_ids1", values) b_0 = torch.tensor( [0, 2, 3, 3], dtype=torch.int32, device=device ) b_1 = "row_ids1" b_2 = a.values assert torch.all(torch.eq(b[0], b_0)) assert b[1] == b_1 assert torch.all(torch.eq(b[2], b_2))
def test_sum_with_grad(self): for device in self.devices: for dtype in [torch.float32, torch.float64]: a = k2r.RaggedTensor([[1, 2], [], [5]], dtype=dtype) a = a.to(device) a.requires_grad_(True) b = a.sum() expected_sum = torch.tensor( [3, 0, 5], dtype=dtype, device=device ) assert torch.all(torch.eq(b, expected_sum)) c = b[0] * 10 + b[1] * 20 + b[2] * 30 c.backward() expected_grad = torch.tensor( [10, 10, 30], device=device, dtype=dtype ) assert torch.all(torch.eq(a.grad, expected_grad))
def test_create_ragged_tensor_from_string(self): funcs = [k2r.create_ragged_tensor, k2r.RaggedTensor] for func in funcs: for device in self.devices: for dtype in self.dtypes: a = func( [[1], [2, 3, 4, 5], []], dtype=dtype, device=device ) b = func("[[1] [2 3 4 5] []]", dtype=dtype, device=device) assert a == b assert b.dim0 == 3 assert a.dtype == dtype assert a.device == device b = k2r.RaggedTensor( "[[[1] [2 3] []] [[10]]]", dtype=dtype, device=device ) assert b.num_axes == 3 assert b.dim0 == 2 assert b.dtype == dtype assert b.device == device
def test_getitem_1d_tensor(self): for device in self.devices: for dtype in self.dtypes: a = k2r.RaggedTensor([[1, 2, 0], [0, 1], [2, 3]], device=device) b = k2.RaggedTensor( # 0 1 2 3 [[10, 20], [300], [-10, 0, -1], [-2, 4, 5]], dtype=dtype, device=device, ) # for a[0] index = torch.tensor( [1, 2, 0], dtype=torch.int32, device=device ) assert torch.all(torch.eq(a[0], index)) expected = k2.RaggedTensor( [[300], [-10, 0, -1], [10, 20]], dtype=dtype, device=device ) assert b[a[0]] == expected # for a[1] index = torch.tensor([0, 1], dtype=torch.int32, device=device) assert torch.all(torch.eq(a[1], index)) expected = k2.RaggedTensor( [[10, 20], [300]], dtype=dtype, device=device ) assert b[a[1]] == expected # for a[2] index = torch.tensor([2, 3], dtype=torch.int32, device=device) assert torch.all(torch.eq(a[2], index)) expected = k2.RaggedTensor( [[-10, 0, -1], [-2, 4, 5]], dtype=dtype, device=device ) assert b[a[2]] == expected
def test_getstate_3axes(self): for device in self.devices: for dtype in self.dtypes: a = k2r.RaggedTensor( "[[[1 2] [3] []] [[4] [5 6]]]", dtype=dtype ).to(device) b = a.__getstate__() assert isinstance(b, tuple) assert len(b) == 5 # b contains (row_splits1, "row_ids1", row_splits2, # "row_ids2", values) b_0 = torch.tensor([0, 3, 5], dtype=torch.int32, device=device) b_1 = "row_ids1" b_2 = torch.tensor( [0, 2, 3, 3, 4, 6], dtype=torch.int32, device=device ) # noqa b_3 = "row_ids2" b_4 = a.values assert torch.all(torch.eq(b[0], b_0)) assert b[1] == b_1 assert torch.all(torch.eq(b[2], b_2)) assert b[3] == b_3 assert torch.all(torch.eq(b[4], b_4))