def test_intersection(self): self.message("Sparsity intersection") nza = set([ (0,0), (0,1), (2,0), (3,1), (2,3)]) nzb = set([ (0,2), (0,0), (2,2), (2,3)]) a = Sparsity(4,5) for i in nza: a.add_nz(i[0],i[1]) b = Sparsity(4,5) for i in nzb: b.add_nz(i[0],i[1]) c=a.intersect(b) for k in range(c.nnz()): ind = (c.row(k),c.get_col()[k]) self.assertTrue(ind in nza and ind in nzb) c = a * b self.assertEqual(c.nnz(),len(nza.intersection(nzb))) for k in range(c.nnz()): ind = (c.row(k),c.get_col()[k]) self.assertTrue(ind in nza and ind in nzb)
def test_union(self): self.message("Sparsity union") nza = set([ (0,0), (0,1), (2,0), (3,1)]) nzb = set([ (0,2), (0,0), (2,2)]) a = Sparsity(4,5) for i in nza: a.add_nz(i[0],i[1]) b = Sparsity(4,5) for i in nzb: b.add_nz(i[0],i[1]) c =a.unite(b) c = a + b self.assertEquals(c.nnz(),len(nza.union(nzb))) for k in range(c.nnz()): ind = (c.row(k),c.get_col()[k]) self.assertTrue(ind in nza or ind in nzb)
def test_union(self): self.message("Sparsity union") nza = set([(0, 0), (0, 1), (2, 0), (3, 1)]) nzb = set([(0, 2), (0, 0), (2, 2)]) a = Sparsity(4, 5) for i in nza: a.add_nz(i[0], i[1]) b = Sparsity(4, 5) for i in nzb: b.add_nz(i[0], i[1]) c = a.unite(b) c = a + b self.assertEquals(c.nnz(), len(nza.union(nzb))) for k in range(c.nnz()): ind = (c.row(k), c.get_col()[k]) self.assertTrue(ind in nza or ind in nzb)