Пример #1
0
class TestMoveGenBoard(unittest.TestCase):
    def setUp(self):
        self.pos=Position()
        self.movegen=MoveGeneratorBoard(self.pos)
        
    def test_move_column(self):
        for i in range(8):
            for j in range(8):
                mgsq=self.movegen.squares[self.pos.get_index_row_col(i,j)]
                self.assertEqual(len(mgsq.column_moves1),7-i)
                self.assertEqual(len(mgsq.column_moves2),i)
                self.assertEqual(mgsq.row,i)
                self.assertEqual(mgsq.column,j)
                q=i+1
                for m in mgsq.column_moves1:
                    self.assertEqual(m.column,mgsq.column)
                    self.assertEqual(m.row,q)
                    q+=1
                q=i-1
                for m in mgsq.column_moves2:
                    self.assertEqual(m.column,mgsq.column)
                    self.assertEqual(m.row,q)
                    q-=1
        
    def test_move_row(self):
        for i in range(8):
            for j in range(8):
                mgsq=self.movegen.squares[self.pos.get_index_row_col(i,j)]
                self.assertEqual(len(mgsq.row_moves1),7-j)
                self.assertEqual(len(mgsq.row_moves2),j)
                self.assertEqual(mgsq.row,i)
                self.assertEqual(mgsq.column,j)
                q=j+1
                for m in mgsq.row_moves1:
                    self.assertEqual(m.column,q)
                    self.assertEqual(m.row,mgsq.row)
                    q+=1
                q=j-1
                for m in mgsq.row_moves2:
                    self.assertEqual(m.column,q)
                    self.assertEqual(m.row,mgsq.row)
                    q-=1
                    
    def test_move_diag(self):
        for i in range(8):
            for j in range(8):
                mgsq=self.movegen.squares[self.pos.get_index_row_col(i,j)]
                self.assertEqual(len(mgsq.diag_moves1),min(7-i,7-j))
                self.assertEqual(len(mgsq.diag_moves2),min(i,7-j))
                self.assertEqual(len(mgsq.diag_moves3),min(7-i,j))
                self.assertEqual(len(mgsq.diag_moves4),min(i,j))
                self.assertEqual(mgsq.row,i)
                self.assertEqual(mgsq.column,j)
                r=i+1
                q=j+1
                for m in mgsq.diag_moves1:
                    self.assertEqual(m.column,q)
                    self.assertEqual(m.row,r)
                    r+=1
                    q+=1
                r=i-1
                q=j+1
                for m in mgsq.diag_moves2:
                    self.assertEqual(m.column,q)
                    self.assertEqual(m.row,r)
                    r-=1
                    q+=1
                r=i+1
                q=j-1
                for m in mgsq.diag_moves3:
                    self.assertEqual(m.column,q)
                    self.assertEqual(m.row,r)
                    r+=1
                    q-=1
                r=i-1
                q=j-1
                for m in mgsq.diag_moves4:
                    self.assertEqual(m.column,q)
                    self.assertEqual(m.row,r)
                    r-=1
                    q-=1
                    
                    
    def test_move_knights(self):
        for i in range(8):
            for j in range(8):
                mgsq=self.movegen.squares[self.pos.get_index_row_col(i,j)]
                candidates=[x for x in self.movegen.knight_deltas if 0<=(i+x[0])<=7 and 0<=(j+x[1])<=7]
                self.assertEqual(len(mgsq.knight_moves),len(candidates))
                self.assertEqual(mgsq.row,i)
                self.assertEqual(mgsq.column,j)
                for q,s in enumerate(mgsq.knight_moves):
                    self.assertEqual(s.row,i+candidates[q][0])
                    self.assertEqual(s.column,j+candidates[q][1])

    def test_move_king(self):
        for i in range(8):
            for j in range(8):
                mgsq=self.movegen.squares[self.pos.get_index_row_col(i,j)]
                candidates=[x for x in self.movegen.king_deltas if 0<=(i+x[0])<=7 and 0<=(j+x[1])<=7]
                self.assertEqual(len(mgsq.king_moves),len(candidates))
                self.assertEqual(mgsq.row,i)
                self.assertEqual(mgsq.column,j)
                for q,s in enumerate(mgsq.king_moves):
                    self.assertEqual(s.row,i+candidates[q][0])
                    self.assertEqual(s.column,j+candidates[q][1])

                    
    def test_pawn_advances(self):
        for i in range(8):
            for j in range(8):
                mgsq=self.movegen.squares[self.pos.get_index_row_col(i,j)]
                if i!=2:
                    self.assertLess(len(mgsq.pawn_advance_white),2)
                else:
                    self.assertEqual(mgsq.pawn_advance_white[1].row,i+2)
                    self.assertEqual(mgsq.pawn_advance_white[1].column,j)
                if i!=7:
                    self.assertLess(len(mgsq.pawn_advance_black),2)
                else:
                    self.assertEqual(mgsq.pawn_advance_black[1].row,i-2)
                    self.assertEqual(mgsq.pawn_advance_black[1].column,j)
                if i==7:
                    self.assertEqual(len(mgsq.pawn_advance_white),0)
                else:
                    self.assertEqual(mgsq.pawn_advance_white[0].row,i+1)
                    self.assertEqual(mgsq.pawn_advance_white[0].column,j)
                if i==0:
                    self.assertEqual(len(mgsq.pawn_advance_black),0)
                else:
                    self.assertEqual(mgsq.pawn_advance_black[0].row,i-1)
                    self.assertEqual(mgsq.pawn_advance_black[0].column,j)
                    
    def test_pawn_captures(self):
        for i in range(8):
            for j in range(8):
                mgsq=self.movegen.squares[self.pos.get_index_row_col(i,j)]
                if i == 0:
                    self.assertEqual(len(mgsq.pawn_captures_black),0)
                    if j==0:
                        self.assertEqual(len(mgsq.pawn_captures_white),1)
                        for q in mgsq.pawn_captures_white:
                            self.assertEqual(q.row,i+1)
                            self.assertEqual(q.column,j+1)
                    elif j<7:
                        self.assertEqual(len(mgsq.pawn_captures_white),2)
                        s=set()
                        for q in mgsq.pawn_captures_white:
                            self.assertEqual(q.row,i+1)
                            s.add(q.column)
                        self.assertEqual(s,set([j+1,j-1]))
                    else:
                        self.assertEqual(len(mgsq.pawn_captures_white),1)
                        for q in mgsq.pawn_captures_white:
                            self.assertEqual(q.row,i+1)
                            self.assertEqual(q.column,j-1)
                elif i < 7:
                    if j==0:
                        self.assertEqual(len(mgsq.pawn_captures_black),1)
                        for q in mgsq.pawn_captures_black:
                            self.assertEqual(q.row,i-1)
                            self.assertEqual(q.column,j+1)
                        self.assertEqual(len(mgsq.pawn_captures_white),1)
                        for q in mgsq.pawn_captures_white:
                            self.assertEqual(q.row,i+1)
                            self.assertEqual(q.column,j+1)
                    elif j<7:
                        self.assertEqual(len(mgsq.pawn_captures_black),2)
                        s=set()
                        for q in mgsq.pawn_captures_black:
                            self.assertEqual(q.row,i-1)
                            s.add(q.column)
                        self.assertEqual(s,set([j+1,j-1]))
                        self.assertEqual(len(mgsq.pawn_captures_white),2)
                        s=set()
                        for q in mgsq.pawn_captures_white:
                            self.assertEqual(q.row,i+1)
                            s.add(q.column)
                        self.assertEqual(s,set([j+1,j-1]))
                    else:
                        self.assertEqual(len(mgsq.pawn_captures_black),1)
                        for q in mgsq.pawn_captures_black:
                            self.assertEqual(q.row,i-1)
                            self.assertEqual(q.column,j-1)
                        self.assertEqual(len(mgsq.pawn_captures_white),1)
                        for q in mgsq.pawn_captures_white:
                            self.assertEqual(q.row,i+1)
                            self.assertEqual(q.column,j-1)
                else:
                    self.assertEqual(len(mgsq.pawn_captures_white),0)
                    if j==0:
                        self.assertEqual(len(mgsq.pawn_captures_black),1)
                        for q in mgsq.pawn_captures_black:
                            self.assertEqual(q.row,i-1)
                            self.assertEqual(q.column,j+1)
                    elif j<7:
                        self.assertEqual(len(mgsq.pawn_captures_black),2)
                        s=set()
                        for q in mgsq.pawn_captures_black:
                            self.assertEqual(q.row,i-1)
                            s.add(q.column)
                        self.assertEqual(s,set([j+1,j-1]))
                    else:
                        self.assertEqual(len(mgsq.pawn_captures_black),1)
                        for q in mgsq.pawn_captures_black:
                            self.assertEqual(q.row,i-1)
                            self.assertEqual(q.column,j-1)
    
    def test_move_sets(self):
        move_types=['column_moves1',
                    'column_moves2',
                    'row_moves1',
                    'row_moves2',
                    'diag_moves1',
                    'diag_moves2',
                    'diag_moves3',
                    'diag_moves4',
                    'knight_moves',
                    'king_moves',
                    'pawn_advance_white',
                    'pawn_advance_black',
                    'pawn_captures_white',
                    'pawn_captures_black',
                    ]
        for i in range(63):
            sq=self.movegen.squares[i]
            for t in move_types:
                print t
                self.assertEqual(set([x.square for x in getattr(sq,t)]),getattr(sq,'set_'+t))
Пример #2
0
class Test_Position_PosInicial(unittest.TestCase):
#class Test_Position_PosInicial(dummyTest): 
    def setUp(self):
        self.pos=Position()
        self.b=[bR,bN,bB,bQ,bK,bB,bN,bR, #00..07
                bP,bP,bP,bP,bP,bP,bP,bP, #08..15
                xx,xx,xx,xx,xx,xx,xx,xx, #16..23
                xx,xx,xx,xx,xx,xx,xx,xx, #24..31
                xx,xx,xx,xx,xx,xx,xx,xx, #32..39
                xx,xx,xx,xx,xx,xx,xx,xx, #40..47
                wP,wP,wP,wP,wP,wP,wP,wP, #48..55
                wR,wN,wB,wQ,wK,wB,wN,wR  #56..63
               ]        
    def test_initial_values(self):

        for i,p in enumerate(self.pos):
            self.assertEqual(p,self.b[i])
        self.assertEqual(self.pos.w_castle_qs,True)
        self.assertEqual(self.pos.w_castle_ks,True)
        self.assertEqual(self.pos.b_castle_qs,True)
        self.assertEqual(self.pos.b_castle_ks,True)
        self.assertEqual(self.pos.ep_sq,None)
        self.assertEqual(self.pos.turn,WHITE)
        self.assertEqual(self.pos.check,False)
        self.assertEqual(self.pos.pieces[xx],set(range(16,48)))
        
    def test_set_get_squares_out_of_range(self):
        self.assertRaises(AssertionError,self.pos.set_square,-1,wB)
        self.assertRaises(AssertionError,self.pos.set_square,64,wB)
        self.assertRaises(AssertionError,self.pos.set_square,0,-1)
        self.assertRaises(AssertionError,self.pos.set_square,0,1000)
        self.assertRaises(AssertionError,self.pos.get_square,-1)
        self.assertRaises(AssertionError,self.pos.get_square,64)
        
    def test_set_get_square(self):
        self.assertEqual(self.pos.get_square(48),wP)
        self.pos.set_square(48,xx)
        self.pos.set_square(40,wP)
        self.assertEqual(self.pos.get_square(48),xx)
        self.assertEqual(self.pos.get_square(40),wP)
        self.assertEqual(self.pos.get_square(8),bP)
        self.pos.set_square(8,xx)
        self.pos.set_square(16,wP)
        self.assertEqual(self.pos.get_square(8),xx)
        self.assertEqual(self.pos.get_square(16),wP)  
        self.assertEqual(self.pos.get_square(63),wR)
        self.pos.set_square(63,xx)
        self.assertEqual(self.pos.get_square(63),xx)
              
    def test_get_row_col_out_of_range(self):
        self.assertRaises(AssertionError,self.pos.get_row_col,-1)
        self.assertRaises(AssertionError,self.pos.get_row_col,-2)
        self.assertRaises(AssertionError,self.pos.get_row_col,64)
        self.assertRaises(AssertionError,self.pos.get_row_col,65)
        
    def test_get_row_col(self):
        f,c = self.pos.get_row_col(0)
        self.assertEqual(f,7)
        self.assertEqual(c,0)
        f,c = self.pos.get_row_col(3)
        self.assertEqual(f,7)
        self.assertEqual(c,3)        
        f,c = self.pos.get_row_col(56)
        self.assertEqual(f,0)
        self.assertEqual(c,0)    
        f,c = self.pos.get_row_col(63)
        self.assertEqual(f,0)
        self.assertEqual(c,7)  
        f,c = self.pos.get_row_col(50)
        self.assertEqual(f,1)
        self.assertEqual(c,2) 
        
    def test_get_row_col_value_out_of_range(self):
        self.assertRaises(AssertionError,self.pos.get_row_col_value,-2,0)
        self.assertRaises(AssertionError,self.pos.get_row_col_value,-1,0)
        self.assertRaises(AssertionError,self.pos.get_row_col_value,0,-1)
        self.assertRaises(AssertionError,self.pos.get_row_col_value,0,-2)
        self.assertRaises(AssertionError,self.pos.get_row_col_value,8,0)
        self.assertRaises(AssertionError,self.pos.get_row_col_value,9,0)
        self.assertRaises(AssertionError,self.pos.get_row_col_value,0,8)
        self.assertRaises(AssertionError,self.pos.get_row_col_value,0,9)
        
    def test_get_row_col_value(self):
        self.assertEqual(self.pos.get_row_col_value(0,0),wR)
        self.assertEqual(self.pos.get_row_col_value(0,1),wN)
        self.assertEqual(self.pos.get_row_col_value(0,2),wB)
        self.assertEqual(self.pos.get_row_col_value(0,3),wQ)
        self.assertEqual(self.pos.get_row_col_value(0,4),wK)
        self.assertEqual(self.pos.get_row_col_value(0,7),wR)
        self.assertEqual(self.pos.get_row_col_value(0,6),wN)
        self.assertEqual(self.pos.get_row_col_value(0,5),wB)
        self.assertEqual(self.pos.get_row_col_value(1,0),wP)
        self.assertEqual(self.pos.get_row_col_value(1,1),wP)
        self.assertEqual(self.pos.get_row_col_value(1,2),wP)
        self.assertEqual(self.pos.get_row_col_value(1,3),wP)
        self.assertEqual(self.pos.get_row_col_value(1,4),wP)
        self.assertEqual(self.pos.get_row_col_value(1,7),wP)
        self.assertEqual(self.pos.get_row_col_value(1,6),wP)
        self.assertEqual(self.pos.get_row_col_value(1,5),wP)
        self.assertEqual(self.pos.get_row_col_value(7,0),bR)
        self.assertEqual(self.pos.get_row_col_value(7,1),bN)
        self.assertEqual(self.pos.get_row_col_value(7,2),bB)
        self.assertEqual(self.pos.get_row_col_value(7,3),bQ)
        self.assertEqual(self.pos.get_row_col_value(7,4),bK)
        self.assertEqual(self.pos.get_row_col_value(7,7),bR)
        self.assertEqual(self.pos.get_row_col_value(7,6),bN)
        self.assertEqual(self.pos.get_row_col_value(7,5),bB)
        self.assertEqual(self.pos.get_row_col_value(6,0),bP)
        self.assertEqual(self.pos.get_row_col_value(6,1),bP)
        self.assertEqual(self.pos.get_row_col_value(6,2),bP)
        self.assertEqual(self.pos.get_row_col_value(6,3),bP)
        self.assertEqual(self.pos.get_row_col_value(6,4),bP)
        self.assertEqual(self.pos.get_row_col_value(6,7),bP)
        self.assertEqual(self.pos.get_row_col_value(6,6),bP)
        self.assertEqual(self.pos.get_row_col_value(6,5),bP)
        for i in range(2,6):
            for j in range(2,6):
                self.assertEqual(self.pos.get_row_col_value(i,j),xx) 
        
    def test_set_row_col_value_out_of_range(self):
        self.assertRaises(AssertionError,self.pos.set_row_col_value,-2,0,xx)
        self.assertRaises(AssertionError,self.pos.set_row_col_value,-1,0,wB)
        self.assertRaises(AssertionError,self.pos.set_row_col_value,0,-1,bB)
        self.assertRaises(AssertionError,self.pos.set_row_col_value,0,-2,wK)
        self.assertRaises(AssertionError,self.pos.set_row_col_value,8,0,bK)
        self.assertRaises(AssertionError,self.pos.set_row_col_value,9,0,wQ)
        self.assertRaises(AssertionError,self.pos.set_row_col_value,0,8,bQ)
        self.assertRaises(AssertionError,self.pos.set_row_col_value,0,9,bP)
        
        self.assertRaises(AssertionError,self.pos.set_row_col_value,7,0,-4)
        self.assertRaises(AssertionError,self.pos.set_row_col_value,7,7,-1)
        self.assertRaises(AssertionError,self.pos.set_row_col_value,0,7,-2)
        self.assertRaises(AssertionError,self.pos.set_row_col_value,5,6,-3)
        self.assertRaises(AssertionError,self.pos.set_row_col_value,7,0,0)
        self.assertRaises(AssertionError,self.pos.set_row_col_value,7,0,14)
        self.assertRaises(AssertionError,self.pos.set_row_col_value,0,7,-1)
        self.assertRaises(AssertionError,self.pos.set_row_col_value,0,7,14)
        
    def test_set_row_col_value(self):
        self.pos.set_row_col_value(0,0,xx)
        self.assertEqual(self.pos.get_row_col_value(0,0),xx)
        self.assertEqual(self.pos.get_square(56),xx)
        
        self.pos.set_row_col_value(0,7,wK)
        self.assertEqual(self.pos.get_row_col_value(0,7),wK)
        self.assertEqual(self.pos.get_square(63),wK)
        
        self.pos.set_row_col_value(1,7,wQ)
        self.assertEqual(self.pos.get_row_col_value(1,7),wQ)
        self.assertEqual(self.pos.get_square(55),wQ)
        
        self.pos.set_row_col_value(1,3,wB)
        self.assertEqual(self.pos.get_row_col_value(1,3),wB)
        self.assertEqual(self.pos.get_square(51),wB)
        
        self.pos.set_row_col_value(7,7,xx)
        self.assertEqual(self.pos.get_row_col_value(7,7),xx)
        self.assertEqual(self.pos.get_square(7),xx)
        
        self.pos.set_row_col_value(6,7,bK)
        self.assertEqual(self.pos.get_row_col_value(6,7),bK)
        self.assertEqual(self.pos.get_square(15),bK)
        
        self.pos.set_row_col_value(6,0,bQ)
        self.assertEqual(self.pos.get_row_col_value(6,0),bQ)
        self.assertEqual(self.pos.get_square(8),bQ)
        
        self.pos.set_row_col_value(6,3,bB)
        self.assertEqual(self.pos.get_row_col_value(6,3),bB)
        self.assertEqual(self.pos.get_square(11),bB)
            
    def test_get_index_row_col(self):
        self.assertRaises(AssertionError,self.pos.get_index_row_col,-1,0)
        self.assertRaises(AssertionError,self.pos.get_index_row_col,0,-1)
        self.assertRaises(AssertionError,self.pos.get_index_row_col,-2,0)
        self.assertRaises(AssertionError,self.pos.get_index_row_col,0,-2)
        self.assertRaises(AssertionError,self.pos.get_index_row_col,8,0)
        self.assertRaises(AssertionError,self.pos.get_index_row_col,0,8)
        self.assertRaises(AssertionError,self.pos.get_index_row_col,9,0)
        self.assertRaises(AssertionError,self.pos.get_index_row_col,0,9)
        self.assertEqual(self.pos.get_index_row_col(7,0),0)
        self.assertEqual(self.pos.get_index_row_col(7,7),7)
        self.assertEqual(self.pos.get_index_row_col(0,0),56)
        self.assertEqual(self.pos.get_index_row_col(0,7),63)
        self.assertEqual(self.pos.get_index_row_col(5,0),16)
        
    def test_piece_sets(self):
        initial_empty=set(range(16,48))
        self.assertEqual(self.pos.pieces[xx],initial_empty)
        
        self.assertEqual(self.pos.pieces[bR],set([0,7]))
        self.assertEqual(self.pos.pieces[bN],set([1,6]))
        self.assertEqual(self.pos.pieces[bB],set([2,5]))
        self.assertEqual(self.pos.pieces[bQ],set([3]))
        self.assertEqual(self.pos.pieces[bK],set([4]))
        self.assertEqual(self.pos.pieces[bP],set(range(8,16)))
        
        self.assertEqual(self.pos.pieces[wR],set([56,63]))
        self.assertEqual(self.pos.pieces[wN],set([57,62]))
        self.assertEqual(self.pos.pieces[wB],set([58,61]))
        self.assertEqual(self.pos.pieces[wQ],set([59]))
        self.assertEqual(self.pos.pieces[wK],set([60]))
        initial_wP=set(range(48,56))
        self.assertEqual(self.pos.pieces[wP],initial_wP)
        
        self.pos.set_square(36,wP)
        self.pos.set_square(52,xx)
        self.assertEqual(self.pos.pieces[xx],(initial_empty-set([36]))|set([52]))
        self.assertEqual(self.pos.pieces[wP],(initial_wP-set([52]))|set([36]))
        
        self.pos.set_square(16,bR)
        self.assertEqual(self.pos.pieces[xx],(initial_empty-set([36])-set([16]))|set([52]))
        self.assertEqual(self.pos.pieces[bR],set([0,7,16]))