예제 #1
0
    def policy_evaluate(self):
        """
        Evaluate the trained policy by playing against the pure MCTS player
        Note: this is only for monitoring the progress of training
        """
        player = MCTSPlayer(self.policy_value_net.policy_value,
                            c_puct=self.c_puct,
                            n_playout=30)
        environment = Molecule(["C", "O", "N"],
                               init_mol=self.mol,
                               allow_removal=True,
                               allow_no_modification=False,
                               allow_bonds_between_rings=False,
                               allowed_ring_sizes=[5, 6],
                               max_steps=10,
                               target_fn=None,
                               record_path=False)
        environment.initialize()
        environment.init_qed = QED.qed(Chem.MolFromSmiles(self.mol))

        moves, fp, _S_P, _Qs = player.get_action(environment,
                                                 temp=self.temp,
                                                 return_prob=1,
                                                 rand=False)

        return moves, _S_P, _Qs
예제 #2
0
def main(debug=False):
    
    model_file = os.path.join(curr_dir, "../model/best_model_15_5.pth")
    policy_value_net = PolicyValueNet(size, model_file=model_file)

    context = zmq.Context()
    socket = context.socket(zmq.REP)
    socket.bind("tcp://*:5555")
    print("Server start on 5555 port")
    while True:
        message = socket.recv()
        try:
            message = message.decode('utf-8')
            actions = json.loads(message)
            print("Received: %s" % message)

            start = datetime.now()
            mcts_player = MCTSPlayer(policy_value_net.policy_value_fn, c_puct=c_puct, n_playout=n_playout, is_selfplay=0)
            # result = predict
            game = FiveChess(size=size, n_in_row=n_in_row)
            for act in actions:
                step=(act[0],act[1])
                game.step_nocheck(step)

            action, value = mcts_player.get_action(game, return_value=1)

            result = {"action":action, "value": value}

            print(result)

            print('time used: {} sec'.format((datetime.now() - start).total_seconds()))
            socket.send_string(json.dumps(result, ensure_ascii=False))
        except Exception as e:
            traceback.print_exc()
            socket.send_string(json.dumps({"error":str(e)}, ensure_ascii=False))
예제 #3
0
def run():
    curr_dir = os.path.dirname(os.path.abspath(__file__))
    model_dir = os.path.join(curr_dir, './model/')
    model_file = os.path.join(model_dir, 'model.pth')

    try:
        agent = Agent()
        # agent.limit_piece_count = 8
        # agent.limit_max_height = 10
        env = TetrominoEnv(agent.tetromino)
        # 神经网络的价值策略
        net_policy = PolicyValueNet(10, 20, 5, model_file=model_file)
        mcts_ai_player = MCTSPlayer(net_policy.policy_value_fn,
                                    c_puct=1,
                                    n_playout=64)
        # agent.start_play(mcts_ai_player, env)
        while not agent.terminal:
            if agent.curr_player == 0:
                # act_probs, value = net_policy.policy_value_fn(agent)
                # act = max(act_probs,  key=lambda act_prob: act_prob[1])[0]
                # print(act, act_probs, value)
                act = mcts_ai_player.get_action(agent)
            else:
                act = 4
            agent.step(act, env)

        agent.print()
    except KeyboardInterrupt:
        print('quit')
예제 #4
0
파일: test.py 프로젝트: one-leaf/pytorch
def run():
    curr_dir = os.path.dirname(os.path.abspath(__file__))
    model_dir = os.path.join(curr_dir, './model/')
    model_file = os.path.join(model_dir, 'model-cnn.pth')

    try:
        agent = Agent()
        agent.limit_piece_count = 0
        agent.limit_max_height = 10
        # env = TetrominoEnv(agent.tetromino)
        # 神经网络的价值策略
        net_policy = PolicyValueNet(10, 20, 5, model_file=model_file)
        mcts_ai_player = MCTSPlayer(net_policy.policy_value_fn,
                                    c_puct=1,
                                    n_playout=64)
        # agent.start_play(mcts_ai_player, env)
        while not agent.terminal:
            act = mcts_ai_player.get_action(agent)
            # agent.step(act, env)
            agent.step(act)
            print(agent.get_availables())
            agent.print2(True)
    except KeyboardInterrupt:
        print('quit')
예제 #5
0
class BoardCanvas(tk.Canvas):
	"""Apply the tkinter Canvas Widget to plot the game board and stones."""
	
	def __init__(self, master=None, height=0, width=0):
		
		tk.Canvas.__init__(self, master, height=height, width=width)
		self.draw_gameBoard()
		self.turn = BLACK
		self.undo = False
		self.depth = 2
		self.prev_exist = False
		self.prev_row = 0
		self.prev_col = 0

		self.initPlayers()
		

	def initPlayers(self):
		self.width = 9
		self.height = 9
		self.board = Board(width=self.width, height=self.height, n_in_row=5)
		self.mcts_player = MCTSPlayer(c_puct=5, n_playout=1000)
		self.human_player = HumanPlayer()

		self.start_player = 0	# 0 - human, 1 - mcts_player

		self.board.init_board(self.start_player)
		p1, p2 = self.board.players
		self.human_player.set_player_id(p1)
		self.mcts_player.set_player_id(p2)
		self.players = {p2: self.mcts_player, p1: self.human_player}
		self.board.show(self.human_player.playerId, self.mcts_player.playerId)


	def draw_gameBoard(self):
		"""Plot the game board."""

		# 9 horizontal lines
		for i in range(9):
			start_pixel_x = (i + 1) * 30
			start_pixel_y = (0 + 1) * 30
			end_pixel_x = (i + 1) * 30
			end_pixel_y = (8 + 1) * 30
			self.create_line(start_pixel_x, start_pixel_y, end_pixel_x, end_pixel_y)

		# 9 vertical lines
		for j in range(9):
			start_pixel_x = (0 + 1) * 30
			start_pixel_y = (j + 1) * 30
			end_pixel_x = (8 + 1) * 30
			end_pixel_y = (j + 1) * 30
			self.create_line(start_pixel_x, start_pixel_y, end_pixel_x, end_pixel_y)

		# place a "star" to particular intersections 
		self.draw_star(2, 2)
		self.draw_star(6, 2)
		self.draw_star(4, 4)
		self.draw_star(2, 6)
		self.draw_star(6, 6)


	def draw_star(self, row, col):
		"""Draw a "star" on a given intersection
		
		Args:
			row, col (i.e. coord of an intersection)
		"""
		start_pixel_x = (row + 1) * 30 - 2
		start_pixel_y = (col + 1) * 30 - 2
		end_pixel_x = (row + 1) * 30 + 2
		end_pixel_y = (col + 1) * 30 + 2
		
		self.create_oval(start_pixel_x, start_pixel_y, end_pixel_x, end_pixel_y, fill=GoBoardUtil.color_string(BLACK))


	def draw_stone(self, row, col):
		"""Draw a stone (with a circle on it to denote latest move) on a given intersection.
		
		Specify the color of the stone depending on the turn.
		
		Args:
			row, col (i.e. coord of an intersection)
		"""

		inner_start_x = (row + 1) * 30 - 4
		inner_start_y = (col + 1) * 30 - 4
		inner_end_x = (row + 1) * 30 + 4
		inner_end_y = (col + 1) * 30 + 4

		outer_start_x = (row + 1) * 30 - 6
		outer_start_y = (col + 1) * 30 - 6
		outer_end_x = (row + 1) * 30 + 6
		outer_end_y = (col + 1) * 30 + 6

		start_pixel_x = (row + 1) * 30 - 10
		start_pixel_y = (col + 1) * 30 - 10
		end_pixel_x = (row + 1) * 30 + 10
		end_pixel_y = (col + 1) * 30 + 10

		self.create_oval(start_pixel_x, start_pixel_y, end_pixel_x, end_pixel_y, fill=GoBoardUtil.color_string(self.turn))
		self.create_oval(outer_start_x, outer_start_y, outer_end_x, outer_end_y, fill=GoBoardUtil.color_string(GoBoardUtil.opponent(self.turn)))
		self.create_oval(inner_start_x, inner_start_y, inner_end_x, inner_end_y, fill=GoBoardUtil.color_string(self.turn))


	def draw_plain_stone(self, row, col):
		"""Draw the plain stone with single color.
		Specify the color of the stone depending on the turn.
		Args:
			row, col (i.e. coord of an intersection)
		"""
		start_pixel_x = (row + 1) * 30 - 10
		start_pixel_y = (col + 1) * 30 - 10
		end_pixel_x = (row + 1) * 30 + 10
		end_pixel_y = (col + 1) * 30 + 10
		self.create_oval(start_pixel_x, start_pixel_y, end_pixel_x, end_pixel_y, fill=GoBoardUtil.color_string(self.turn))


	def draw_prev_stone(self, row, col):
		"""Draw the previous stone with single color.
		
		Specify the color of the stone depending on the turn.
		
		Args:
			row, col (i.e. coord of an intersection)
		"""
		
		start_pixel_x = (row + 1) * 30 - 10
		start_pixel_y = (col + 1) * 30 - 10
		end_pixel_x = (row + 1) * 30 + 10
		end_pixel_y = (col + 1) * 30 + 10

		self.create_oval(start_pixel_x, start_pixel_y, end_pixel_x, end_pixel_y, fill=GoBoardUtil.color_string(GoBoardUtil.opponent(self.turn)))


	def isValidClickPos(self, event, row, col):
		"""Since there is only one intersection such that the distance between it 
		and where the user clicks is less than 9, it is not necessary to find 
		the actual least distance
		"""
		pixel_x = (row + 1) * 30
		pixel_y = (col + 1) * 30
		square_x = math.pow((event.x - pixel_x), 2)
		square_y = math.pow((event.y - pixel_y), 2)
		move = self.board.location_to_move([row, col])
		return math.sqrt(square_x + square_y) < 9 and self.board and self.board.valid_move(move)


	def check_win(self):
		"""If the user wins the game, end the game and unbind."""
		end, winner = self.board.game_end()
		if end:
			if winner != -1:
				message = GoBoardUtil.color_string(self.turn).upper() + " WINS"
				print("{} WINS".format(self.players[winner]))
				self.create_text(150, 320, text=message)
			else:
				print("DRAW")
				self.create_text(150, 320, text='DRAW')
			self.unbind(LEFTBUTTON)
		return end, winner


	def gameLoop_human(self, event, turn=False):
		"""The main loop of the game. 
		Note: The game is played on a tkinter window. However, there is some quite useful information 
			printed onto the terminal such as the simple visualizaiton of the board after each turn,
			messages indicating which step the user reaches at, and the game over message. The user
			does not need to look at what shows up on the terminal. 
		
		self.gameBoard.board()[row][col] == 1(black stone) / 2(white stone)
		self.gameBoard.check() == 1(black wins) / 2(white wins)
		
		Args:
			event (the position the user clicks on using a mouse)
		"""
		if turn:
			self.turn = WHITE
		else:
			self.turn = BLACK

		print('Your turn now...\n')
		invalid_pos = True
		for i in range(self.height):
			for j in range(self.width):
				if self.isValidClickPos(event, i, j):
					invalid_pos = False
					row = i
					col = j
					break
		
		if invalid_pos:
			print('Invalid position.\n')
			print('Please re-select a point\n')
			self.bind(LEFTBUTTON, lambda event, arg = turn:self.gameLoop_human(event,arg))	
		else:
			self.draw_plain_stone(row, col)
			#if self.prev_exist == False:
			#	self.prev_exist = True
			#else:
				#self.draw_prev_stone(self.prev_row, self.prev_col)
			#	self.prev_row, self.prev_col = row, col
						# unbind to ensure the user cannot click anywhere until the program
						# has placed a white stone already
			self.unbind(LEFTBUTTON)

			# Place a black stone after determining the position
			move = self.board.location_to_move([row, col])
			self.board.do_move(move)
			self.board.show(self.human_player.playerId, self.mcts_player.playerId)
			print('\n')

			end, winner = self.check_win()
			if end:
				return winner
			self.unbind(LEFTBUTTON)
			self.update()
			self.gameLoop_robot(turn)
	
	
	def gameLoop_robot(self, turn=False, start=False):
		if turn:
			self.turn = BLACK
		else:
			self.turn = WHITE	
		print('Program is thinking now...')	
		move = self.mcts_player.get_action(self.board,start)
		self.board.do_move(move)
		row, col = self.board.move_to_location(move)
		coord = '%s%s'%(chr(ord('A') + row), col + 1)
		print('Program has moved to {}\n'.format(coord))
		self.draw_plain_stone(row,col)
		#if self.prev_exist == False:
		#	self.prev_exist = True
		#else:
			#self.draw_prev_stone(self.prev_row, self.prev_col)
		#self.prev_row, self.prev_col = row, col
		self.board.show(self.human_player.playerId, self.mcts_player.playerId)
		print('\n')
		# bind after the program makes its move so that the user can continue to play
		self.bind(LEFTBUTTON, lambda event, arg = turn:self.gameLoop_human(event,arg))

		end, winner = self.check_win()
		if end:
			return winner
예제 #6
0
    def collect_selfplay_data(self):
        """收集自我对抗数据用于训练"""
        print("TRAIN Self Play starting ...")

        jsonfile = os.path.join(data_dir, "result.json")

        # 游戏代理
        agent = Agent()

        max_game_num = 1
        agentcount, agentreward, piececount, agentscore = 0, 0, 0, 0

        borads = []
        game_num = 0

        cpuct_first_flag = random.random() > 0.5

        # 尽量不要出现一样的局面
        game_keys = []
        game_datas = []
        # 开始一局游戏
        for _ in count():
            start_time = time.time()
            game_num += 1
            print('start game :', game_num, 'time:',
                  datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))

            result = self.read_status_file(jsonfile)
            print("QVal:", result["QVal"])

            # c_puct 参数自动调节,step=0.1
            cpuct_list = []
            for cp in result["cpuct"]:
                cpuct_list.append(cp)
                if len(cpuct_list) == 2: break
            cpuct_list.sort()

            print("cpuct:", result["cpuct"])

            if cpuct_first_flag:
                cpuct = float(cpuct_list[0])
            else:
                cpuct = float(cpuct_list[1])
            cpuct_first_flag = not cpuct_first_flag

            print("c_puct:", cpuct, "n_playout:", self.n_playout)
            policy_value_net = PolicyValueNet(GAME_WIDTH,
                                              GAME_HEIGHT,
                                              GAME_ACTIONS_NUM,
                                              model_file=model_file)
            player = MCTSPlayer(policy_value_net.policy_value_fn,
                                c_puct=cpuct,
                                n_playout=self.n_playout)

            _data = {
                "steps": [],
                "shapes": [],
                "last_state": 0,
                "score": 0,
                "piece_count": 0
            }
            # game = copy.deepcopy(agent)
            game = Agent(isRandomNextPiece=False)

            if game_num == 1 or game_num == max_game_num:
                game.show_mcts_process = True

            piece_idx = []

            for i in count():
                _step = {"step": i}
                _step["state"] = game.current_state()
                _step["piece_count"] = game.piececount
                _step["shape"] = game.fallpiece["shape"]
                _step["piece_height"] = game.pieceheight

                if game_num == 1:
                    action, move_probs = player.get_action(game,
                                                           temp=self.temp,
                                                           return_prob=1,
                                                           need_random=False)
                else:
                    action, move_probs = player.get_action(game,
                                                           temp=self.temp,
                                                           return_prob=1,
                                                           need_random=False)

                    if game.get_key() in game_keys:
                        print(game.steps, game.piececount,
                              game.fallpiece["shape"], game.piecesteps, "key:",
                              game.get_key(), "key_len:", len(game_keys))
                        action = random.choice(game.get_availables())

                _, reward = game.step(action)

                _step["key"] = game.get_key()
                # 这里不鼓励多行消除
                _step["reward"] = 1 if reward > 0 else 0
                _step["action"] = action
                _step["move_probs"] = move_probs

                _data["shapes"].append(_step["shape"])
                _data["steps"].append(_step)

                # 这里的奖励是消除的行数
                if reward > 0:
                    result = self.read_status_file(jsonfile)
                    if result["curr"]["height"] == 0:
                        result["curr"]["height"] = game.pieceheight
                    else:
                        result["curr"]["height"] = round(
                            result["curr"]["height"] * 0.99 +
                            game.pieceheight * 0.01, 2)
                    result["shapes"][_step["shape"]] += reward

                    # 如果是第一次奖励,记录当前的是第几个方块
                    if game.score == reward:
                        if result["first_reward"] == 0:
                            result["first_reward"] = game.piececount
                        else:
                            result["first_reward"] = result[
                                "first_reward"] * 0.99 + game.piececount * 0.01

                        # 如果第一次的奖励低于平均数,则将前面的几个方块也进行奖励
                        if game.piececount < result["first_reward"]:
                            for idx in piece_idx:
                                _data["steps"][idx]["reward"] = 0.5

                    json.dump(result, open(jsonfile, "w"), ensure_ascii=False)
                    print("#"*40, 'score:', game.score, 'height:', game.pieceheight, 'piece:', game.piececount, "shape:", game.fallpiece["shape"], \
                        'step:', i, "step time:", round((time.time()-start_time)/i,3), "#"*40)

                # 记录当前的方块放置的 idx
                if game.state != 0:
                    piece_idx.append(i)

                # 方块的个数越多越好
                if game.terminal or (reward > 0 and game.pieceheight > 8):
                    _game_last_reward = 0  # game.getNoEmptyCount()/200.
                    _data["reward"] = _game_last_reward
                    _data["score"] = game.score
                    _data["piece_count"] = game.piececount

                    # 更新状态
                    game_reward = _game_last_reward + game.score

                    result = self.read_status_file(jsonfile)
                    if result["QVal"] == 0:
                        result["QVal"] = game_reward
                    else:
                        result["QVal"] = result[
                            "QVal"] * 0.999 + game_reward * 0.001
                    paytime = time.time() - start_time
                    steptime = paytime / game.steps
                    if result["time"]["agent_time"] == 0:
                        result["time"]["agent_time"] = paytime
                        result["time"]["step_time"] = steptime
                    else:
                        result["time"]["agent_time"] = round(
                            result["time"]["agent_time"] * 0.99 +
                            paytime * 0.01, 3)
                        d = game.steps / 10000.0
                        if d > 1: d = 0.99
                        result["time"]["step_time"] = round(
                            result["time"]["step_time"] * (1 - d) +
                            steptime * d, 3)

                    # 记录当前cpuct的统计结果
                    if str(cpuct) in result["cpuct"]:
                        result["cpuct"][str(cpuct)] = result["cpuct"][str(
                            cpuct)] * 0.99 + game_reward * 0.01

                    if game_reward > result["best"]["reward"]:
                        result["best"]["reward"] = game_reward
                        result["best"]["pieces"] = game.piececount
                        result["best"]["score"] = game.score
                        result["best"]["agent"] = result["agent"] + agentcount

                    result["agent"] += 1
                    result["curr"]["reward"] += game.score
                    result["curr"]["pieces"] += game.piececount
                    result["curr"]["agent1000"] += 1
                    result["curr"]["agent100"] += 1
                    json.dump(result, open(jsonfile, "w"), ensure_ascii=False)

                    game.print()
                    print(game_num, 'reward:', game.score, "Qval:",
                          game_reward, 'len:', i, "piececount:",
                          game.piececount, "time:",
                          time.time() - start_time)
                    print("pay:", time.time() - start_time, "s\n")
                    agentcount += 1
                    agentscore += game.score
                    agentreward += game_reward
                    piececount += game.piececount

                    break

            for step in _data["steps"]:
                if not step["key"] in game_keys:
                    game_keys.append(step["key"])

            game_datas.append(_data)

            borads.append(game.board)

            # 如果训练样本超过10000,则停止训练
            if len(game_keys) > 10000: break

            # 如果训练次数超过了最大次数,则直接终止训练
            if game_num >= max_game_num: break

        # 打印borad:
        from game import blank
        for y in range(agent.height):
            line = ""
            for b in borads:
                line += "| "
                for x in range(agent.width):
                    if b[x][y] == blank:
                        line += "  "
                    else:
                        line += "%s " % b[x][y]
            print(line)
        print((" " + " -" * agent.width + " ") * len(borads))

        ## 放弃 按0.50的衰减更新reward
        # 只关注最后一次得分方块的所有步骤,将消行方块的所有步骤的得分都设置为1
        for data in game_datas:
            step_count = len(data["steps"])
            piece_count = -1
            v = 0
            vlist = []
            for i in range(step_count - 1, -1, -1):
                if piece_count != data["steps"][i]["piece_count"]:
                    piece_count = data["steps"][i]["piece_count"]
                    v = data["steps"][i][
                        "reward"]  # 0.5*v+data["steps"][i]["reward"]
                    if v > 1: v = 1
                    vlist.insert(0, v)
                data["steps"][i]["reward"] = v
            print(vlist)

        # 总得分为 消行奖励  + (本局消行奖励-平均每局消行奖励/平均每局消行奖励)
        # for data in game_datas:
        #     step_count = len(data["steps"])
        #     weight = (data["score"]-result["QVal"])/result["QVal"]
        #     for i in range(step_count):
        #         # if data["steps"][i]["reward"] < 1:
        #         v = data["steps"][i]["reward"] + weight
        #             # if v>1: v=1
        #         data["steps"][i]["reward"] = v

        # print("fixed reward")
        # for data in game_datas:
        #     step_count = len(data["steps"])
        #     piece_count = -1
        #     vlist=[]
        #     for i in range(step_count):
        #         if piece_count!=data["steps"][i]["piece_count"]:
        #             piece_count = data["steps"][i]["piece_count"]
        #             vlist.append(data["steps"][i]["reward"])
        #     print("score:", data["score"], "piece_count:", data["piece_count"],  [round(num, 2) for num in vlist])

        # 状态    概率      本步表现 本局奖励
        states, mcts_probs, values, score = [], [], [], []

        for data in game_datas:
            for step in data["steps"]:
                states.append(step["state"])
                mcts_probs.append(step["move_probs"])
                values.append(step["reward"])
                score.append(data["score"])

        # # 用于统计shape的std
        # pieces_idx={"t":[], "i":[], "j":[], "l":[], "s":[], "z":[], "o":[]}

        # var_keys = set()

        # for data in game_datas:
        #     for shape in set(data["shapes"]):
        #         var_keys.add(shape)
        # step_key_name = "shape"

        # for key in var_keys:
        #     _states, _mcts_probs, _values = [], [], []
        #     # _pieces_idx={"t":[], "i":[], "j":[], "l":[], "s":[], "z":[], "o":[]}
        #     for data in game_datas:
        #         for step in data["steps"]:
        #             if step[step_key_name]!=key: continue
        #             _states.append(step["state"])
        #             _mcts_probs.append(step["move_probs"])
        #             _values.append(step["reward"])
        #             # _pieces_idx[step["shape"]].append(len(values)+len(_values)-1)

        #     if len(_values)==0: continue

        #     # 重新计算
        #     curr_avg_value = sum(_values)/len(_values)
        #     curr_std_value = np.std(_values)
        #     if curr_std_value<0.01: continue

        #     # for shape in _pieces_idx:
        #     #     pieces_idx[shape].extend(_pieces_idx[shape])

        #     _normalize_vals = []
        #     # 用正态分布的方式重新计算
        #     curr_std_value_fix = curr_std_value + 1e-8 # * (2.0**0.5) # curr_std_value / result["vars"]["std"]
        #     for v in _values:
        #         #标准化的标准差为 (x-μ)/(σ/std), std 为 1 # 1/sqrt(2)
        #         _nv = (v-curr_avg_value)/curr_std_value_fix
        #         if _nv <-1 : _nv = -1
        #         if _nv >1  : _nv = 1
        #         if _nv == 0: _nv = 1e-8
        #         _normalize_vals.append(_nv)

        #     # 将最好的一步的值设置为1
        #     # max_normalize_val = max(_normalize_vals)-1
        #     # for i in range(len(_normalize_vals)):
        #     #     _normalize_vals[i] -= max_normalize_val

        #     print(key, len(_normalize_vals), "max:", max(_normalize_vals), "min:", min(_normalize_vals), "std:", curr_std_value)

        #     states.extend(_states)
        #     mcts_probs.extend(_mcts_probs)
        #     values.extend(_normalize_vals)
        #     result["vars"]["max"] = result["vars"]["max"]*0.999 + max(_normalize_vals)*0.001
        #     result["vars"]["min"] = result["vars"]["min"]*0.999 + min(_normalize_vals)*0.001
        #     result["vars"]["avg"] = result["vars"]["avg"]*0.999 + np.average(_normalize_vals)*0.001
        #     result["vars"]["std"] = result["vars"]["std"]*0.999 + np.std(_normalize_vals)*0.001
        #     # _states, _mcts_probs, _values = [], [], []

        # # if result["vars"]["max"]>1 or result["vars"]["min"]<-1:
        # #     result["vars"]["std"] = round(result["vars"]["std"]-0.0001,4)
        # # else:
        # #     result["vars"]["std"] = round(result["vars"]["std"]+0.0001,4)

        # json.dump(result, open(jsonfile,"w"), ensure_ascii=False)

        assert len(states) > 0
        assert len(states) == len(values)
        assert len(states) == len(mcts_probs)

        print("TRAIN Self Play end. length:%s value sum:%s saving ..." %
              (len(states), sum(values)))

        # 保存对抗数据到data_buffer
        for obj in self.get_equi_data(states, mcts_probs, values, score):
            filename = "{}.pkl".format(uuid.uuid1())
            savefile = os.path.join(data_wait_dir, filename)
            pickle.dump(obj, open(savefile, "wb"))

        # 打印shape的标准差
        # for shape in pieces_idx:
        #     test_data=[]
        #     for i in pieces_idx[shape]:
        #         if i>=(len(values)): break
        #         test_data.append(values[i])
        #     if len(test_data)==0: continue
        #     print(shape, "len:", len(test_data), "max:", max(test_data), "min:", min(test_data), "std:", np.std(test_data))

        result = self.read_status_file(jsonfile)
        if result["curr"]["agent100"] > 100:
            result["reward"].append(
                round(result["curr"]["reward"] / result["curr"]["agent1000"],
                      2))
            result["pieces"].append(
                round(result["curr"]["pieces"] / result["curr"]["agent1000"],
                      2))
            result["qvals"].append(round(result["QVal"], 2))
            result["height"].append(result["curr"]["height"])
            result["time"]["step_times"].append(result["time"]["step_time"])
            result["curr"]["agent100"] -= 100
            while len(result["reward"]) > 200:
                result["reward"].remove(result["reward"][0])
            while len(result["pieces"]) > 200:
                result["pieces"].remove(result["pieces"][0])
            while len(result["qvals"]) > 200:
                result["qvals"].remove(result["qvals"][0])
            while len(result["height"]) > 200:
                result["height"].remove(result["height"][0])
            while len(result["time"]["step_times"]) > 200:
                result["time"]["step_times"].remove(
                    result["time"]["step_times"][0])

            # 每100局更新一次cpuct参数
            qval = result["QVal"]
            # cpuct表示概率的可信度
            if result["cpuct"][cpuct_list[0]] > result["cpuct"][cpuct_list[1]]:
                cpuct = round(float(cpuct_list[0]) - 0.01, 2)
                if cpuct <= 0.01:
                    result["cpuct"] = {"0.01": qval, "1.01": qval}
                else:
                    result["cpuct"] = {
                        str(cpuct): qval,
                        str(round(cpuct + 1, 2)): qval
                    }
            else:
                cpuct = round(float(cpuct_list[0]) + 0.01, 2)
                result["cpuct"] = {
                    str(cpuct): qval,
                    str(round(cpuct + 1, 2)): qval
                }

            if max(result["reward"]) == result["reward"][-1]:
                newmodelfile = model_file + "_reward_" + str(
                    result["reward"][-1])
                if not os.path.exists(newmodelfile):
                    policy_value_net.save_model(newmodelfile)

        if result["curr"]["agent1000"] > 1000:
            result["curr"] = {
                "reward": 0,
                "pieces": 0,
                "agent1000": 0,
                "agent100": 0,
                "height": 0
            }

            newmodelfile = model_file + "_" + str(result["agent"])
            if not os.path.exists(newmodelfile):
                policy_value_net.save_model(newmodelfile)
        result["lastupdate"] = datetime.datetime.now().strftime(
            '%Y-%m-%d %H:%M:%S')
        json.dump(result, open(jsonfile, "w"), ensure_ascii=False)