예제 #1
0
파일: sparse.py 프로젝트: kevinwkc/jittor
 def t(self):
     indices = list(self.indices.split(1, dim=0))
     indices[-1], indices[-2] = indices[-2], indices[-1]
     indices = jt.contrib.concat(indices, dim=0)
     shape = list(self.shape)
     shape[-1], shape[-2] = shape[-2], shape[-1]
     shape = jt.NanoVector(shape)
     return SparseVar(indices, self.values, shape)
예제 #2
0
 def test_sparse_var(self):
     indices = np.array([[0,1,1],[2,0,2]])
     values = np.array([3,4,5]).astype(np.float32)
     shape = [2,3]
     jt_array = jt.sparse.sparse_array(jt.array(indices),jt.array(values),jt.NanoVector(shape))
     torch_tensor = torch.sparse.FloatTensor(torch.from_numpy(indices),torch.from_numpy(values),torch.Size(shape))
     jt_numpy = jt_array.to_dense().numpy()
     torch_numpy = torch_tensor.to_dense().numpy()
     assert np.allclose(jt_numpy,torch_numpy)
예제 #3
0
 def test_slice_bug(self):
     a = jt.NanoVector([2, 3, 4, 5])
     assert a[:] == [2, 3, 4, 5]
     assert a[1:] == [3, 4, 5]