def create_valid_moves_given_squares(pos: Position, start_sq: Square, end_sq: Square) -> List[Move]: if not _is_move_from_square_available(pos, start_sq): return [] if not _is_move_between_squares_valid(pos, start_sq, end_sq): return [] koma = pos.get_koma(start_sq) side = pos.turn if koma.side() != pos.turn: return [] _, promotion_constrainer = MOVEGEN_FUNCTIONS[KomaType.get(koma)] return [ pos.create_move(start_sq, end_sq, can_promote) for can_promote in promotion_constrainer(side, start_sq, end_sq) ]
def _is_move_from_square_available(pos: Position, start_sq: Square) -> bool: if start_sq.is_hand(): return False koma = pos.get_koma(start_sq) if koma == Koma.NONE or koma == Koma.INVALID: return False return True
def generate_drop_moves(pos: Position, side: Side, ktype: KomaType) -> List[Move]: if not _is_drop_available(pos, side, ktype): return [] empty_sqs = _idxs_to_squares(pos.board.empty_idxs) return [ pos.create_drop_move(side, ktype, end_sq) for end_sq in empty_sqs if not _is_drop_innately_illegal(pos, side, ktype, end_sq) ]
def _is_move_between_squares_valid(pos: Position, start_sq: Square, end_sq: Square) -> bool: koma = pos.get_koma(start_sq) ktype = KomaType.get(koma) board = pos.board side = koma.side() start_idx = MailboxBoard.sq_to_idx(start_sq) end_idx = MailboxBoard.sq_to_idx(end_sq) dest_generator, _ = MOVEGEN_FUNCTIONS[ktype] return end_idx in dest_generator(board, start_idx, side)
def get_ambiguous_moves(pos: Position, move: Move) -> List[Move]: start_sq = move.start_sq if not _is_move_from_square_available(pos, start_sq): return [] koma = pos.get_koma(start_sq) side = koma.side() ktype = KomaType.get(koma) return [ mv for mv in generate_valid_moves(pos, side, ktype) if (mv.end_sq == move.end_sq) and ( mv.start_sq != move.start_sq) and is_legal(mv, pos) ]
def _is_drop_innately_illegal(pos: Position, side: Side, ktype: KomaType, end_sq: Square) -> bool: if pos.get_koma(end_sq) != Koma.NONE: return True if ktype == KomaType.FU: if (_is_drop_illegal_ky(side, end_sq) or _is_drop_nifu(pos.board, side, end_sq)): return True elif ktype == KomaType.KY: if _is_drop_illegal_ky(side, end_sq): return True elif ktype == KomaType.KE: if _is_drop_illegal_ke(side, end_sq): return True return False
def generate_valid_moves(pos: Position, side: Side, ktype: KomaType) -> List[Move]: dest_generator, promotion_constrainer = MOVEGEN_FUNCTIONS[ktype] mvlist = [] board = pos.board locations = board.koma_sets[Koma.make(side, ktype)] for start_idx in locations: destinations = dest_generator(board, start_idx, side) start_sq = MailboxBoard.idx_to_sq(start_idx) destination_sqs = _idxs_to_squares(destinations) for end_sq in destination_sqs: for can_promote in promotion_constrainer(side, start_sq, end_sq): move = pos.create_move(start_sq, end_sq, can_promote) mvlist.append(move) return mvlist
class TestJapaneseNotation(unittest.TestCase): def setUp(self): self.position = Position() self.move_writer = JapaneseMoveWriter(JAPANESE_MOVE_FORMAT) def _parse_move_test(self, line1, line2, line3): test_name = line1.rstrip() sfen = line2.rstrip() (start_coord, end_coord, promotion, expected_movestr) = line3.rstrip().split(" ") self.position.from_sfen(sfen) start_sq = Square.from_coord(int(start_coord)) end_sq = Square.from_coord(int(end_coord)) is_promotion = (promotion == "+") move = self.position.create_move(start_sq, end_sq, is_promotion) return test_name, move, expected_movestr def test_moves(self): test_data_file = r"tsumemi/test/test_cases_japanese_notation.txt" with open(test_data_file, encoding="utf8") as fh: for line1, line2, line3, _ in itertools.zip_longest(*[iter(fh)] * 4): (test_name, move, expected_movestr) = self._parse_move_test( line1, line2, line3) movestr = self.move_writer.write_move(move, self.position) with self.subTest(msg=test_name, answer=expected_movestr): self.assertEqual(expected_movestr, movestr) def test_drop(self): sfen = r"9/7+R1/sS1+p1+P2+R/2p+P1+P3/SS1n+pn3/9/9/9/3N5 b 2S2s 1" self.position.from_sfen(sfen) move = self.position.create_drop_move(Side.SENTE, KomaType.GI, Square.b84) movestr = self.move_writer.write_move(move, self.position) self.assertEqual("8四銀打", movestr) def test_termination_move_mate(self): move = TerminationMove(GameTermination.MATE) movestr = self.move_writer.write_move(move, self.position) self.assertEqual("詰み", movestr)
def get_move_preview(self) -> str: move_writer = self.model.get_item() pos = Position() pos.set_koma(Koma.FU, Square.from_coord(77)) move = pos.create_move(Square.from_coord(77), Square.from_coord(76)) return move_writer.write_move(move, pos)
def setUp(self): self.position = Position() self.move_writer = JapaneseMoveWriter(JAPANESE_MOVE_FORMAT)
def setUp(self): self.position = Position() self.move_writer = WesternMoveWriter(WESTERN_MOVE_FORMAT)
def setUp(self): self.position = Position() self.position.reset()
class TestPositionMethods(unittest.TestCase): def setUp(self): self.position = Position() self.position.reset() def test_set_hand_komatype_count(self): self.hand = HandRepresentation() self.hand.set_komatype_count(KomaType.KE, 4) self.assertEqual(self.hand.get_komatype_count(KomaType.KE), 4) def test_king_is_not_promoted(self): self.assertFalse(KomaType.get(Koma.OU).is_promoted()) def test_set_koma(self): self.position.set_koma(Koma.vTO, Square.b63) self.assertEqual(self.position.get_koma(sq=Square.b63), Koma.vTO) def test_set_hand_koma_count(self): self.position.set_hand_koma_count(Side.SHITATE, Koma.KE, 4) self.assertEqual( self.position.get_hand_of_side(Side.SENTE).mochigoma_dict[Koma.KE], 4) def test_make_move(self): self.position.set_koma(Koma.vGI, Square.b76) self.position.set_koma(Koma.UM, Square.b87) move = Move(start_sq=Square.b76, end_sq=Square.b87, is_promotion=True, koma=Koma.vGI, captured=Koma.UM) self.position.make_move(move) self.assertEqual(self.position.get_koma(sq=Square.b87), Koma.vNG) self.assertEqual( self.position.get_hand_of_side(Side.GOTE).mochigoma_dict[Koma.KA], 1) def test_from_to_sfen(self): sfen = "nk1n5/1g3g3/p8/2BP5/3+r5/9/9/9/9 b RBGg4s2n4l16p 17" self.position.from_sfen(sfen) self.assertEqual(self.position.to_sfen(), sfen) def test_starting_position(self): sfen_hirate = "lnsgkgsnl/1r5b1/ppppppppp/9/9/9/PPPPPPPPP/1B5R1/LNSGKGSNL b - 1" self.position.from_sfen(sfen_hirate) self.assertEqual(self.position.to_sfen(), sfen_hirate)
class TestMoveGeneration(unittest.TestCase): def setUp(self): self.position = Position() self.position.reset() def test_pawn_moves(self): sfen = "p1p4P1/2P5P/3p3P1/3P4P/9/p4p3/1p3P3/p5p2/1p4P1P b - 1" # answer keys sente_moves = [ "P18(19)", "P13(14)", "P13(14)+", "P11(12)+", "P22(23)", "P22(23)+", "P38(39)", "P46(47)", "P63(64)", "P63(64)+", "P71(72)+" ] gote_moves = [ "P39(38)+", "P47(46)", "P47(46)+", "P64(63)", "P72(71)", "P88(87)", "P88(87)+", "P92(91)", "P97(96)", "P97(96)+", "P99(98)+" ] self.position.from_sfen(sfen) mvlist_sente = rules.generate_valid_moves(pos=self.position, side=Side.SENTE, ktype=KomaType.FU) mvlist_gote = rules.generate_valid_moves(pos=self.position, side=Side.GOTE, ktype=KomaType.FU) mvset_sente = set((move.to_latin() for move in mvlist_sente)) mvset_gote = set((move.to_latin() for move in mvlist_gote)) # check answers self.assertEqual(mvset_sente, set(sente_moves)) self.assertEqual(mvset_gote, set(gote_moves)) def test_lance_moves(self): sfen = "4p4/4L2P1/5L2l/6l2/9/7L1/5l2L/4l1p2/4P4 b - 1" # answer keys sente_moves = [ "L16(17)", "L15(17)", "L14(17)", "L13(17)", "L13(17)+", "L25(26)", "L24(26)", "L23(26)", "L23(26)+", "L42(43)", "L42(43)+", "L41(43)+", "L51(52)+" ] gote_moves = [ "L14(13)", "L15(13)", "L16(13)", "L17(13)", "L17(13)+", "L35(34)", "L36(34)", "L37(34)", "L37(34)+", "L48(47)", "L48(47)+", "L49(47)+", "L59(58)+" ] self.position.from_sfen(sfen) mvlist_sente = rules.generate_valid_moves(pos=self.position, side=Side.SENTE, ktype=KomaType.KY) mvlist_gote = rules.generate_valid_moves(pos=self.position, side=Side.GOTE, ktype=KomaType.KY) mvset_sente = set((move.to_latin() for move in mvlist_sente)) mvset_gote = set((move.to_latin() for move in mvlist_gote)) # check answers self.assertEqual(mvset_sente, set(sente_moves)) self.assertEqual(mvset_gote, set(gote_moves)) def test_knight_moves(self): sfen = "N7N/4N1n2/9/5N3/n2n1N2N/3n5/9/2N1n4/n7n b - 1" # answer keys sente_moves = [ "N23(15)", "N23(15)+", "N32(44)+", "N53(45)", "N53(45)+", "N33(45)", "N33(45)+", "N66(78)", "N86(78)" ] gote_moves = [ "N24(32)", "N44(32)", "N57(65)", "N57(65)+", "N77(65)", "N77(65)+", "N78(66)+", "N87(95)", "N87(95)+" ] self.position.from_sfen(sfen) mvlist_sente = rules.generate_valid_moves(pos=self.position, side=Side.SENTE, ktype=KomaType.KE) mvlist_gote = rules.generate_valid_moves(pos=self.position, side=Side.GOTE, ktype=KomaType.KE) mvset_sente = set((move.to_latin() for move in mvlist_sente)) mvset_gote = set((move.to_latin() for move in mvlist_gote)) # check answers self.assertEqual(mvset_sente, set(sente_moves)) self.assertEqual(mvset_gote, set(gote_moves)) def test_silver_moves(self): sfen = "8S/7s1/7S1/6S2/9/6s2/7s1/7S1/8s b - 1" # answer keys sente_moves = [ "S22(11)", "S22(11)+", "S22(23)", "S22(23)+", "S12(23)", "S12(23)+", "S32(23)", "S32(23)+", "S14(23)", "S14(23)+", "S33(34)", "S33(34)+", "S43(34)", "S43(34)+", "S25(34)", "S45(34)", "S17(28)", "S27(28)", "S37(28)", "S19(28)", "S39(28)" ] gote_moves = [ "S28(19)", "S28(19)+", "S28(27)", "S28(27)+", "S18(27)", "S18(27)+", "S38(27)", "S38(27)+", "S16(27)", "S16(27)+", "S37(36)", "S37(36)+", "S47(36)", "S47(36)+", "S25(36)", "S45(36)", "S13(22)", "S23(22)", "S33(22)", "S11(22)", "S31(22)" ] self.position.from_sfen(sfen) mvlist_sente = rules.generate_valid_moves(pos=self.position, side=Side.SENTE, ktype=KomaType.GI) mvlist_gote = rules.generate_valid_moves(pos=self.position, side=Side.GOTE, ktype=KomaType.GI) mvset_sente = set((move.to_latin() for move in mvlist_sente)) mvset_gote = set((move.to_latin() for move in mvlist_gote)) # check answers self.assertEqual(mvset_sente, set(sente_moves)) self.assertEqual(mvset_gote, set(gote_moves)) def test_gold_moves(self): sfen = "8G/7G1/7g1/9/9/9/7G1/7g1/8g b - 1" # answer keys sente_moves = [ "G12(11)", "G21(11)", "G31(22)", "G21(22)", "G32(22)", "G12(22)", "G23(22)", "G36(27)", "G26(27)", "G16(27)", "G37(27)", "G17(27)", "G28(27)" ] gote_moves = [ "G18(19)", "G29(19)", "G39(28)", "G29(28)", "G38(28)", "G18(28)", "G27(28)", "G34(23)", "G24(23)", "G14(23)", "G33(23)", "G13(23)", "G22(23)" ] self.position.from_sfen(sfen) mvlist_sente = rules.generate_valid_moves(pos=self.position, side=Side.SENTE, ktype=KomaType.KI) mvlist_gote = rules.generate_valid_moves(pos=self.position, side=Side.GOTE, ktype=KomaType.KI) mvset_sente = set((move.to_latin() for move in mvlist_sente)) mvset_gote = set((move.to_latin() for move in mvlist_gote)) # check answers self.assertEqual(mvset_sente, set(sente_moves)) self.assertEqual(mvset_gote, set(gote_moves)) def test_bishop_moves(self): sfen = "9/7P1/9/5B3/9/3b5/9/1p7/9 b - 1" # answer keys sente_moves = [ "B71(44)", "B71(44)+", "B62(44)", "B62(44)+", "B53(44)", "B53(44)+", "B33(44)", "B33(44)+", "B55(44)", "B66(44)", "B35(44)", "B26(44)", "B17(44)" ] gote_moves = [ "B39(66)", "B39(66)+", "B48(66)", "B48(66)+", "B57(66)", "B57(66)+", "B77(66)", "B77(66)+", "B55(66)", "B44(66)", "B75(66)", "B84(66)", "B93(66)" ] self.position.from_sfen(sfen) mvlist_sente = rules.generate_valid_moves(pos=self.position, side=Side.SENTE, ktype=KomaType.KA) mvlist_gote = rules.generate_valid_moves(pos=self.position, side=Side.GOTE, ktype=KomaType.KA) mvset_sente = set((move.to_latin() for move in mvlist_sente)) mvset_gote = set((move.to_latin() for move in mvlist_gote)) # check answers self.assertEqual(mvset_sente, set(sente_moves)) self.assertEqual(mvset_gote, set(gote_moves)) def test_rook_moves(self): sfen = "9/3p5/9/9/2Pr1Rp2/9/9/5P3/9 b - 1" # answer keys sente_moves = [ "R41(45)", "R41(45)+", "R42(45)", "R42(45)+", "R43(45)", "R43(45)+", "R44(45)", "R55(45)", "R65(45)", "R35(45)", "R46(45)", "R47(45)" ] gote_moves = [ "R69(65)", "R69(65)+", "R68(65)", "R68(65)+", "R67(65)", "R67(65)+", "R66(65)", "R55(65)", "R45(65)", "R75(65)", "R64(65)", "R63(65)" ] self.position.from_sfen(sfen) mvlist_sente = rules.generate_valid_moves(pos=self.position, side=Side.SENTE, ktype=KomaType.HI) mvlist_gote = rules.generate_valid_moves(pos=self.position, side=Side.GOTE, ktype=KomaType.HI) mvset_sente = set((move.to_latin() for move in mvlist_sente)) mvset_gote = set((move.to_latin() for move in mvlist_gote)) # check answers self.assertEqual(mvset_sente, set(sente_moves)) self.assertEqual(mvset_gote, set(gote_moves)) def manual_test_drop_moves(self): # NOT automated test, needs manual verification sfen = "l1sgk1snl/6g2/p2ppp2p/2p6/9/9/P1SPPPP1P/2G6/LN2KGSNL b RBN3Prb3p 1" # answer keys self.position.from_sfen(sfen) droplist_fu = rules.generate_drop_moves(pos=self.position, side=Side.SENTE, ktype=KomaType.FU) droplist_ke = rules.generate_drop_moves(pos=self.position, side=Side.SENTE, ktype=KomaType.KE) droplist_ka = rules.generate_drop_moves(pos=self.position, side=Side.SENTE, ktype=KomaType.KA)
def __init__(self) -> None: self.movetree = GameNode() self.curr_node: MoveNode = self.movetree self.position = Position() return
class Game: """Representation of a shogi game. Contains a reference to the root of the movetree, the current active node, and the current position in the game. """ def __init__(self) -> None: self.movetree = GameNode() self.curr_node: MoveNode = self.movetree self.position = Position() return def reset(self) -> None: """Reset self to a new empty game. """ self.movetree = GameNode() self.curr_node = self.movetree self.position.reset() def add_move(self, move: Move) -> None: """Execute the given move and add it to the movetree if it doesn't already exist. """ self.position.make_move(move) # should check for exceptions self.curr_node = self.curr_node.add_move(move) return def make_move(self, move: Move) -> bool: """If the move exists in the movetree, execute the move. If not, don't do anything. """ res = self.curr_node.has_as_next_move(move) if res: self.position.make_move(move) # should check for exceptions self.curr_node = self.curr_node.get_variation_node(move) return res def is_mainline(self, move: Move) -> bool: if self.curr_node.is_leaf(): return False else: return self.curr_node.next().move == move def is_end(self) -> bool: return self.curr_node.is_leaf() def get_mainline_move(self) -> Move: next_node = self.curr_node.next() return next_node.move def go_next_move(self) -> bool: """Go one move further into the game, following the mainline. """ next_node = self.curr_node.next() if next_node.is_null(): return False self.position.make_move(next_node.move) self.curr_node = next_node return True def go_prev_move(self) -> None: """Step one move back in the game. """ prev_node = self.curr_node.prev() if prev_node.is_null(): return self.position.unmake_move(self.curr_node.move) self.curr_node = prev_node return def go_to_start(self) -> None: """Go to the start of the game. """ if not self.movetree.start_pos: return self.position.from_sfen(self.movetree.start_pos) self.curr_node = self.movetree return def go_to_end(self) -> None: """Go to the end of the current branch. """ has_next = True while has_next: has_next = self.go_next_move() return def get_current_sfen(self) -> str: return self.position.to_sfen() def get_mainline_notation(self, move_writer: AbstractMoveWriter) -> List[str]: # Make a new Game to leave self unchanged game = Game() game.movetree = self.movetree game.curr_node = self.movetree game.go_to_start() res = [] while not game.curr_node.is_leaf(): prev_move = game.curr_node.move game.go_next_move() move: Move = game.curr_node.move is_same_dest = ((not prev_move.is_null()) and (move.end_sq == prev_move.end_sq)) res.append( move_writer.write_move(move, game.position, is_same_dest)) return res
def create_valid_drop_given_square(pos: Position, side: Side, ktype: KomaType, end_sq: Square) -> Move: if exists_valid_drop_given_square(pos, side, ktype, end_sq): return pos.create_drop_move(side, ktype, end_sq) else: return NullMove()
def _is_drop_available(pos: Position, side: Side, ktype: KomaType) -> bool: if ktype not in HAND_TYPES: raise ValueError(f"{ktype} is not a valid KomaType for a drop move") return pos.get_hand_koma_count(side, ktype) != 0
def is_legal(mv: Move, pos: Position) -> bool: side = pos.turn pos.make_move(mv) ans = not is_in_check(pos, side) pos.unmake_move(mv) return ans