def evaluate_state(self, env: MancalaEnv) -> (float, float): flip_board = env.side_to_move == Side.NORTH state = env.board.get_board_image(flipped=flip_board) mask = env.get_action_mask_with_no_pie() dist, _, value = self.network.evaluate_move(state=state, mask=mask) return dist, float(value)
def mcts_main(): mcts = TreesFactory.standard_mcts() state = MancalaEnv() try: _run_game(mcts, state) except Exception as e: logging.error("Uncaught exception in main: " + str(e))
def test_mcts_doesnt_mutate_state(self): state = MancalaEnv() initial_board = Board.clone(state.board) mcts = MCTSFactory.test_mcts() mcts.search(state) self.assertEqual(initial_board.board, state.board.board, "Expect MCTS doesn't mutate the initial board")
def start(): _state = MancalaEnv() _player = RandomPlayer() try: _run_game(_player, _state) except Exception as e: print("Uncaught exception in main: " + str(e))
def __init__(self, state: MancalaEnv, move: Move = None, parent=None): self.visits = 0 self.reward = 0 self.state = state self.children = [] self.parent = parent self.move = move self.value = -1 self.unexplored_moves = set(state.get_legal_moves())
def main(_): with tf.Session() as sess: with tf.variable_scope("global"): a3client = A3Client(sess) mcts = TreesFactory.alpha_mcts(a3client) state = MancalaEnv() try: _run_game(mcts, state) except Exception as e: logging.error("Uncaught exception in main: " + str(e))
def search(self, state: MancalaEnv) -> Move: # short circuit last move if len(state.get_legal_moves()) == 1: return state.get_legal_moves()[0] game_state_root = Node(state=MancalaEnv.clone(state)) start_time = datetime.datetime.utcnow() games_played = 0 while datetime.datetime.utcnow() - start_time < self.calculation_time: node = self.tree_policy.select(game_state_root) final_state = self.default_policy.simulate(node) self.rollout_policy.backpropagate(node, final_state) # Debugging information games_played += 1 logging.debug("%s; Game played %i" % (node, games_played)) logging.debug("%s" % game_state_root) chosen_child = node_utils.select_robust_child(game_state_root) logging.info("Choosing: %s" % chosen_child) return chosen_child.move
def expand(parent: Node) -> Node: child_expansion_move = choice(tuple(parent.unexplored_moves)) child_state = MancalaEnv.clone(parent.state) child_state.perform_move(child_expansion_move) child_node = Node(state=child_state, move=child_expansion_move, parent=parent) parent.put_child(child_node) MonteCarloTreePolicy._rave_expand(child_node) # go down the tree return child_node
def _alpha_beta_search(game: MancalaEnv, alpha=-np.inf, beta=np.inf, depth=5): """Search game to determine best action; use alpha-beta pruning. This version cuts off search and uses an evaluation function.""" if depth == 0 or game.is_game_over(): return game.get_player_utility() if game.side_to_move == Side.SOUTH: v = -np.inf for (_, new_s) in game.next_states(): v = max(v, AlphaBeta._alpha_beta_search(new_s, alpha, beta, depth - 1)) alpha = max(alpha, v) # if beta <= alpha: # break else: v = np.inf for (_, new_s) in game.next_states(): v = min(v, AlphaBeta._alpha_beta_search(new_s, alpha, beta, depth - 1)) beta = min(beta, v) # if beta <= alpha: # break return v
def _rave_expand(parent: Node): moves = [-1e80 for _ in range(parent.state.board.holes + 1)] for unexplored_move in parent.unexplored_moves.copy(): child_state = MancalaEnv.clone(parent.state) child_state.perform_move(unexplored_move) moves[unexplored_move.index] = evaluation.get_score( state=child_state, parent_side=parent.state.side_to_move) moves_dist = np.asarray(moves, dtype=np.float64).flatten() exp = np.exp(moves_dist - np.max(moves_dist)) dist = exp / np.sum(exp) parent.value = max(dist)
def backpropagate(self, root: Node, final_state: MancalaEnv): """ backpropgate pushes the reward (pay/visits) to the parents node up to the root :param root: starting node to backpropgate from :param final_state: the state of final node (holds final reward from the simulation) """ node = root # propagate node reward to parents' while node is not None: side = node.parent.state.side_to_move if node.parent is not None else node.state.side_to_move # root node node.update(final_state.compute_end_game_reward(side)) node = node.parent
def expand(self, node: AlphaNode): # Tactical workaround the pie move if Move(node.state.side_to_move, 0) in node.unexplored_moves: node.unexplored_moves.remove(Move(node.state.side_to_move, 0)) dist, value = self.network.evaluate_state(node.state) for index, prior in enumerate(dist): expansion_move = Move(node.state.side_to_move, index + 1) if node.state.is_legal(expansion_move): child_state = MancalaEnv.clone(node.state) child_state.perform_move(expansion_move) child_node = AlphaNode(state=child_state, prior=prior, move=expansion_move, parent=node) node.put_child(child_node) # go down the tree return node_utils.select_child_with_maximum_action_value(node)
def backpropagate(self, root: Node, final_state: MancalaEnv, lmbd=1): """backpropgate pushes the reward (pay/visits) to the parents node starting from the root down :param root: starting node to backpropgate from :param final_state: the state of final node (holds final reward from the simulation) :param lmbd: a parameter to control the weight of the value network """ path_stack = [] node = root # propagate node reward to parents' while node is not None: path_stack.append(node) node = node.parent # Update from root downward so the exploration bonus calculation is correct while len(path_stack) > 0: node = path_stack.pop() side = node.parent.state.side_to_move if node.parent is not None else node.state.side_to_move # root node game_reward = final_state.compute_end_game_reward(side) # _, value = self.network.evaluate_state(final_state) # game_reward = (1 - lmbd) * value + (lmbd * side_final_reward) # value from network + value from actionNet node.update(game_reward)
def _run_game(player: Player, state: MancalaEnv): our_agent_states = [] their_agent_states = [] both_agent_states = [] our_side = Side.SOUTH while True: msg = protocol.read_msg() try: msg_type = protocol.get_msg_type(msg) if msg_type == MsgType.START: first = protocol.interpret_start_msg(msg) if first: move = player.get_play(state) protocol.send_msg(protocol.create_move_msg(move.index)) else: our_side = Side.NORTH elif msg_type == MsgType.STATE: move_turn = protocol.interpret_state_msg(msg) if move_turn.move == 0: our_side = Side.opposite(our_side) move_to_perform = Move(state.side_to_move, move_turn.move) observed_state = ObservedState(state=state, action_taken=move_to_perform) both_agent_states.append(observed_state) if state.side_to_move == our_side: our_agent_states.append(observed_state) else: their_agent_states.append(observed_state) state.perform_move(move_to_perform) if not move_turn.end: if move_turn.again: move = player.get_play(state) # pie rule; optimal move is to swap if move.index == 0: protocol.send_msg(protocol.create_swap_msg()) else: protocol.send_msg( protocol.create_move_msg(move.index)) elif msg_type == MsgType.END: args = parser.parse_args() run_id = '%06d' % args.run_number run_category = args.category _our_agent_file_path = _checkpoint_file_path + "/our-agent/" + run_category + run_id _their_agent_file_path = _checkpoint_file_path + "/their-agent/" + run_category + run_id _both_agent_file_path = _checkpoint_file_path + "/both-agent/" + run_category + run_id np.save(file=_our_agent_file_path, arr=np.array(our_agent_states)) np.save(file=_their_agent_file_path, arr=np.array(their_agent_states)) np.save(file=_both_agent_file_path, arr=np.array(both_agent_states)) break else: print("Not sure what I got " + str(msg_type)) except InvalidMessageException as _e: print(str(_e))
def test_is_legal_move_returns_true_for_the_pie_rule(self): board = self.game.board MancalaEnv.make_move(board, Move(Side.SOUTH, 6), False) self.assertTrue(MancalaEnv.is_legal_action(board, Move(Side.NORTH, 0), False))
def setUp(self): self.game = MancalaEnv()
def test_mcts_test_game(self): state = MancalaEnv() mcts = MCTSFactory.long_test_mcts(sec=0) # Tweak this to test MCTS manually move = mcts.search(state) print(move)
def test_mcts_generate_legal_move(self): state = MancalaEnv() mcts = MCTSFactory.test_mcts() move = mcts.search(state) self.assertTrue(state.is_legal(move), "Expect move generated by MCTS is legal move")
def sample_state(self, env: MancalaEnv) -> (int, float): flip_board = env.side_to_move == Side.NORTH state = env.board.get_board_image(flipped=flip_board) mask = env.get_action_mask_with_no_pie() return self.network.sample(state=state, mask=mask)
def test_is_legal_move_returns_true_for_the_pie_rule2(self): env = MancalaEnv() env.perform_move(Move(Side.SOUTH, 5)) self.assertTrue(env.is_legal(Move(Side.NORTH, 0)))
def search(self, game: MancalaEnv) -> Move: values = [(a, self._alpha_beta_search(game=state, depth=self.depth)) for a, state in game.next_states()] np.random.shuffle(values) if game.side_to_move == Side.SOUTH: action, _ = max(values, key=lambda x: x[1]) else: action, _ = min(values, key=lambda x: x[1]) return action
def run(args, server): env = MancalaEnv() trainer = A3C(env, args.task) # Variable names that start with "local" are not saved in checkpoints. variables_to_save = [ v for v in tf.global_variables() if not v.name.startswith("local") ] init_op = tf.variables_initializer(variables_to_save) init_all_op = tf.global_variables_initializer() saver = FastSaver(variables_to_save) var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) logger.info('Trainable vars:') for v in var_list: logger.info(' %s %s', v.name, v.get_shape()) def init_fn(ses): logger.info("Initializing all parameters.") ses.run(init_all_op) config = tf.ConfigProto(device_filters=[ "/job:ps", "/job:worker/task:{}/cpu:0".format(args.task) ]) logdir = os.path.join(args.log_dir, 'train') summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task) logger.info("Events directory: %s_%s", logdir, args.task) sv = tf.train.Supervisor( is_chief=(args.task == 0), logdir=logdir, saver=saver, summary_op=None, init_op=init_op, init_fn=init_fn, summary_writer=summary_writer, ready_op=tf.report_uninitialized_variables(variables_to_save), global_step=trainer.global_step, save_model_secs=30, save_summaries_secs=30) num_global_steps = 100000000 logger.info( "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. " + "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified." ) with sv.managed_session(server.target, config=config) as sess, sess.as_default(): sess.run(trainer.down_sync) global_step = sess.run(trainer.global_step) logger.info("Starting training at step=%d", global_step) while not sv.should_stop() and (not num_global_steps or global_step < num_global_steps): trainer.play(sess, RandomAgent(), summary_writer) global_step = sess.run(trainer.global_step) # Ask for all the services to stop. sv.stop() logger.info('reached %s steps. worker stopped.', global_step)
class TestMancalaGameState(unittest.TestCase): def setUp(self): self.game = MancalaEnv() def test_initial_state_is_correct(self): self.assertEqual(self.game.side_to_move, Side.SOUTH) self.assertFalse(self.game.north_moved) for hole in range(1, self.game.board.holes + 1): self.assertEqual(self.game.board.get_seeds(Side.SOUTH, hole), 7) self.assertEqual(self.game.board.get_seeds(Side.NORTH, hole), 7) self.assertEqual(self.game.board.get_seeds_in_store(Side.SOUTH), 0) self.assertEqual(self.game.board.get_seeds_in_store(Side.NORTH), 0) def test_cloning_immutability(self): clone = MancalaEnv.clone(self.game) self.game.perform_move(Move(Side.SOUTH, 3)) self.assertEqual(clone.board.get_seeds(Side.SOUTH, 3), 7) self.assertEqual(clone.side_to_move, Side.SOUTH) def test_move_has_required_effects(self): self.game.perform_move(Move(Side.SOUTH, 5)) self.assertEqual(self.game.board.get_seeds(Side.SOUTH, 5), 0) self.assertEqual(self.game.board.get_seeds(Side.SOUTH, 6), 8) self.assertEqual(self.game.board.get_seeds(Side.SOUTH, 7), 8) self.assertEqual(self.game.board.get_seeds_in_store(Side.SOUTH), 1) self.assertEqual(self.game.board.get_seeds(Side.NORTH, 1), 8) self.assertEqual(self.game.board.get_seeds(Side.NORTH, 2), 8) self.assertEqual(self.game.board.get_seeds(Side.NORTH, 3), 8) self.assertEqual(self.game.board.get_seeds(Side.NORTH, 4), 8) self.game.perform_move(Move(Side.NORTH, 4)) self.assertEqual(self.game.board.get_seeds(Side.NORTH, 4), 0) self.assertEqual(self.game.board.get_seeds(Side.NORTH, 5), 8) self.assertEqual(self.game.board.get_seeds(Side.NORTH, 6), 8) self.assertEqual(self.game.board.get_seeds(Side.NORTH, 7), 8) self.assertEqual(self.game.board.get_seeds_in_store(Side.NORTH), 1) self.assertEqual(self.game.board.get_seeds(Side.SOUTH, 1), 8) self.assertEqual(self.game.board.get_seeds(Side.SOUTH, 2), 8) self.assertEqual(self.game.board.get_seeds(Side.SOUTH, 3), 8) def test_game_is_over_returns_false(self): self.assertFalse(self.game.is_game_over()) def test_game_is_over_returns_true(self): board = self.game.board for hole in range(1, board.holes + 1): board.set_seeds(Side.SOUTH, hole, 0) board.set_seeds_in_store(Side.SOUTH, 49) self.assertTrue(self.game.is_game_over()) def test_game_returns_winner_the_player_with_most_seeds(self): board = self.game.board for hole in range(1, self.game.board.holes + 1): board.set_seeds(Side.SOUTH, hole, 0) board.set_seeds(Side.NORTH, hole, 0) board.set_seeds_in_store(Side.SOUTH, 23) board.set_seeds_in_store(Side.NORTH, 21) self.assertEqual(self.game.get_winner(), Side.SOUTH) def test_game_returns_no_winner_if_players_have_equal_number_of_seeds(self): board = self.game.board for hole in range(1, self.game.board.holes + 1): board.set_seeds(Side.SOUTH, hole, 0) board.set_seeds(Side.NORTH, hole, 0) board.set_seeds_in_store(Side.SOUTH, 30) board.set_seeds_in_store(Side.NORTH, 30) self.assertEqual(self.game.get_winner(), None) def test_is_legal_move_returns_true_for_the_pie_rule(self): board = self.game.board MancalaEnv.make_move(board, Move(Side.SOUTH, 6), False) self.assertTrue(MancalaEnv.is_legal_action(board, Move(Side.NORTH, 0), False)) def test_is_legal_move_returns_true_for_the_pie_rule2(self): env = MancalaEnv() env.perform_move(Move(Side.SOUTH, 5)) self.assertTrue(env.is_legal(Move(Side.NORTH, 0))) def test_legal_moves_contains_all_moves(self): self.assertEqual(len(set(self.game.get_legal_moves())), 7) self.game.perform_move(Move(Side.SOUTH, 3)) self.assertEqual(len(set(self.game.get_legal_moves())), 8) def test_side_to_move_doesnt_change(self): self.game.perform_move(Move(Side.SOUTH, 1)) self.assertEqual(self.game.side_to_move, Side.NORTH) def test_alphbeta(self): board = self.game.board for hole in range(1, self.game.board.holes + 1): board.set_seeds(Side.SOUTH, hole, 0) board.set_seeds(Side.NORTH, hole, 0) board.set_seeds_in_store(Side.SOUTH, 0) board.set_seeds_in_store(Side.NORTH, 0) board.set_seeds(Side.SOUTH, 3, 1) board.set_seeds(Side.SOUTH, 2, 4) board.set_seeds_op(Side.SOUTH, 4, 5) print(self.game) print(search_action(self.game))
def __init__(self, env: MancalaEnv, task: int): self.env = env self.task = task # Performance statistics self.episodes_reward = [] self.episodes_length = [] self.episodes_mean_value = [] self.wins = 0 self.games = 0 worker_device = "/job:worker/task:{}/cpu:0".format(task) with tf.device( tf.train.replica_device_setter(1, worker_device=worker_device)): with tf.variable_scope("global"): # The input board is a tensor 2 x 8 x 1. The last dimension is added so that # convolutional layers can be applied to the input self.network = ACNetwork(state_shape=[2, 8, 1], num_act=7) self.global_step = tf.get_variable( "global_step", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), trainable=False) with tf.device(worker_device): with tf.variable_scope("local"): self.local_network = pi = self.network pi.global_step = self.global_step self.action = tf.placeholder(shape=[None], dtype=tf.int32) self.action_one_hot = tf.one_hot(self.action, 7, dtype=tf.float32) self.target_v = tf.placeholder(shape=[None], dtype=tf.float32) self.advantage = tf.placeholder(shape=[None], dtype=tf.float32) log_prob = tf.nn.log_softmax(pi.logits) prob = tf.nn.softmax(pi.logits) act_log_prob = tf.reduce_sum(log_prob * self.action_one_hot, [1]) # Loss functions self.value_loss = 0.5 * tf.reduce_sum( tf.square(self.target_v - tf.reshape(pi.value, [-1]))) self.entropy = -tf.reduce_sum(prob * log_prob) self.policy_loss = -tf.reduce_sum(act_log_prob * self.advantage) # self.reg_loss = tf.add_n([tf.nn.l2_loss(v) for v in local_vars]) self.loss = 0.5 * self.value_loss + self.policy_loss - self.entropy * 0.01 # + 0.002 * self.reg_loss # Get gradients from local network using local losses and clip them to avoid exploding gradients self.gradients = tf.gradients(self.loss, pi.vars) grads, self.grad_norms = tf.clip_by_global_norm( self.gradients, 100.0) # Define operation for downloading the weights from the parameter server (ps) # on the local model of the worker self.down_sync = tf.group( *[v1.assign(v2) for v1, v2 in zip(pi.vars, self.network.vars)]) # Define the training operation which applies the gradients on the parameter server network (up sync) optimiser = tf.train.RMSPropOptimizer(learning_rate=0.0007) grads_and_global_vars = list(zip(grads, self.network.vars)) inc_step = self.global_step.assign_add(tf.shape(self.action)[0]) self.train_op = tf.group( *[optimiser.apply_gradients(grads_and_global_vars), inc_step]) # Define an environment runner of this network self.env_runner = EnvironmentRunner(MancalaEnv(), pi) episode_size = tf.to_float(tf.shape(pi.value)[0]) # Define summaries for tensorboard tf.summary.scalar("Model/PolicyLoss", self.policy_loss / episode_size) tf.summary.scalar("Model/ValueLoss", self.value_loss / episode_size) tf.summary.scalar("Model/Entropy", self.entropy / episode_size) tf.summary.scalar("Model/GradientsGlobalNorm", self.grad_norms) tf.summary.scalar("Model/VarGlobalNorm", tf.global_norm(pi.vars)) self.summary_op = tf.summary.merge_all() self.summary_writer = None self.local_steps = 0
def __init__(self, state: MancalaEnv, action_taken: Move): self.state = MancalaEnv.clone(state) self.action_taken = Move.clone(action_taken)
def get_play(state: MancalaEnv) -> Move: return choice(state.get_legal_moves())
def _make_temp_child(parent: Node, move: Move) -> MancalaEnv: child_state = MancalaEnv.clone(parent.state) child_state.perform_move(move) return child_state
def test_cloning_immutability(self): clone = MancalaEnv.clone(self.game) self.game.perform_move(Move(Side.SOUTH, 3)) self.assertEqual(clone.board.get_seeds(Side.SOUTH, 3), 7) self.assertEqual(clone.side_to_move, Side.SOUTH)