def test_getRow(self):
        """Get a row"""
        exp = SparseMat(1,3)
        exp.update({(0,2):3})
        obs = self.obj.getRow(1)
        self.assertEqual(obs, exp)

        exp = SparseMat(1,3)
        obs = self.obj.getRow(4)
        self.assertEqual(obs,exp)

        obj = SparseMat(4,1)
        obj[0,0] = 5
        obj[1,0] = 6
        obj[2,0] = 7
        obj[3,0] = 8
        exp1 = SparseMat(1,1)
        exp1[0,0] = 5
        exp2 = SparseMat(1,1)
        exp2[0,0] = 6
        exp3 = SparseMat(1,1)
        exp3[0,0] = 7
        exp4 = SparseMat(1,1)
        exp4[0,0] = 8
        obs1 = obj.getRow(0)
        obs2 = obj.getRow(1)
        obs3 = obj.getRow(2)
        obs4 = obj.getRow(3)
        self.assertRaises(IndexError,obj.getRow, 4)

        self.assertEqual(obs1, exp1)
        self.assertEqual(obs2, exp2)
        self.assertEqual(obs3, exp3)
        self.assertEqual(obs4, exp4)

        self.assertRaises(KeyError, self.obj.getRow, -1)
class SparseMatTests(TestCase):
    def setUp(self):
        self.obj = SparseMat(6,3)
        self.obj.update({(1,2):3,(5,2):6})

    def test_copy(self):
        """copy thy self"""
        obs = self.obj.copy()
        self.assertEqual(obs, self.obj)
        self.assertNotEqual(id(obs), id(self.obj))
        obs[1,2] = 10
        self.assertNotEqual(obs, self.obj)
        obs[1,2] = 3
        obs[0,0] = 5
        self.assertNotEqual(obs, self.obj)
        self.assertNotEqual(id(self.obj._index_rows), id(obs._index_rows))
        self.assertNotEqual(id(self.obj._index_cols), id(obs._index_cols))

    def test_rebuildIndices(self):
        """rebuild table indices"""
        sm = SparseMat(3,4,enable_indices=False)
        sm[0,0] = 10
        sm[0,2] = 0
        sm[1,3] = -5
        sm[1,2] = 1
        self.assertEqual(sm._index_rows, None)
        self.assertEqual(sm._index_cols, None)
        sm.rebuildIndices()
        self.assertEqual(sm._index_rows, [set([(0,0)]),set([(1,3), (1,2)]),
                                          set([])])
        self.assertEqual(sm._index_cols, [set([(0,0)]), set([]),
                                          set([(1,2)]), set([(1,3)])])

    def test_setitem(self):
        self.obj[(2,2)] = 10
        exp = sorted([((1,2),3),((5,2),6),((2,2),10)])
        self.assertEqual(sorted(self.obj.items()), exp)
        self.assertRaises(KeyError, self.obj.__setitem__, (100,50), 10)

        self.assertEqual(self.obj._index_rows, [set(),set([(1,2)]),
                                                set([(2,2)]), set(), set(),
                                                set([(5,2)])])
        self.assertEqual(self.obj._index_cols, [set(),set(),
                                                set([(1,2),(2,2),(5,2)])])

    def test_getitem_simple(self):
        """Tests simple getitem"""
        self.assertEqual(self.obj[(1,2)], 3)
        self.assertEqual(self.obj[1,2], 3)
        self.assertEqual(self.obj[1,1], 0)
        self.assertRaises(KeyError, self.obj.__getitem__, (3,3))
        self.assertRaises(KeyError, self.obj.__getitem__, (-1,2))
        self.assertRaises(IndexError, self.obj.__getitem__, 1)

    def test_getitem_slice(self):
        """Tests for slices on getitem"""
        exp = SparseMat(1,3)
        exp[0,2] = 3
        self.assertEqual(exp, self.obj[1,:])

        exp = SparseMat(6,1)
        exp[1,0] = 3
        exp[5,0] = 6
        self.assertEqual(exp, self.obj[:,2])

        self.assertRaises(IndexError, self.obj.__getitem__, (10,slice(None)))
        self.assertRaises(AttributeError, self.obj.__getitem__, (3, slice(1,2,3)))

    def test_contains(self):
        """Make sure we can check things exist"""
        sm1 = SparseMat(3,4)
        for r in range(3):
            for c in range(4):
                assert (r,c) not in sm1
        sm1[1,2] = 0
        assert (1,2) not in sm1
        sm1[1,2] = 10
        assert (1,2) in sm1
        sm1.erase(1,2)
        assert (1,2) not in sm1
                
    def test_erase(self):
        """Make sure we can get rid of elements"""
        sm1 = SparseMat(4,6)
        self.assertEqual(sm1[2,3], 0.0)
        sm1.erase(2,3)
        self.assertEqual(sm1[2,3], 0.0)
        sm1[2,3] = 10
        self.assertEqual(sm1[2,3], 10.0)
        self.assertEqual(sm1._index_rows, [set([]), set([]), set([(2,3)]), set([])])
        self.assertEqual(sm1._index_cols, [set([]), set([]), set([]), set([(2,3)]), set([]), set([])])
        sm1.erase(2,3)
        self.assertEqual(sm1._index_rows, [set([]), set([]), set([]), set([])])
        self.assertEqual(sm1._index_cols, [set([]), set([]), set([]), set([]), set([]), set([])])
        self.assertEqual(sm1[2,3], 0.0)
        self.assertEqual(sm1._index_rows, [set([]), set([]), set([]), set([])])
        self.assertEqual(sm1._index_cols, [set([]), set([]), set([]), set([]), set([]), set([])])
        
    def test_eq(self):
        """Tests for equality"""
        sm1 = SparseMat(4,6)
        sm2 = SparseMat(4,6)
        sm3 = SparseMat(6,4)
        
        self.assertEqual(sm1, sm2)
        self.assertNotEqual(sm1,sm3)
        
        sm1[0,1] = 10
        sm2[0,1] = 5
        self.assertNotEqual(sm1, sm2)
        sm2[0,1] = 10
        self.assertEqual(sm1, sm2)
        
    def test_update_internal_indices(self):
        """Update internal indices"""
        sd = SparseMat(2,3)
        self.assertEqual(sd._index_rows, [set(),set()])
        self.assertEqual(sd._index_cols, [set(),set(),set()])

        sd[(1,2)] = 5
        self.assertEqual(sd._index_rows, [set(),set([(1,2)])])
        self.assertEqual(sd._index_cols, [set(),set(),set([(1,2)])])

        sd[(1,2)] = 0
        self.assertEqual(sd._index_rows, [set(),set()])
        self.assertEqual(sd._index_cols, [set(),set(),set()])

        sd[(1,1)] = 0
        self.assertEqual(sd._index_rows, [set(),set()])
        self.assertEqual(sd._index_cols, [set(),set(),set()])

    def test_getRow(self):
        """Get a row"""
        exp = SparseMat(1,3)
        exp.update({(0,2):3})
        obs = self.obj.getRow(1)
        self.assertEqual(obs, exp)

        exp = SparseMat(1,3)
        obs = self.obj.getRow(4)
        self.assertEqual(obs,exp)

        obj = SparseMat(4,1)
        obj[0,0] = 5
        obj[1,0] = 6
        obj[2,0] = 7
        obj[3,0] = 8
        exp1 = SparseMat(1,1)
        exp1[0,0] = 5
        exp2 = SparseMat(1,1)
        exp2[0,0] = 6
        exp3 = SparseMat(1,1)
        exp3[0,0] = 7
        exp4 = SparseMat(1,1)
        exp4[0,0] = 8
        obs1 = obj.getRow(0)
        obs2 = obj.getRow(1)
        obs3 = obj.getRow(2)
        obs4 = obj.getRow(3)
        self.assertRaises(IndexError,obj.getRow, 4)

        self.assertEqual(obs1, exp1)
        self.assertEqual(obs2, exp2)
        self.assertEqual(obs3, exp3)
        self.assertEqual(obs4, exp4)

        self.assertRaises(KeyError, self.obj.getRow, -1)

    def test_getCol(self):
        """Get a col"""
        exp = SparseMat(6,1)
        exp.update({(1,0):3,(5,0):6})
        obs = self.obj.getCol(2)
        self.assertEqual(obs,exp)

        exp = SparseMat(6,1)
        obs = self.obj.getCol(1)
        self.assertEqual(obs,exp)

        self.assertRaises(KeyError, self.obj.getCol, -1)

    def test_update(self):
        """updates should work and update indexes"""
        items = self.obj.items()
        indexes = (self.obj._index_rows, self.obj._index_cols)
        self.obj.update({(1,2):3,(5,2):6})
        self.assertEqual(items, self.obj.items())
        self.assertEqual(indexes, (self.obj._index_rows, self.obj._index_cols))

        self.obj.update({(1,2):0,(5,2):6})
        self.assertEqual(self.obj.items(), {(5,2):6}.items())
        self.assertEqual(self.obj._index_rows, [set(),set(),set(),set(),\
                                                set(), set([(5,2)])])
        self.assertEqual(self.obj._index_cols, [set(),set(),set([(5,2)])])

        self.obj.update({(1,2):1,(2,2):0,(1,1):10})
        
        self.assertEqual(sorted(self.obj.items()), sorted({(1,2):1.0,(1,1):10.0,(5,2):6.0}.items()))
        self.assertEqual(self.obj._index_rows, [set(),set([(1,2),(1,1)]),\
                                                set(),set(),\
                                                set(), set([(5,2)])])
        self.assertEqual(self.obj._index_cols, [set(),set([(1,1)]),set([(1,2),(5,2)])])

    def test_T(self):
        """test transpose"""
        exp = SparseMat(3,6)
        exp.update({(2,1):3,(2,5):6})
        obs = self.obj.T
        self.assertEqual(obs, exp)

    def test_get_size(self):
        """test getting the number of nonzero elements"""
        self.assertEqual(self.obj.size, 2)

        # Test with setting an element explicitly to zero.
        sm = SparseMat(2,4)
        sm.update({(0,1):3,(1,2):7,(0,0):0})
        self.assertEqual(sm.size, 2)

        # Test with an empty matrix.
        sm = SparseMat(2,4)
        self.assertEqual(sm.size, 0)
        sm = SparseMat(0,0)
        self.assertEqual(sm.size, 0)