コード例 #1
0
    def train(self):
        """
        Perform a singular self play game and then update the neural network
        """
        p1_images = []
        p2_images = []
        p1_pi = []
        p2_pi = []

        g = game.Game()
        mcts = MonteCarloTreeSearch(
            simulations=self.mcts_simulations,
            model=self.model,
        )
        player = 1
        while not g.game_over():
            pi = list(mcts.mcts(g))
            best_prob = 0
            best_moves = []
            for i, prob in enumerate(pi):
                if prob > best_prob:
                    best_prob = prob
                    best_moves = [i]
                elif prob == best_prob:
                    best_moves.append(i)
                else:
                    continue

            images, pi = mcts.get_training_data()
            assert len(images) == len(pi)

            if player == 1:
                p1_images += images
                p1_pi += pi
            else:
                p2_images += images
                p2_pi += pi

            g.make_move_index(random.choice(best_moves))
            player *= -1

        if g.result == 0:
            labels = [0 for _ in range(len(p1_pi) + len(p2_pi))]
        else:
            if player == 1:
                labels = [-1. for _ in range(len(p1_images))]
                labels += [1. for _ in range(len(p2_images))]
            else:
                labels = [1. for _ in range(len(p1_images))]
                labels += [-1. for _ in range(len(p2_images))]

        p1_images = np.array(p1_images)
        p2_images = np.array(p2_images)
        data = np.vstack((p1_images, p2_images))
        p1_pi = np.array(p1_pi)
        p2_pi = np.array(p2_pi)
        policy = np.vstack((p1_pi, p2_pi))
        labels = np.array(labels)

        self.model.train(data, policy, labels)
コード例 #2
0
ファイル: Game.py プロジェクト: Chaoste/hexgame-ml
    def start(self, firstPlayer):

        self.GameState = 0

        EventManager.notify("GameStarting")

        # move counter init
        self.moveCounter = 0

        # generate fresh state
        self.HexBoard = HexBoard(self.size[0], self.size[1])
        self.HexBoard.setReferenceToGame(self)

        # current player depending on decision
        self._currentPlayer = firstPlayer

        if self.mode == "ki":
            self.KI = MonteCarloTreeSearch(self, 2)  # TODO: ki_player_id == 2?
            # self.KI = HexKI(self.size[0], self.size[1])

        if self.mode == "inter" or self.mode == "machine":

            self.KI = []
            self.KI.append(HexKI(self.size[0], self.size[1]))
            self.KI.append(HexKI(self.size[0], self.size[1]))

            self._currentPlayerType = "ki"

        # if random number wanted, generate one
        if firstPlayer == 0:
            self.chooseFirst()

        EventManager.notify("GameStarted")
コード例 #3
0
ファイル: main.py プロジェクト: sandernordeide/AI-programming
def run_batch(G, M, M_decay, N, K, B, P, game_mode, verbose):
    wins = 0
    verbose_message = "\n"
    MCTS = MonteCarloTreeSearch()
    for i in tqdm(range(G)):
        starting_player = set_starting_player(P)
        if game_mode == 0:
            env = NIMBoard(N, K, starting_player)
        else:
            env = LedgeBoard(B, starting_player)
        verbose_message += "Initial state: {}\n".format(env.get_state()[1])
        MCTS.init_tree(env)
        iteration = 1
        simulations = M
        while not env.is_game_over():
            action = MCTS.search(env, simulations)  # find best move
            verbose_message += "{}: {}".format(iteration,
                                               env.print_move(action))
            env.move(action)
            iteration += 1
            simulations = math.ceil(simulations *
                                    M_decay)  # (optional) speed up for mcts
        if starting_player == env.player:
            wins += 1
        verbose_message += print_last_move(iteration, env)
    if verbose:
        print(verbose_message)
    print("Starting player won {}/{} ({}%)".format(wins, G, 100 * wins / G))
コード例 #4
0
    def play(self):
        datas, node = [], TreeNode()
        mc = MonteCarloTreeSearch(self.net)
        move_count = 0

        while True:
            if move_count < TEMPTRIG:
                pi, next_node = mc.search(self.board, node, temperature=1)
            else:
                pi, next_node = mc.search(self.board, node)

            datas.append([self.board.gen_state(), pi, self.board.c_player])

            self.board.move(next_node.action)
            next_node.parent = None
            node = next_node

            if self.board.is_draw():
                reward = 0.
                break

            if self.board.is_game_over():
                reward = 1.
                break

            self.board.trigger()
            move_count += 1

        datas = np.asarray(datas)
        datas[:, 2][datas[:, 2] == self.board.c_player] = reward
        datas[:, 2][datas[:, 2] != self.board.c_player] = -reward

        return datas
コード例 #5
0
    def play_game(self, game, training_data):
        """Loop for each self-play game.

        Runs MCTS for each game state and plays a move based on the MCTS output.
        Stops when the game is over and prints out a winner.

        Args:
            game: An object containing the game state.
            training_data: A list to store self play states, pis and vs.
        """
        mcts = MonteCarloTreeSearch(self.net)

        game_over = False
        value = 0
        self_play_data = []
        count = 0

        node = TreeNode()

        # Keep playing until the game is in a terminal state.
        while not game_over:
            # MCTS simulations to get the best child node.
            if count < CFG.temp_thresh:
                best_child, prob_vector = mcts.search(game, node,
                                                      CFG.temp_init)
            else:
                best_child, prob_vector = mcts.search(game, node,
                                                      CFG.temp_final)

            # Store state, prob and v for training.
            if best_child != None:
                self_play_data.append(
                    [deepcopy(game.state),
                     deepcopy(prob_vector), 0])

                action = best_child.action
                game.play_action(action)  # Play the child node's action.
                count += 1
                # print('Next player is', game.current_player)

                game_over, value = game.check_game_over(game.current_player)

                best_child.parent = None
                node = best_child  # Make the child node the root node.
            else:
                self_play_data.append(
                    [deepcopy(game.state),
                     deepcopy(prob_vector), 0])
                game.current_player *= -1
                # print('NO ACTION TAKEN, Next player is', game.current_player)

        # Update v as the value of the game result.
        print('FINAL SCORES ARE ', game.score)
        for game_state in self_play_data:
            value = -value
            game_state[2] = value
            self.augment_data(game_state, training_data, game.row, game.column)
コード例 #6
0
 def __init__(self, model=None, thinking_depth=1024, config=dict()):
     if model is None:
         pvn = PolicyValueNet(None, config)
         self.mcts = MonteCarloTreeSearch(pvn)
     elif isinstance(model, str):
         self.load_model(model)
     else:
         self.mcts = MonteCarloTreeSearch(model)
     self.thinking_depth = thinking_depth
コード例 #7
0
    def play(self):
        """Function to play a game vs the AI."""
        print("Start Human vs AI\n")

        mcts = MonteCarloTreeSearch(self.net)
        game = self.game.clone()  # Create a fresh clone for each game.
        game_over = False
        value = 0
        node = TreeNode()

        print("Enter your move in the form: row, column. Eg: 1,1")
        go_first = input("Do you want to go first: y/n?")

        if go_first.lower().strip() == 'y':
            print("You play as X")
            human_value = 1

            game.print_board()
        else:
            print("You play as O")
            human_value = -1

        # Keep playing until the game is in a terminal state.
        while not game_over:
            # MCTS simulations to get the best child node.
            # If player_to_eval is 1 play as the Human.
            # Else play as the AI.
            if game.current_player == human_value:
                action = input("Enter your move: ")
                if isinstance(action, str):
                    action = [int(n, 10) for n in action.split(",")]
                    action = (1, action[0], action[1])

                best_child = TreeNode()
                best_child.action = action
            else:
                best_child = mcts.search(game, node,
                                         CFG.temp_final)

            action = best_child.action
            game.play_action(action)  # Play the child node's action.

            game.print_board()

            game_over, value = game.check_game_over(game.current_player)

            best_child.parent = None
            node = best_child  # Make the child node the root node.

        if value == human_value * game.current_player:
            print("You won!")
        elif value == -human_value * game.current_player:
            print("You lost.")
        else:
            print("Draw Match")
        print("\n")
コード例 #8
0
 def setUp(self):
     # tree init
     self.game = Game()
     self.tree = MonteCarloTreeSearch(game=self.game)
     # 1st level (complete)
     self.make_level(self.tree.root, 9)
     # 2nd level (complete)
     self.make_level(self.tree.root.children[3], 8)
     # 3rd level (incomplete)
     self.make_level(self.tree.root.children[3].children[4], 3)
コード例 #9
0
ファイル: server.py プロジェクト: hsperr/StackIt
 def __init__(self, iid: str, board: Board, ai: str, thinking_time: int):
     self.iid = iid
     self.board = board
     self.alphaBeta = AlphaBeta()
     self.mcts = MonteCarloTreeSearch()
     self.player_names = ["HUMAN", str(ai)]
     self.set_ai(ai)
     self.thinking_time = thinking_time
     self.max_depth = 30
     self.last_move = []
     self.last_move_score = []
コード例 #10
0
    def start(self):
        """Main training loop."""
        for i in range(self.num_iterations):
            print("Iteration", i + 1)

            training_data = []  # list to store self play states, pis and vs

            for j in range(self.num_games):
                print("Start Training Self-Play Game", j + 1)
                game = self.game.clone()  # Create a fresh clone for each game.
                self.play_game(game, training_data)
                print(game.evaluate())
                self.scores.append(game.evaluate())

            # Save the current neural network model.
            self.net.save_model()

            # Load the recently saved model into the evaluator network.
            self.eval_net.load_model()

            # Train the network using self play values.
            self.net.train(training_data)

            # Initialize MonteCarloTreeSearch objects for both networks.
            current_mcts = MonteCarloTreeSearch(self.net)
            eval_mcts = MonteCarloTreeSearch(self.eval_net)
            ''' TO BE COMPLETED ! '''
            evaluator = Evaluate(current_mcts=current_mcts,
                                 eval_mcts=eval_mcts,
                                 game=self.game)
            wins, losses = evaluator.evaluate()
            #             wins, losses = 1, 1

            print("wins:", wins)
            print("losses:", losses)

            num_games = wins + losses

            if num_games == 0:
                win_rate = 0
            else:
                win_rate = wins / num_games

            print("win rate:", win_rate)

            if win_rate > self.eval_win_rate:
                # Save current model as the best model.
                print("New model saved as best model.")
                self.net.save_model("best_model")
            else:
                print("New model discarded and previous model loaded.")
                # Discard current model and use previous best model.
                self.net.load_model()
コード例 #11
0
def run_game_batched(initial_game: Game,
                     games: int,
                     simulations: int,
                     verbose=False):
    mcts = MonteCarloTreeSearch()
    # Various stats for the batch
    p1_starts = 0
    p2_starts = 0
    p1_wins = 0
    p2_wins = 0
    win_count = 0

    for i in tqdm(range(games)):
        if verbose:
            print("Starting game:", i)
        game = copy.deepcopy(initial_game)
        mcts.clear()
        game.start()  # Select starting player
        if game.player1_starts:
            p1_starts += 1
        else:
            p2_starts += 1

        while not game.completed:
            # NOTE: simulations can decay over time to speed up, that's why they're not part of constructor
            move = mcts.search(game, simulations)
            game.apply_move(move)
            if verbose:
                game.print_move(move)

        if game.starting_player_won():
            win_count += 1

        if game.player1s_turn:
            p2_wins += 1
        else:
            p1_wins += 1

    print(
        f"Player 1 started {p1_starts} out of {games} games - {p1_starts * 100 / games}%)"
    )
    print(
        f"Player 2 started {p2_starts} out of {games} games - {p2_starts * 100 / games}%)"
    )
    print(
        f"Starting player won {win_count} out of {games} games - {win_count * 100 / games}%)"
    )
    print(
        f"Player 1 won {p1_wins} out of {games} games - {p1_wins * 100 / games}%)"
    )
    print(
        f"Player 2 won {p2_wins} out of {games} games - {p2_wins * 100 / games}%)"
    )
コード例 #12
0
 def __init__(self, network_path, faktor, exponent):
     self.nnet = keras.models.load_model(network_path, compile=False)
     self.exponent = exponent
     self.faktor = faktor
     self.env = MillEnv()
     self.graphics = MillDisplayer(self.env)
     self.graphics.reloadEnv()
     self.root: State = State(np.zeros((1, 24)), 0, -self.env.isPlaying,
                              self.env)
     val = self.root.setValAndPriors(self.nnet)
     self.root.backpropagate(val)
     self.mcts = MonteCarloTreeSearch(self.root)
コード例 #13
0
 def resetMonteCarlo(self):
     self.env.reset()
     self.graphics.millEnv.reset()
     self.root: State = State(np.zeros((1, 24)), 0, -self.env.isPlaying,
                              self.env)
     self.env.setFullState(self.mcts.root.state[0], self.mcts.root.state[1],
                           self.mcts.root.state[2], self.mcts.root.state[3],
                           self.mcts.root.state[4], self.mcts.root.state[5],
                           self.mcts.root.state[6], self.mcts.root.state[7],
                           self.mcts.root.state[8], self.mcts.root.state[9])
     self.mcts = MonteCarloTreeSearch(self.root)
     val = self.mcts.root.setValAndPriors(self.nnet)
     self.mcts.root.backpropagate(val)
     self.graphics.reloadEnv()
コード例 #14
0
ファイル: play.py プロジェクト: jbofill10/Magma-Chess
def game():

    board = ChessGame()
    board.show_board()
    model = CovNet()
    player_color = 1 if sys.argv[1] else 0
    while not board.is_game_over():
        if board.curr_turn() == player_color:
            move = input(
                f"Enter a move as {'White' if board.curr_turn() else 'Black'}: "
            )

            try:
                print(chr(27) + "[2J")
                board.make_move(move)
                board.show_board()
            except Exception as error:
                print("Unable to make move!")
                print(error)
        else:
            move = MonteCarloTreeSearch(board=board, model=model,
                                        depth=30).run_mcts(200)
            print(chr(27) + "[2J")
            board.make_move(move, format='uci')
            board.show_board()
コード例 #15
0
    def play_game(self, game, training_data):
        """Loop for each self-play game.

        Runs MCTS for each game state and plays a move based on the MCTS output.
        Stops when the game is over and prints out a winner.

        Args:
            game: An object containing the game state.
            training_data: A list to store self play states, pis and vs.
        """
        mcts = MonteCarloTreeSearch(self.net)

        game_over = False
        value = 0
        self_play_data = []
        count = 0

        node = TreeNode()

        # Keep playing until the game is in a terminal state.
        while not game_over:
            # MCTS simulations to get the best child node.
            if count < self.temp_thresh:
                best_child = mcts.search(game, node, self.temp_init)
            else:
                best_child = mcts.search(game, node, self.temp_final)

            # Store state, prob and v for training.
            self_play_data.append([
                deepcopy(game.state['state']),
                deepcopy(best_child.parent.child_psas), 0
            ])

            action = best_child.action
            game.play_action(action)  # Play the child node's action.
            count += 1
            ''' TO BE COMPLETED !! '''
            game_over, value = game.check_game_over()

            best_child.parent = None
            node = best_child  # Make the child node the root node.

        # Update v as the value of the game result.
        for game_state in self_play_data:
            value = -value
            game_state[2] = value
            self.augment_data(game_state, training_data, game.row, game.column)
コード例 #16
0
    def choose_move(self, game):
        if self.mcts is None:
            self.mcts = MonteCarloTreeSearch(game, RandomAI,
                                             game.get_player_num(),
                                             self.turnTime)
        else:
            self.mcts = self.mcts.childByState(str(game))
        self.mcts.search()

        # print number of visits (out of curiosity)
        if not game.quiet:
            print "Number of Visits: " + str(self.mcts.getVisits())
            print "Player: " + str(self.mcts.playerNum)

        stateMovePair = self.mcts.bestMove()
        self.mcts = self.mcts.childByState(stateMovePair[0])
        return stateMovePair[1]
コード例 #17
0
ファイル: train.py プロジェクト: rnovesteras/rubiks_cube
    def play_game(self, game, training_data):
        """Loop for each self-play game.

        Runs MCTS for each game state and plays a move based on the MCTS output.
        Stops when the game is over and prints out a winner.

        Args:
            game: An object containing the game state.
            training_data: A list to store self play states, pis and vs.
        """
        mcts = MonteCarloTreeSearch(self.net)

        game_over = False
        value = 0
        self_play_data = []
        count = 0

        node = TreeNode()

        # Keep playing until the game is in a terminal state.
        while not game_over:
            # MCTS simulations to get the best child node.
            if count < CFG.temp_thresh:
                best_child = mcts.search(game, node, CFG.temp_init)
            else:
                best_child = mcts.search(game, node, CFG.temp_final)

            # Store state, prob and v for training.
            self_play_data.append([deepcopy(game.state),
                                   deepcopy(best_child.parent.child_psas),
                                   0])

            action = best_child.action
            game.play_action(action)  # Play the child node's action.
            count += 1

            game_over, value = game.check_game_over(game.current_player)

            best_child.parent = None
            node = best_child  # Make the child node the root node.

        # Update v as the value of the game result.
        for game_state in self_play_data:
            value = -value
            game_state[2] = value
            self.augment_data(game_state, training_data, game.row, game.column)
コード例 #18
0
ファイル: launch_game.py プロジェクト: RomainSa/mcts
def main():
    """
    Run an interactive game with MCTS advice
    """
    logging.basicConfig(level=logging.CRITICAL)
    game = Game()
    # print init version of game
    game.show_board()
    while game.legal_plays():
        # run Monte Carlo Tree Search and show recommended move
        tree = MonteCarloTreeSearch(game=game)
        tree.search(max_iterations=10000, max_runtime=3, n_simulations=1)
        tree.show_tree(level=1)
        print("MCTS recommends: {}".format(tree.recommended_play()))
        # ask user for move to play and play it
        move = tuple([
            int(s) for s in input("Move to play (format: .,.) : ").split(',')
        ])
        game.play(move)
        print('You played:')
        game.show_board()
        # a random  move is selected for opponent
        if game.legal_plays():
            game.play()
        print('Opponent played:')
        game.show_board()
    if game.winner() is None:
        print("It's a tie! :|")
    elif game.winner() == game.players[0]:
        print("You win! :)")
    elif game.winner() == game.players[1]:
        print("You lose! :(")
    else:
        raise ValueError('Unknown game status')
コード例 #19
0
class MCTS_Player(Player):
    def __init__(self, turnTime=30):
        self.turnTime = turnTime
        self.human = False
        self.mcts = None

    def choose_move(self, game):
        if self.mcts is None:
            self.mcts = MonteCarloTreeSearch(game, RandomAI,
                                             game.get_player_num(),
                                             self.turnTime)
        else:
            self.mcts = self.mcts.childByState(str(game))
        self.mcts.search()

        # print number of visits (out of curiosity)
        if not game.quiet:
            print "Number of Visits: " + str(self.mcts.getVisits())
            print "Player: " + str(self.mcts.playerNum)

        stateMovePair = self.mcts.bestMove()
        self.mcts = self.mcts.childByState(stateMovePair[0])
        return stateMovePair[1]

    def reset(self):
        self.mcts = None
コード例 #20
0
def play():
    tree = MonteCarloTreeSearch()
    board = newBoard()
    print(board.toString())
    while True:
        rowColumn = input("enter in this format row,col: ")
        row, col = map(int, rowColumn.split(","))
        index = 3 * (row - 1) + (col - 1)
        if board.tuple[index] is not None:
            raise RuntimeError("Invalid move")
        board = board.makeMove(index)
        print(board.toString())
        if board.leaf:
            break
        # You can train at every round, or you can handle training only at the begining.
        # We are currently training as we move by each round, every round 2500 rollouts.
        # the more we increase number, the tough we win.If we decrease this to 30 rollouts winning becomes far easier.
        for _ in range(2500):
            tree.makeRollout(board)
        board = tree.choose(board)
        print(board.toString())
        if board.leaf:
            break
コード例 #21
0
    def evaluate(self, result):
        self.net.eval()
        self.evl_net.eval()

        if random.randint(0, 1) == 1:
            players = {
                BLACK: (MonteCarloTreeSearch(self.net), "net"),
                WHITE: (MonteCarloTreeSearch(self.evl_net), "eval"),
            }
        else:
            players = {
                WHITE: (MonteCarloTreeSearch(self.net), "net"),
                BLACK: (MonteCarloTreeSearch(self.evl_net), "eval"),
            }
        node = TreeNode()

        while True:
            _, next_node = players[self.board.c_player][0].search(
                self.board, node)

            self.board.move(next_node.action)

            if self.board.is_draw():
                result[0] += 1
                return

            if self.board.is_game_over():
                if players[self.board.c_player][1] == "net":
                    result[1] += 1
                else:
                    result[2] += 1
                return

            self.board.trigger()

            next_node.parent = None
            node = next_node
コード例 #22
0
ファイル: play.py プロジェクト: xdcesc/torch_light
    def go(self):
        print("One rule:\r\n Move piece form 'x,y' \r\n eg 1,3\r\n")
        print("-" * 60)
        print("Ready Go")

        mc = MonteCarloTreeSearch(self.net, 1000)
        node = TreeNode()
        board = Board()

        while True:
            if board.c_player == BLACK:
                action = input(f"Your piece is 'O' and move: ")
                action = [int(n, 10) for n in action.split(",")]
                action = action[0] * board.size + action[1]
                next_node = TreeNode(action=action)
            else:
                _, next_node = mc.search(board, node)

            board.move(next_node.action)
            board.show()

            next_node.parent = None
            node = next_node

            if board.is_draw():
                print("-" * 28 + "Draw" + "-" * 28)
                return

            if board.is_game_over():
                if board.c_player == BLACK:
                    print("-" * 28 + "Win" + "-" * 28)
                else:
                    print("-" * 28 + "Loss" + "-" * 28)
                return

            board.trigger()
コード例 #23
0
ファイル: star.py プロジェクト: JackFurby/STAR
def update():
	"""Text interface for user"""
	print()
	for playerNum in range(len(game.players)):
		player = game.players[playerNum]
		print("Player " + str(playerNum + 1) + "	-	" + str(player.score))
	print()
	action = input("What do you want to do? Enter 'help' for more: ")

	if action == "\q":
		pygame.quit()
		sys.exit()
	elif action == "isAccepted":
		inputLetters = input("Enter a word to check: ")

		# Makes sure input is in lower case
		inputLetters = inputLetters.lower()

		if game.trie.hasWord(inputLetters):
			print("Yes")
		else:
			print("No")
	elif action == "findWords":
		inputLetters = input("Enter letters ('?' is a wildcard): ").lower()
		start = time.time()
		wordList = game.trie.wordSearch(list(inputLetters))
		wordList.sort(key=lambda tup: -tup[1])
		print(*wordList, sep='\n')
		end = time.time()
		print("Completed search in", end - start, 'seconds')
	elif action == "findWordsPrefix":
		prefixLetters = input("Enter prefix (in order): ").lower()
		inputLetters = input("Enter letters ('?' is a wildcard): ").lower()
		start = time.time()
		wordList = game.trie.prefix(list(inputLetters), prefixLetters)
		wordList.sort(key=lambda tup: -tup[1])
		print(*wordList, sep='\n')
		end = time.time()
		print("Completed search in", end - start, 'seconds')
	elif action == "findWordsSuffix":
		suffixLetters = input("Enter suffix (in order): ").lower()
		inputLetters = input("Enter letters ('?' is a wildcard): ").lower()
		start = time.time()
		wordList = game.trie.wordSearch(list(inputLetters), suffix=suffixLetters)
		wordList.sort(key=lambda tup: -tup[1])
		print(*wordList, sep='\n')
		end = time.time()
		print("Completed search in", end - start, 'seconds')
	elif action == "findWordsContains":
		suffixLetters = input("Enter string words must contain (in order): ").lower()
		inputLetters = input("Enter letters ('?' is a wildcard): ").lower()
		start = time.time()
		wordList = game.trie.contains(list(inputLetters), suffixLetters)
		wordList.sort(key=lambda tup: -tup[1])
		print(*wordList, sep='\n')
		end = time.time()
		print("Completed search in", end - start, 'seconds')
	elif action == "findMoves":
		player = game.getPlayer(numInput(input("Enter player number: ")) - 1)
		if player != False:
			start = time.time()
			wordList = game.board.possibleMoves(player.letters, game.trie)
			wordList.sort(key=lambda tup: -tup[1])
			print(*wordList, sep='\n')
			end = time.time()
			print("Completed search in", end - start, 'seconds')
		else:
			print("Player not created")
	elif action == "lookAhead":
		if len(game.players) > 0:
			start = time.time()

			updatedBoard = copy.deepcopy(game.board)
			updatedTiles = copy.deepcopy(game.tiles)
			currentPlayer = game.players[game.active]

			players = []
			newTiles, updatedRemainingTiles = updatedTiles.getProbableTiles(updatedBoard, 0, currentPlayer.letters)
			for i in game.players:
				# we know our tiles but not other player tiles
				if i is not currentPlayer:
					newTiles, updatedRemainingTiles = updatedTiles.getProbableTiles(updatedBoard, len(i.letters), currentPlayer.letters)
					players.append([newTiles, i.score])
				else:
					players.append([currentPlayer.letters, currentPlayer.score])

			print('players:', players)
			mcts = MonteCarloTreeSearch(updatedBoard, updatedTiles, players, game.active, game.trie, game.active, game.over, updatedRemainingTiles)
			bestMove = mcts.run(180)  # run for 3 minutes
			#bestMove = mcts.run(600)  # run for 10 minutes
			end = time.time()
			print("Completed search in", end - start, 'seconds')
			print("Best move is:", bestMove.state.moveMade)
			print("Node score:", bestMove.score)
			print("Node visits:", bestMove.visits)
			print("Node average:", bestMove.score / bestMove.visits)
		else:
			print("No players created")
	elif action == "board":
		game.board.printBoard()
	elif action == "letters":
		game.tiles.printLetters()
	elif action == "makePlayer":
		playerIndex = game.newPlayer()
		if playerIndex is not False:
			print("Player " + str(playerIndex + 1) + " created")
			# add 7 tiles to player
			game.players[playerIndex].takeLetters(game.tiles)
		else:
			print("Max player limit reached")
	elif action == "playerLetters":
		# Get the player index in array
		player = game.players[numInput(input("Enter player number: ")) - 1]
		if player != False:
			player.printLetters()
		else:
			print("Player not created")
	elif action == "playTurn":
		if len(game.players) == 0:
			print("No current players")
		else:
			print("Player " + str(game.active + 1) + " it's your turn")
			player = game.players[game.active]

			turn = True

			while turn:
				print()
				player.printLetters()
				print()

				playOption = input("Enter 1 to skip turn, 2 to swap some letters or 3 to place tile(s): ")
				# player selects what they want to do
				if playOption == '1':
					# Go skipped
					print("Go skipped")
					turn = False
				elif playOption == '2':
					# Player replaces 0 or more tiles
					player.printLetters()
					swapTiles = []  # list of tiles to swap
					stillEntering = True
					while stillEntering:
						selectedTile = numInput(input("Input a tile to replace (0 is the first tile, 6 is the last). Enter 7 to stop selection: "))
						if selectedTile == 7:  # stop swapping tiles. All tiles in swapTiles are replaces
							player.swapLetters(game, swapTiles)
							stillEntering = False
						elif selectedTile >= 0 and selectedTile <= 6:
							if player.letters[selectedTile] is None:
								print("Tile already selected")
							else:
								swapTiles.append(player.letters[selectedTile])
								player.letters[selectedTile] = None
								print("letters left: " + str(player.letters))
								print("letters removed: " + str(swapTiles))
								if len(swapTiles) == 7:  # swap all tiles
									player.takeLetters(game.tiles)
									game.tiles.returnTiles(swapTiles)
									stillEntering = False
						else:
							print("Input out of range")
					turn = False
				elif playOption == '3':
					# Player places 1 or more tiles
					word = []  # List of tiles to play with score (in order)
					playerBackup = copy.deepcopy(player.letters)  # backup of player tiles incase input is not accepted
					stillEntering = True
					while stillEntering:
						selectedTile = numInput(input("Input a tile to use in order (0 is the first tile, 6 is the last). Enter 7 to stop selection: "))
						if selectedTile == 7:  # Stop adding tiles to play
							print(word)
							stillEntering = False
						elif selectedTile >= 0 and selectedTile <= 6:  # Add tile to play
							if player.letters[selectedTile] is None:
								print("Tile already selected")
							else:
								if player.letters[selectedTile] is '?':  # If tile is a blank then provide character and add score of 0
									while True:
										char = input("input letter: ")
										if char.isalpha():
											break
										print("Please enter characters a-z only")
									word.append([char.lower(), 0])
								else:  # Add tile to word list
									word.append([player.letters[selectedTile], letterScore(player.letters[selectedTile])])
								player.letters[selectedTile] = None
								print("letters left: " + str(player.letters))
								print("letters used: " + str(word))
								if len(word) == 7:  # swap all tiles
									stillEntering = False
						else:
							print("Input out of range")

					if len(word) > 0:
						# Add tiles start position and direction
						x = numInput(input("Enter x of first tile (starting from 0 in top left): "))
						y = numInput(input("Enter y of first tile (starting from 0 in top left): "))

						while True:
							direction = input("enter direction of word (right/down): ")
							if direction == 'right' or direction == 'down':
								break
							print("Input not recognised")

						# Verify word placement is valid and play it if it is
						tilesAdded, score = game.board.addWord(word, x, y, direction, game.trie, player)
						if tilesAdded:
							# Word accepted. End turn and update score
							turn = False
							player.score = player.score + score
							print("Player " + str(game.active + 1) + " you scored " + str(score))
						else:
							# Word not accepeted. Reset player and try again
							player.letters = playerBackup
							print("Input not accepted")
					else:
						print("No tiles selected. Turn skipped")
						turn = False

				else:
					print("input not recognised")
			# refill player letters
			player.takeLetters(game.tiles)

			# If player has no tiles after refilling there are not tiles left. Game ends
			for tile in player.letters:
				if tile is None:
					game.over = True

			if game.over:
				print("")
				print("=== Game Over ===")
				print("")
				for playerNum in range(len(game.players)):
					player = game.players[playerNum]
					print("Player " + str(playerNum + 1) + "	-	" + str(player.score))

				pygame.quit()
				sys.exit()

			# change player
			game.nextPlayer()
	elif action == "activePlayer":
		if len(game.players) == 0:
			print("No current players")
		else:
			print("Player " + str(game.active + 1) + " it's your turn")
	elif action == "probableTiles":
		tileProbabilities = game.tiles.probableTiles(game.board)
		for i in tileProbabilities:
			print(i)
	elif action == "probableTilesWithPlayer":
		player = game.getPlayer(numInput(input("Enter player number: ")) - 1)
		tileProbabilities = game.tiles.probableTilesWithPlayer(game.board, player)
		for i in tileProbabilities:
			print(i)
	elif action == "nextPlayerTiles":
		player = game.getPlayer(numInput(input("Enter player number: ")) - 1)
		probablePlayerTiles = game.tiles.nextProbablePlayer(game.board, player)
		print(probablePlayerTiles)


	elif action == "help":
		print("")
		print("=== STAR help ===")
		print("")
		print("\q			-	Exit STAR")
		print("isAccepted		-	Enter a single word to find out if it is accepted or not")
		print("findWords		-	Find all words you can make with a given set of characters")
		print("findWordsPrefix		-	Find all words you can make with a given set of characters + a prefix")
		print("findWordsSuffix		-	Find all words you can make with a given set of characters + a suffix")
		print("findWordsContains	-	Find all words you can make with a given set of characters + a set string")
		print("findMoves		-	Find all words you can make with a given player and the board")
		print("lookAhead		-	Return the best moves to make to win a game (Not working)")
		print("board			-	Display the current state of the board")
		print("letters			-	Display the current letters available to take")
		print("probableTiles		-	Print a list of tiles not on the board with the probability of picking that tile")
		print("probableTilesWithPlayer	-	Print a list of tiles not on the board or on players rack with the probability of picking that tile")
		print("nextPlayerTiles		- 	Print a list of tiles that are most probable for the next player to have")
		print("makePlayer		-	Makes a new player (max 4)")
		print("playerLetters		-	Prints the letters a given player has")
		print("playTurn		-	Make a move for the current players turn")
		print("activePlayer		-	Print the current active player")
		print("")
	else:
		print("Input not recognised")
コード例 #24
0
from mcts import MonteCarloTreeSearch
import game
from board import Board
import numpy as np
import random
import model

import ipdb

model_name = 'models/m1'
m = model.NN(existing=model_name)
alpha_iters = 100
perfect_iters = 1000

alphazero = MonteCarloTreeSearch(simulations=alpha_iters, model=m)
perfect = MonteCarloTreeSearch(simulations=perfect_iters)
vanilla = MonteCarloTreeSearch(simulations=alpha_iters)


def test_alphazero(g):
    g.board.print_pretty()
    alpha_pi = np.array(alphazero.mcts(g))
    perfect_pi = np.array(perfect.mcts(g))
    vanilla_pi = np.array(vanilla.mcts(g))
    print("alpha_pi =\n {}".format(np.array(alpha_pi).reshape((3, 3))))
    print("p =\n {}".format(
        np.array(m.predict_policy(g.board)).reshape((3, 3))))
    print("v = {}".format(m.predict_score(g.board)))


print("Testing obvious draw")
コード例 #25
0
class ModeratedGraphics(object):
    def __init__(self, network_path, faktor, exponent):
        self.nnet = keras.models.load_model(network_path, compile=False)
        self.exponent = exponent
        self.faktor = faktor
        self.env = MillEnv()
        self.graphics = MillDisplayer(self.env)
        self.graphics.reloadEnv()
        self.root: State = State(np.zeros((1, 24)), 0, -self.env.isPlaying,
                                 self.env)
        val = self.root.setValAndPriors(self.nnet)
        self.root.backpropagate(val)
        self.mcts = MonteCarloTreeSearch(self.root)

    def agentPlay(self):
        self.resetMonteCarlo()
        self.graphics.deactivateClick()
        finished = 0
        while finished == 0:
            self.graphics.reloadEnv()
            pi = self.mcts.search(self.nnet, self.faktor, self.exponent)
            if self.mcts.depth < 5:
                choices_pi = np.where(pi == -1, np.zeros(pi.shape), pi)
                pos = np.random.choice(np.arange(24), p=choices_pi)
            else:
                pos = np.argmax(pi)
            self.mcts.goToMoveNode(pos)
            self.env.setFullState(
                self.mcts.root.state[0], self.mcts.root.state[1],
                self.mcts.root.state[2], self.mcts.root.state[3],
                self.mcts.root.state[4], self.mcts.root.state[5],
                self.mcts.root.state[6], self.mcts.root.state[7],
                self.mcts.root.state[8], self.mcts.root.state[9])
            event, values = self.graphics.read(True)
            if self.eventHandler(event):
                return
            finished = self.env.isFinished()
        self.graphics.reloadEnv()
        if not finished == 2:
            self.graphics.setStatus("player " +
                                    self.graphics.getPlayerName(finished) +
                                    " won")
        else:
            self.graphics.setStatus("The game ended in a draw")

    def playersVSPlayer(self):
        self.graphics.activateClick()
        self.graphics.reset()
        finished = 0
        while finished == 0:
            event, values = self.graphics.read()
            if self.eventHandler(event):
                return
            self.graphics.reloadEnv()
            finished = self.env.isFinished()
        self.graphics.reloadEnv()
        if not finished == 2:
            self.graphics.setStatus("player " +
                                    self.graphics.getPlayerName(finished) +
                                    " won")
        else:
            self.graphics.setStatus("The game ended in a draw")
        self.graphics.deactivateClick()

    def playerVSAgent(self):
        self.graphics.activateClick()
        self.resetMonteCarlo()
        finished = 0
        while finished == 0:
            event, values = self.graphics.read(True)
            if self.eventHandler(event):
                return
            while len(self.graphics.last_move) > 0:
                self.mcts.goToMoveNode(self.graphics.last_move.pop())
                self.env.setFullState(
                    self.mcts.root.state[0], self.mcts.root.state[1],
                    self.mcts.root.state[2], self.mcts.root.state[3],
                    self.mcts.root.state[4], self.mcts.root.state[5],
                    self.mcts.root.state[6], self.mcts.root.state[7],
                    self.mcts.root.state[8], self.mcts.root.state[9])
            if self.env.isPlaying == 1:
                self.graphics.activateClick()
            else:
                if self.mcts.root.priors is None:
                    val = self.mcts.root.setValAndPriors(self.nnet)
                    self.mcts.root.backpropagate(val)
                    generate_empty_nodes(self.mcts.root)
                self.graphics.deactivateClick()
                pi = self.mcts.search(self.nnet, self.faktor, self.exponent)
                pos = np.argmax(pi)
                self.mcts.goToMoveNode(pos)
                self.env.setFullState(
                    self.mcts.root.state[0], self.mcts.root.state[1],
                    self.mcts.root.state[2], self.mcts.root.state[3],
                    self.mcts.root.state[4], self.mcts.root.state[5],
                    self.mcts.root.state[6], self.mcts.root.state[7],
                    self.mcts.root.state[8], self.mcts.root.state[9])
            self.graphics.reloadEnv()
            finished = self.env.isFinished()
        self.graphics.reloadEnv()
        if not finished == 2:
            self.graphics.setStatus("player " +
                                    self.graphics.getPlayerName(finished) +
                                    " won")
        else:
            self.graphics.setStatus("The game ended in a draw")
        self.graphics.deactivateClick()

    def playLoop(self):
        self.graphics.deactivateClick()
        self.playerVSAgent()
        finished = False
        while not finished:
            event, values = self.graphics.read()
            finished = self.eventHandler(event)
            if event != psGui.WIN_CLOSED and event != 'Close':
                finished = False

    def eventHandler(self, event) -> bool:
        if event == psGui.WIN_CLOSED or event == 'Close':  # if user closes window or clicks cancel
            self.graphics.close()
            return True
        elif event == "Agent vs. Agent":
            self.agentPlay()
            return True
        elif event == "Player vs. Player":
            self.playersVSPlayer()
            return True
        elif event == "Player vs. Agent":
            self.playerVSAgent()
            return True
        return False

    def resetMonteCarlo(self):
        self.env.reset()
        self.graphics.millEnv.reset()
        self.root: State = State(np.zeros((1, 24)), 0, -self.env.isPlaying,
                                 self.env)
        self.env.setFullState(self.mcts.root.state[0], self.mcts.root.state[1],
                              self.mcts.root.state[2], self.mcts.root.state[3],
                              self.mcts.root.state[4], self.mcts.root.state[5],
                              self.mcts.root.state[6], self.mcts.root.state[7],
                              self.mcts.root.state[8], self.mcts.root.state[9])
        self.mcts = MonteCarloTreeSearch(self.root)
        val = self.mcts.root.setValAndPriors(self.nnet)
        self.mcts.root.backpropagate(val)
        self.graphics.reloadEnv()
コード例 #26
0
    def play(self):

        mcts = MonteCarloTreeSearch(self.net)
        game = deepcopy(self.game)
        game_over = False
        value = 0
        node = TreeNode()
        valid = 0
        # self.game.colorBoard()
        game.print_board()

        while not game_over:

            if game.current_player == self.human_player:
                valid = False
                while valid == False:
                    piece, refpt, rot, flip = self.get_input(game)
                    piece.create(0, (refpt[0], refpt[1]))

                    f = 'None'
                    if flip == 0:
                        f == 'None'
                    else:
                        f = 'h'

                    piece.flip(f)
                    piece.rotate(90 * rot)

                    valid = game.valid_move(piece.points, self.human_player)

                    if valid == False:
                        print('You selected an illegal move, please reselect')
                        # print('attempting', piece.points)
                        # print('corners are ', game.corners[self.human_player])

                    if piece.ID not in ['I5', 'I4', 'I3', 'I2']:
                        encoding = (refpt[0] * 14 +
                                    refpt[1]) * 91 + piece.shift + (
                                        rot // 90) * 2 + flip
                    else:
                        encoding = (refpt[0] * 14 +
                                    refpt[1]) * 91 + piece.shift + (
                                        rot // 90) * 1 + flip

                best_child = TreeNode()
                best_child.action = encoding
                print('CHOICE WAS MADE BY A HUMAN TO PLAY', piece.ID, '@',
                      refpt)

            else:
                best_child = mcts.search(game, node, CFG.temp_final)

            action = best_child.action
            game.play_action(action)

            game.print_board()
            # game.colorBoard()

            game_over, value = game.check_game_over(game.current_player)

            best_child.parent = None
            node = best_child

        if value == self.human_player * game.current_player:
            print("You won!")
        elif value == -self.human_player * game.current_player:
            print("You lost.")
        else:
            print("Draw Match")
コード例 #27
0
 def load_model(self, name='latest'):
     model = PolicyValueNet('../model/{}.model'.format(name))
     self.mcts = MonteCarloTreeSearch(model)
コード例 #28
0
class Player:
    def __init__(self, model=None, thinking_depth=1024, config=dict()):
        if model is None:
            pvn = PolicyValueNet(None, config)
            self.mcts = MonteCarloTreeSearch(pvn)
        elif isinstance(model, str):
            self.load_model(model)
        else:
            self.mcts = MonteCarloTreeSearch(model)
        self.thinking_depth = thinking_depth

    def self_play(self, show_board=True):
        self.new_game()
        while not self.is_game_end():
            if show_board:
                self.show()
            x, y = self.pick()
            self.set_move(x, y)
        if show_board:
            self.show()

    def vs_user(self, first_hand=False):
        self.new_game()
        while not self.is_game_end():
            self.show()
            if first_hand:
                cmd = self.get_user_input()
                if cmd == 'exit':
                    return
                else:
                    self.set_move(*cmd)
            else:
                x, y = self.pick()
                self.set_move(x, y)
            first_hand = not first_hand
        self.show()
        if self.get_game_state() == 'draw':
            print('Draw game.')
        elif first_hand:
            print('AI win.')
        else:
            print('User win.')

    def get_user_input(self):
        while True:
            try:
                cmd = input("move: ")
                if cmd.lower() == 'exit':
                    return cmd
                x, y = eval(cmd)
                return (x, y)
            except KeyboardInterrupt as e:
                raise e
            except:
                print("Bad input, try again.")

    def vs(self, opponent, show_board=True):
        current = self
        current.new_game()
        opponent.new_game()
        while not current.is_game_end():
            if show_board:
                current.show()
            x, y = current.pick()
            current.set_move(x, y)
            opponent.set_move(x, y)
            current, opponent = opponent, current
        if show_board:
            current.show()
        if current.get_game_state() == 'draw':
            return 'Draw'
        elif current != self:
            return 'Win'
        else:
            return 'Lose'

    def set_thinking_depth(self, depth):
        self.thinking_depth = depth

    def pick(self):
        x, y = self.mcts.pick(self.thinking_depth)
        return (x + 1, y + 1)

    def get_game_state(self):
        return self.mcts.get_game_state()

    def is_game_end(self):
        return self.mcts.is_game_end()

    def show(self):
        self.mcts.show()

    def new_game(self):
        self.mcts.new_game()

    def set_move(self, x, y):
        move = (x - 1, y - 1)
        self.mcts.set_move(move)

    def save(self, name, override=False):
        self.save_data(name, override)
        self.save_model(name, override)

    def save_data(self, name, override=False):
        self.mcts.save_data('../data/{}.data'.format(name), override)

    def save_model(self, name, override=False):
        self.mcts.save_model('../model/{}.model'.format(name), override)

    def load_data(self, name='latest', merge=True):
        self.mcts.load_data('../data/{}.data'.format(name), merge)

    def load_model(self, name='latest'):
        model = PolicyValueNet('../model/{}.model'.format(name))
        self.mcts = MonteCarloTreeSearch(model)

    def train(self, epochs=5, batch_size=128):
        if self.mcts.train_data[0].shape[0] == 0:
            raise RuntimeError('player: Data is unloaded.')
        self.mcts.train(epochs, batch_size)

    def clear_train_data(self):
        self.mcts.clear_train_data()

    def evolve(self, epochs=10):
        for i in range(epochs):
            if i > 0:
                self.clear_train_data()
            timestamp = time.strftime('%Y%m%d%H%M%S')
            self.save(timestamp)
            opponent = Player(timestamp)
            opponent.disable_guide()
            n = 16
            while True:
                print('Self playing: ', end='')
                for i in range(n):
                    # self.enable_guide()
                    opponent.set_thinking_depth(64)
                    self.set_thinking_depth(64)
                    self.self_play(show_board=False)
                    print(i + 1, end=' ')
                    if (i + 1) % n == 0:
                        print('')
                        timestamp = time.strftime('%Y%m%d%H%M%S')
                        self.save(timestamp, override=True)
                        print('')
                print('done.')
                self.train(4, 256)
                score = dict(Win=0, Lose=0, Draw=0)
                opponent.set_thinking_depth(128)
                self.set_thinking_depth(128)
                self.disable_guide()
                for i in range(10):
                    score[self.vs(opponent, show_board=True)] += 1
                for i in range(10):
                    result = opponent.vs(self, show_board=True)
                    result = \
                        'Win' if result == 'Lose' else \
                        'Lose' if result == 'Win' else \
                        'Draw'
                    score[result] += 1
                print(score)
                if score['Lose'] + 4 <= score['Win']:
                    print('Evlove succeed.')
                    break
                else:
                    self.load_model(timestamp)
                    self.load_data(timestamp)
                n *= 2

    def benchmark(self, opponent):
            self.disable_guide()
            opponent.disable_guide()
            score = dict(Win=0, Lose=0, Draw=0)
            for i in range(10):
                score[self.vs(opponent, show_board=True)] += 1
            for i in range(10):
                result = opponent.vs(self, show_board=True)
                result = \
                    'Win' if result == 'Lose' else \
                    'Lose' if result == 'Win' else \
                    'Draw'
                score[result] += 1
            print(score)

    def enable_guide(self):
        self.mcts.enable_guide()

    def disable_guide(self):
        self.mcts.disable_guide()
コード例 #29
0
ファイル: RL.py プロジェクト: sandernordeide/AI-programming
    M = 500
    save_interval = 50
    buffer_size = 2000
    batch_size = 1000

    # ANN parameters
    activation_functions = ["sigmoid", "tanh", "relu"]
    optimizers = ["Adagrad", "SGD", "RMSprop", "Adam"]
    alpha = 0.005  # learning rate
    H_dims = [128, 128, 64, 64]
    io_dim = board_size**2  # input and output layer sizes
    activation = activation_functions[2]
    optimizer = optimizers[3]
    epochs = 10

    ANN = ANN(io_dim, H_dims, alpha, optimizer, activation, epochs)
    MCTS = MonteCarloTreeSearch(ANN, c=1., eps=1, stoch_policy=True)
    env = HexGame(board_size)
    RL = RL(G, M, env, ANN, MCTS, save_interval, buffer_size, batch_size)

    # Run RL algorithm and plot results
    RL.run()
    RL.play_game()

    # Generate training cases
    #RL.generate_cases()

    # Plot model accuracies and losses
    #levels = np.arange(0, 251, 50)
    #RL.plot_level_accuracies(levels)
コード例 #30
0
import numpy as np
import random
import model

import ipdb

print("type the number of games each AI should play going first")
games = int(input())
print("number of iterations the vanilla mcts can search for")
vanilla_iters = int(input())

vanilla_score = 0
alphazero_score = 0
total = 0

alphazero = MonteCarloTreeSearch(simulations=vanilla_iters)
mcts = MonteCarloTreeSearch(simulations=vanilla_iters)

for i in range(2 * games):
    if i >= games:
        v = False
    else:
        v = True
    g = game.Game()
    while not g.game_over():
        if v:
            pi = mcts.mcts(g)
            best_prob = 0
            best_moves = []
            for i, prob in enumerate(pi):
                if prob > best_prob:
コード例 #31
0
def train(model_class, env):
    '''
    Train a model of instance `model_class` on environment `env` (`GridDrivingEnv`).

    It runs the model for `max_episodes` times to collect experiences (`Transition`)
    and store it in the `ReplayBuffer`. It collects an experience by selecting an action
    using the `model.act` function and apply it to the environment, through `env.step`.
    After every episode, it will train the model for `train_steps` times using the
    `optimize` function.

    Output: `model`: the trained model.
    '''

    # Initialize model and target network
    model = model_class(env.world.tensor_space().shape, env.action_space.n).to(device)
    target = model_class(env.world.tensor_space().shape, env.action_space.n).to(device)
    target.load_state_dict(model.state_dict())
    target.eval()

    # Initialize replay buffer
    memory = ReplayBuffer()

    print(model)

    # Initialize rewards, losses, and optimizer
    rewards = []
    losses = []
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    numiters = 15
    explorationParam = 1.
    random_seed = 10
    mcts = MonteCarloTreeSearch(env=env, numiters=numiters, explorationParam=1., random_seed=random_seed)

    for episode in range(max_episodes):
        epsilon = compute_epsilon(episode)
        state = env.reset()
        episode_rewards = 0.0

        for t in range(t_max):
            # Model takes action
            state = GridWorldState(state, is_done=env.done)
            state_tensor = np.copy(env.world.tensor_state)

            action = mcts.buildTreeAndReturnBestAction(initialState=state)
            print(action)
            # done = env.step(state=deepcopy(state.state), action=action)[2]
            # action = torch.from_numpy(action).float().unsqueeze(0).to(device)

            # action = model.act(state, epsilon)

            # print(action)
            # env.render()
            # Apply the action to the environment
            next_state, reward, done, info = env.step(state=deepcopy(state.state), action=action)

            # Save transition to replay buffer
            next_state_tensor = np.copy(env.world.tensor_state)
            memory.push(Transition(state_tensor, [env.actions.index(action)], [reward], next_state_tensor, [done]))

            state = next_state
            episode_rewards += reward
            if done:
                print("episode done"+ str(episode_rewards))
                break
        rewards.append(episode_rewards)

        # Train the model if memory is sufficient
        if len(memory) > min_buffer:
            if np.mean(rewards[print_interval:]) < 0.001:
                print('Bad initialization. Please restart the training.')
                exit()
            for i in range(train_steps):
                loss = optimize(model, target, memory, optimizer)
                losses.append(loss.item())

        # Update target network every once in a while
        if episode % target_update == 0:
            target.load_state_dict(model.state_dict())

        if episode % print_interval == 0 and episode > 0:
            print(
                "[Episode {}]\tavg rewards : {:.3f},\tavg loss: : {:.6f},\tbuffer size : {},\tepsilon : {:.1f}%".format(
                    episode, np.mean(rewards[print_interval:]), np.mean(losses[print_interval * 10:]), len(memory),
                    epsilon * 100))
    return model