def __init__(self): self.policy_evaluate_size = 100 # 策略评估胜率时的模拟对局次数 self.batch_size = 512 # 训练一批数据的长度 self.max_keep_size = 500000 # 保留最近对战样本个数 平均一局大约400~600个样本, 也就是包含了最近1000次对局数据 # 训练参数 self.learn_rate = 1e-4 self.lr_multiplier = 1.0 # 基于KL的自适应学习率 self.temp = 1 # 概率缩放程度,实际预测0.01,训练采用1 self.n_playout = 1000 # 每个动作的模拟次数 self.play_batch_size = 1 # 每次自学习次数 self.epochs = 1 # 重复训练次数, 推荐是5 self.kl_targ = 0.02 # 策略价值网络KL值目标 # 纯MCTS的模拟数,用于评估策略模型 self.pure_mcts_playout_num = 1000 # 用户纯MCTS构建初始树时的随机走子步数 self.c_puct = 4 # MCTS child权重, 用来调节MCTS中 探索/乐观 的程度 默认 5 self.mcts_win = [0, 0] # 和纯MCTS对战胜率 self.best_win = [0, 0] # 和历史最佳模型对战胜率 if os.path.exists(model_file): # 使用一个训练好的策略价值网络 self.policy_value_net = PolicyValueNet(size, model_file=model_file) else: # 使用一个新的的策略价值网络 self.policy_value_net = PolicyValueNet(size) self.best_policy_value_net = None # 保存历史最佳模型赢的次数,赢的越高,越要继续对战 self.best_model_files_win = {}
def __init__(self): self.policy_evaluate_size = 20 # 策略评估胜率时的模拟对局次数 self.batch_size = 256 # 训练一批数据的长度 self.max_keep_size = 500000 # 保留最近对战样本个数 平均一局大约400~600个样本, 也就是包含了最近1000次对局数据 # 训练参数 self.learn_rate = 1e-5 self.lr_multiplier = 1.0 # 基于KL的自适应学习率 self.temp = 1 # 概率缩放程度,实际预测0.01,训练采用1 self.n_playout = 600 # 每个动作的模拟次数 self.play_batch_size = 1 # 每次自学习次数 self.epochs = 1 # 重复训练次数, 推荐是5 self.kl_targ = 0.02 # 策略价值网络KL值目标 # 纯MCTS的模拟数,用于评估策略模型 self.pure_mcts_playout_num = 4000 # 用户纯MCTS构建初始树时的随机走子步数 self.c_puct = 4 # MCTS child权重, 用来调节MCTS中 探索/乐观 的程度 默认 5 if os.path.exists(model_file): # 使用一个训练好的策略价值网络 self.policy_value_net = PolicyValueNet(size, model_file=model_file) else: # 使用一个新的的策略价值网络 self.policy_value_net = PolicyValueNet(size) print("start data loader") self.dataset = Dataset(data_dir, self.max_keep_size) print("dataset len:",len(self.dataset),"index:",self.dataset.index) print("end data loader")
def testPlay(): print('Test Mode') sess = tf.Session() brain = PolicyValueNet(sess, HEIGHT, WIDTH) game = Sixmok(WIDTH, HEIGHT, brain) saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state('model') saver.restore(sess, ckpt.model_checkpoint_path) one = 0 two = 0 for episode in range(TEST_EPISODE): game.reset() winner, turns, states, actions, winners = game.runSelfPlay() print('%d play : winner is %d' %(episode+1, winner)) print(np.array(states[len(states)-1])) if winner == 1: one += 1 elif winner == 2: two += 1 print("player 1 win : ", one) print("player 2 win : ", two)
def run(): curr_dir = os.path.dirname(os.path.abspath(__file__)) model_dir = os.path.join(curr_dir, './model/') model_file = os.path.join(model_dir, 'vit/model.pth') try: agent = Agent() # agent.limit_piece_count = 0 # agent.limit_max_height = 10 # env = TetrominoEnv(agent.tetromino) # 神经网络的价值策略 net_policy = PolicyValueNet(10, 20, 5, model_file=model_file) # mcts_ai_player = MCTSPlayer(net_policy.policy_value_fn, c_puct=1, n_playout=64) # agent.start_play(mcts_ai_player, env) while not agent.terminal: act_probs, v = net_policy.policy_value_fn(agent) act, act_p = 0, 0 for a, p in act_probs: if p > act_p: act, act_p = a, p # act = mcts_ai_player.get_action(agent) # agent.step(act, env) agent.step(act) # print(agent.get_availables()) os.system("cls") print(act_probs, v) agent.print2() time.sleep(0.1) except KeyboardInterrupt: print('quit')
def run(): size = 15 # 棋盘大小 n_in_row = 5 # 几子连线 curr_dir = os.path.dirname(os.path.abspath(__file__)) model_dir = os.path.join(curr_dir, './model/') model_file = os.path.join(model_dir, 'model_%s_%s.pth' % (size, n_in_row)) try: agent = Agent(size=size, n_in_row=n_in_row) # ############### human VS AI ################### # 神经网络的价值策略 net_policy = PolicyValueNet(size, model_file=model_file) mcts_ai_player = MCTSPlayer(net_policy.policy_value_fn, c_puct=4, n_playout=500, is_selfplay=0) # 纯MCTS玩家 # mcts_player = MCTSPurePlayer(c_puct=5, n_playout=2000) # 人类玩家 human = Human(agent, is_show=1) # 设置 start_player=0 AI先走棋 agent.start_play(mcts_ai_player, human, start_player=0) agent.game.print() agent.env.close() # agent.start_play(human, human, start_player=0 if random.random()>0.5 else 1) except KeyboardInterrupt: print('quit')
def main(debug=False): model_file = os.path.join(curr_dir, "../model/best_model_15_5.pth") policy_value_net = PolicyValueNet(size, model_file=model_file) context = zmq.Context() socket = context.socket(zmq.REP) socket.bind("tcp://*:5555") print("Server start on 5555 port") while True: message = socket.recv() try: message = message.decode('utf-8') actions = json.loads(message) print("Received: %s" % message) start = datetime.now() mcts_player = MCTSPlayer(policy_value_net.policy_value_fn, c_puct=c_puct, n_playout=n_playout, is_selfplay=0) # result = predict game = FiveChess(size=size, n_in_row=n_in_row) for act in actions: step=(act[0],act[1]) game.step_nocheck(step) action, value = mcts_player.get_action(game, return_value=1) result = {"action":action, "value": value} print(result) print('time used: {} sec'.format((datetime.now() - start).total_seconds())) socket.send_string(json.dumps(result, ensure_ascii=False)) except Exception as e: traceback.print_exc() socket.send_string(json.dumps({"error":str(e)}, ensure_ascii=False))
def run(): curr_dir = os.path.dirname(os.path.abspath(__file__)) model_dir = os.path.join(curr_dir, './model/') model_file = os.path.join(model_dir, 'model.pth') try: agent = Agent() # agent.limit_piece_count = 8 # agent.limit_max_height = 10 env = TetrominoEnv(agent.tetromino) # 神经网络的价值策略 net_policy = PolicyValueNet(10, 20, 5, model_file=model_file) mcts_ai_player = MCTSPlayer(net_policy.policy_value_fn, c_puct=1, n_playout=64) # agent.start_play(mcts_ai_player, env) while not agent.terminal: if agent.curr_player == 0: # act_probs, value = net_policy.policy_value_fn(agent) # act = max(act_probs, key=lambda act_prob: act_prob[1])[0] # print(act, act_probs, value) act = mcts_ai_player.get_action(agent) else: act = 4 agent.step(act, env) agent.print() except KeyboardInterrupt: print('quit')
def __init__(self,args,share_model,opti,board_max,param,is_selfplay=True): super().__init__() self._is_selfplay=is_selfplay self.learn_rate = 5e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 100 # num of simulations for each move self.c_puct = 5 self.batch_size = 32 # mini-batch size for training self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.025 self.check_freq = 50 self.game_batch_num = 1500 self.best_win_ratio = 0.0 # num of simulations used for the pure mcts, which is used as the opponent to evaluate the trained policy self.pure_mcts_playout_num = 1000 self.policy_value_net = PolicyValueNet(board_max,board_max,net_params = param) self.mcts = MCTS(self.policy_value_net.policy_value_fn, self.c_puct, self.n_playout) self.batch_size = args.batch_size self.discount = args.discount self.epsilon = args.epsilon self.action_space = args.action_space self.hidden_size = args.hidden_size self.state_space = args.state_space
def collect_selfplay_data(self, i): """收集自我对抗数据用于训练""" # 使用MCTS蒙特卡罗树搜索进行自我对抗 logging.info("TRAIN Self Play starting ...") agent = Agent(size, n_in_row, is_shown=0) # 创建使用策略价值网络来指导树搜索和评估叶节点的MCTS玩家 if i % 2 == 0: mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) pure_mcts_player = None mcts_player.mcts._limit_max_var = False else: if os.path.exists(best_model_file): best_policy_value_net = PolicyValueNet( size, model_file=best_model_file) else: best_policy_value_net = self.policy_value_net mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) pure_mcts_player = MCTSPlayer( best_policy_value_net.policy_value_fn, c_puct=self.c_puct + 0.5, n_playout=self.n_playout, is_selfplay=1) mcts_player.mcts._limit_max_var = False pure_mcts_player.mcts._limit_max_var = False # 开始下棋 winner, play_data = agent.start_self_play(mcts_player, pure_mcts_player, temp=self.temp) agent.game.print() if winner is None or play_data is None: print("give up this agent") return if pure_mcts_player != None: if winner == mcts_player.player: self.c_puct_win[0] = self.c_puct_win[0] + 1 else: self.c_puct_win[1] = self.c_puct_win[1] + 1 play_data = list(play_data)[:] # 采用翻转棋盘来增加样本数据集 play_data = self.get_equi_data(play_data) logging.info("Self Play end. length:%s saving ..." % len(play_data)) logging.info("c_puct:{}/{} = {}/{}".format(self.c_puct, self.c_puct + 0.5, self.c_puct_win[0], self.c_puct_win[1])) # 保存训练数据 for obj in play_data: self.save_wait_data(obj)
def run(self): """启动训练""" try: # 先训练样本100000局 for i in range(100000): logging.info( "TRAIN Batch:{} starting, Size:{}, n_in_row:{}".format( i, size, n_in_row)) # 有 0.2 的概率中间插入一局和历史最佳模型对战样本 if random.random() > 0.8: state, mcts_porb, winner = self.collect_selfplay_data() if i == 0: print("-" * 50, "state", "-" * 50) print(state) print("-" * 50, "mcts_porb", "-" * 50) print(mcts_porb) print("-" * 50, "winner", "-" * 50) print(winner) self.policy_evaluate() rate_of_winning = 0.6 if (i + 1) % self.policy_evaluate_size == 0 or self.best_win[ 1] > (self.policy_evaluate_size * (1 - rate_of_winning)): # if self.mcts_win[0]>self.mcts_win[1]: # self.pure_mcts_playout_num=self.pure_mcts_playout_num+50 # if self.mcts_win[0]<self.mcts_win[1]: # self.pure_mcts_playout_num=self.pure_mcts_playout_num-50 # self.mcts_win=[0, 0] # 如果当前模型的胜率大于等于0.6,保留为最佳模型 v = 1.0 * self.best_win[0] / self.policy_evaluate_size if v >= rate_of_winning: t = os.path.getctime(best_model_file) timeStruct = time.localtime(t) timestr = time.strftime('%Y_%m_%d_%H_%M', timeStruct) os.rename(best_model_file, best_model_file + "." + timestr) self.policy_value_net.save_model(best_model_file) self.best_policy_value_net = None print("save curr modle to best model") else: print("curr:", v, "< 0.65, keep best model") self.best_win = [0, 0] self.policy_value_net = PolicyValueNet( size, model_file=model_file) # 一轮训练完毕后与最佳模型进行对比 # # 如果输了,再训练一次 # if win_ratio<=0.5: # self.policy_evaluate(self.policy_evaluate_size) # print("lost all, add more sample") except KeyboardInterrupt: logging.info('quit')
def __init__(self, mol=None, init_model=None): # params of the board and the game # training params self.learn_rate = 2e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 30 # num of simulations for each move self.c_puct = 1 self.buffer_size = 200 self.batch_size = 200 # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.epochs = 50 # num of train_steps for each update self.kl_targ = 0.2 self.check_freq = 5 self.mol = mol self.play_batch_size = 1 self.game_batch_num = 15 self.in_dim = 1024 self.n_hidden_1 = 1024 self.n_hidden_2 = 1024 self.out_dim = 1 self.output_smi = [] self.output_qed = [] # num of simulations used for the pure mcts, which is used as # the opponent to evaluate the trained policy if init_model: # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(self.in_dim, self.n_hidden_1, self.n_hidden_2, self.out_dim, model_file=init_model) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.in_dim, self.n_hidden_1, self.n_hidden_2, self.out_dim) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1)
def policy_evaluate(self): """ 策略胜率评估:当前模型与最佳模型对战n局看胜率 """ # 如果不存在最佳模型,直接将当前模型保存为最佳模型 if not os.path.exists(best_model_file): self.policy_value_net.save_model(best_model_file) return # 当前训练好的模型 current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) if self.best_policy_value_net is None: self.best_policy_value_net = PolicyValueNet( size, model_file=best_model_file) best_mcts_player = MCTSPlayer( self.best_policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) current_mcts_player.mcts._limit_max_var = False best_mcts_player.mcts._limit_max_var = False agent = Agent(size, n_in_row, is_shown=0) winner, play_data = agent.start_self_evaluate( current_mcts_player, best_mcts_player, temp=self.temp, start_player=sum(self.best_win) % 2) if winner == current_mcts_player.player: self.best_win[0] = self.best_win[0] + 1 print("Curr Model Win!", "win:", self.best_win[0], "lost", self.best_win[1]) if winner == best_mcts_player.player: self.best_win[1] = self.best_win[1] + 1 print("Curr Model Lost!", "win:", self.best_win[0], "lost", self.best_win[1]) agent.game.print() # 保存训练数据 play_data = list(play_data)[:] play_data = self.get_equi_data(play_data) logging.info("Eval Play end. length:%s saving ..." % len(play_data)) for obj in play_data: self.save_wait_data(obj)
def __init__(self): self.game_batch_num = 1000000 # selfplay对战次数 self.batch_size = 512 # data_buffer中对战次数超过n次后开始启动模型训练 # training params self.learn_rate = 1e-4 self.lr_multiplier = 1.0 # 基于KL的自适应学习率 self.temp = 1 # MCTS的概率参数,越大越不肯定,训练时1,预测时1e-3 self.n_playout = 500 # 每个动作的模拟战记录个数 self.play_batch_size = 5 # 每次自学习次数 self.buffer_size = 500000 # cache对次数 self.epochs = 2 # 每次更新策略价值网络的训练步骤数, 推荐是5 self.kl_targ = 0.02 # 策略价值网络KL值目标 self.best_win_ratio = 0.0 self.c_puct = 5 # MCTS child权重, 用来调节MCTS中 探索/乐观 的程度 默认 5 self.policy_value_net = PolicyValueNet(GAME_WIDTH, GAME_HEIGHT, GAME_ACTIONS_NUM, model_file=model_file)
def train(): print('Train Mode') sess = tf.Session() brain = PolicyValueNet(sess, HEIGHT, WIDTH) game = Sixmok(HEIGHT, WIDTH, brain) rewards = tf.placeholder(tf.float32, [None]) tf.summary.scalar('avg.reward/ep.', tf.reduce_mean(rewards)) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) writer = tf.summary.FileWriter('logs', sess.graph) summary_merged = tf.summary.merge_all() epsilon = 1.0 time_step = 0 total_reward_list = [] one = 0 two = 0 for episode in range(LEARNING_EPISODE): game.reset() winner, turns, states, actions, winners = game.runSelfPlay() print('%d play : winner is %d' %(episode+1, winner)) total_reward_list.append(winners[0]) if winner == 1: one += 1 elif winner == 2: two += 1 if (episode+1) % OBSERVE == 0: print("player 1 : ", one) print("player 2 : ", two) brain.train(states, actions, winners, 0.01) if (episode+1) % 10 == 0: summary = sess.run(summary_merged, feed_dict={rewards: total_reward_list}) writer.add_summary(summary, (episode+1)) total_reward_list = [] if (episode+1) % 100 == 0: saver.save(sess, 'model/dqn.ckpt', global_step = (episode+1))
def __init__(self, args, c_puct, n_playout, self_play, shared_lr_mul, shared_g_cnt): super().__init__() self.cuda = args.cuda self._is_selfplay = self_play self.n_playout = n_playout # num of simulations for each move self.c_puct = c_puct self.play_batch_size = 1 self.check_freq = 50 self.game_batch_num = 1500 # num of simulations used for the pure mcts, which is used as the opponent to evaluate the trained policy self.policy_value_net = PolicyValueNet(args.board_max, args.board_max, use_gpu=args.cuda) self.mcts = MCTS(self.policy_value_net.policy_value_fn, self.c_puct, self.n_playout)
def main(): parser = argparse.ArgumentParser(description='Test') parser.add_argument('--player1', default='AlphaGo', help='player1 tpye') parser.add_argument('--player2', default='MCTS', help='player2 tpye') parser.add_argument('--self_play', default=0, type=int, help='1 means self play, 0 means not') args = parser.parse_args() n = 4 width, height = 6, 6 AlphaGoNet = PolicyValueNet(width, height) try: board = Board(width=width, height=height, n_in_row=n) game = Game(board) if args.self_play: player = AlphaGoPlayer(NN_fn=AlphaGoNet.policy_value_fn) game.AlphaGo_self_play(player, is_shown=1) else: if args.player1 == 'human': player1 = Human() if args.player1 == 'MCTS': player1 = MCTSPlayer() if args.player1 == 'AlphaGo': AlphaGoNet.policy_value_net.load_state_dict( torch.load('model/current_best.mdl')) player1 = AlphaGoPlayer(NN_fn=AlphaGoNet.policy_value_fn, n_iteration=1000) if args.player2 == 'human': player2 = Human() if args.player2 == 'MCTS': player2 = MCTSPlayer() if args.player2 == 'AlphaGo': player2 = AlphaGoPlayer(NN_fn=AlphaGoNet.policy_value_fn) # set start_player=0 for human first game.start_play(player1, player2, start_player=0, is_shown=1) except KeyboardInterrupt: print('\n\rquit')
def run(): curr_dir = os.path.dirname(os.path.abspath(__file__)) model_dir = os.path.join(curr_dir, './model/') model_file = os.path.join(model_dir, 'model-cnn.pth') try: agent = Agent() agent.limit_piece_count = 0 agent.limit_max_height = 10 # env = TetrominoEnv(agent.tetromino) # 神经网络的价值策略 net_policy = PolicyValueNet(10, 20, 5, model_file=model_file) mcts_ai_player = MCTSPlayer(net_policy.policy_value_fn, c_puct=1, n_playout=64) # agent.start_play(mcts_ai_player, env) while not agent.terminal: act = mcts_ai_player.get_action(agent) # agent.step(act, env) agent.step(act) print(agent.get_availables()) agent.print2(True) except KeyboardInterrupt: print('quit')
def run(self): """启动训练,并动态调整c_puct参数""" try: for i in range(10000): logging.info( "TRAIN Batch:{} starting, Size:{}, n_in_row:{}".format( i, size, n_in_row)) self.collect_selfplay_data(i) if sum(self.c_puct_win) >= 10: if self.c_puct_win[0] > self.c_puct_win[1]: self.c_puct = self.c_puct - 0.1 if self.c_puct_win[0] < self.c_puct_win[1]: self.c_puct = self.c_puct + 0.1 self.c_puct = max(0.2, self.c_puct) self.c_puct = min(10, self.c_puct) self.policy_value_net = PolicyValueNet( size, model_file=model_file) self.c_puct_win = [0, 0] except KeyboardInterrupt: logging.info('quit')
def collect_selfplay_data(self): """收集自我对抗数据用于训练""" # 使用MCTS蒙特卡罗树搜索进行自我对抗 logging.info("TRAIN Self Play starting ...") agent = Agent(size, n_in_row, is_shown=0) # 创建使用策略价值网络来指导树搜索和评估叶节点的MCTS玩家 mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=0) files = os.listdir(model_dir) his_best_model_files = [] his_best_model_weights = [] for file in files: if file.startswith("best_model_15_5.pth."): if file not in self.best_model_files_win: self.best_model_files_win[file] = 0 his_best_model_files.append(file) his_best_model_weights.append(self.best_model_files_win[file]) weights_min = min(his_best_model_weights) weights_max = max(his_best_model_weights) for i in range(len(his_best_model_weights)): if weights_max == weights_min: his_best_model_weights[i] = 1. / len(his_best_model_weights) else: his_best_model_weights[i] = 1.0 * ( his_best_model_weights[i] - weights_min) / (weights_max - weights_min) curr_best_model_file = random.choices( his_best_model_files, weights=his_best_model_weights)[0] print(self.best_model_files_win) print("loading", curr_best_model_file) curr_best_policy_value_net = PolicyValueNet(size, model_file=os.path.join( model_dir, curr_best_model_file)) his_best_mcts_player = MCTSPlayer( curr_best_policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=0) his_best_mcts_player.mcts._limit_max_var = False mcts_player.mcts._limit_max_var = False # 有一定几率和纯MCTS对抗 # r = random.random() # if r>0.5: # pure_mcts_player = MCTSPurePlayer(c_puct=self.c_puct, n_playout=self.pure_mcts_playout_num) # print("AI VS MCTS, pure_mcts_playout_num:", self.pure_mcts_playout_num) # else: # pure_mcts_player = None # 开始下棋 winner, play_data = agent.start_self_play(mcts_player, his_best_mcts_player, temp=self.temp) if not his_best_mcts_player is None: if winner == mcts_player.player: self.mcts_win[0] = self.mcts_win[0] + 1 # self.pure_mcts_playout_num=min(2000, self.pure_mcts_playout_num+100) print("Curr Model Win!", "win:", self.mcts_win[0], "lost", self.mcts_win[1], "playout_num", self.pure_mcts_playout_num) if winner == his_best_mcts_player.player: self.mcts_win[1] = self.mcts_win[1] + 1 self.pure_mcts_playout_num = max( 500, self.pure_mcts_playout_num - 100) print("Curr Model Lost!", "win:", self.mcts_win[0], "lost", self.mcts_win[1], "playout_num", self.pure_mcts_playout_num) agent.game.print() play_data = list(play_data)[:] if winner == his_best_mcts_player.player: self.best_model_files_win[curr_best_model_file] += 1 if winner == mcts_player.player: self.best_model_files_win[curr_best_model_file] -= 1 # 采用翻转棋盘来增加样本数据集 play_data = self.get_equi_data(play_data) logging.info("Self Play end. length:%s saving ..." % len(play_data)) # 保存训练数据 for obj in play_data: self.save_wait_data(obj) return play_data[-1]
def run(self): """启动训练""" try: print("start data loader") self.dataset = Dataset(data_dir, self.buffer_size) print("end data loader") self.policy_value_net = PolicyValueNet(GAME_WIDTH, GAME_HEIGHT, GAME_ACTIONS_NUM, model_file=model_file) self.policy_value_net.save_model(model_file + ".bak") # step = 0 # # 如果训练数据一半都不到,就先攒训练数据 # if self.dataset.curr_game_batch_num/self.dataset.buffer_size<0.5: # for _ in range(8): # logging.info("TRAIN Batch:{} starting".format(self.dataset.curr_game_batch_num,)) # # n_playout=self.n_playout # # self.n_playout=8 # self.collect_selfplay_data() # # self.n_playout=n_playout # logging.info("TRAIN Batch:{} end".format(self.dataset.curr_game_batch_num,)) # step += 1 dataset_len = len(self.dataset) training_loader = torch.utils.data.DataLoader( self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=1, ) old_probs = None test_batch = None totle = 0 for i, data in enumerate(training_loader): # 计划训练批次 if i == 0: for obj in data: print(obj[0]) # 使用对抗数据重新训练策略价值网络模型 totle_value, b_loss, v_loss, p_loss, entropy = self.policy_update( data, self.epochs) totle = totle + totle_value if (i + 1) % 100 == 0: logging.info(("TRAIN idx {} : {} / {} b_loss:{:.5f}, v_loss:{:.5f}, p_loss:{:.5f}, entropy:{:.5f}")\ .format(i, i*self.batch_size, dataset_len, b_loss, v_loss, p_loss, entropy)) # 动态调整学习率 if old_probs is None: test_batch, _, _, _ = data old_probs, old_value, old_bvalue = self.policy_value_net.policy_value( test_batch) else: new_probs, new_value, new_bvalue = self.policy_value_net.policy_value( test_batch) kl = np.mean( np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) # logging.info("probs var before: {} now: {}".format(np.var(old_probs), np.var(new_probs))) # logging.info("value var before: {} now: {} [0]: {} {}".format(np.var(old_value), np.var(new_value), old_value[0], new_value[0])) old_probs = None if kl > self.kl_targ * 2: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 else: continue logging.info("kl:{} lr_multiplier:{} lr:{}".format( kl, self.lr_multiplier, self.learn_rate * self.lr_multiplier)) self.policy_value_net.save_model(model_file) # 收集自我对抗数据 # for _ in range(self.play_batch_size): # self.collect_selfplay_data() # logging.info("TRAIN {} self-play end, size: {}".format(self.dataset.curr_game_batch_num, self.dataset.curr_size())) win = (totle + dataset_len // 2) // 2 print("win:", win, "lost:", dataset_len // 2 - win, "prop:", win / (dataset_len // 2)) except KeyboardInterrupt: logging.info('quit')
def run(self): """启动训练""" try: print("start data loader") self.dataset = Dataset(data_dir, self.buffer_size) newsample=self.dataset.newsample self.testdataset = TestDataset(data_dir, 10, newsample) print("end data loader") self.policy_value_net = PolicyValueNet(GAME_WIDTH, GAME_HEIGHT, GAME_ACTIONS_NUM, model_file=model_file) self.policy_value_net.save_model(model_file+".bak") # step = 0 # # 如果训练数据一半都不到,就先攒训练数据 # if self.dataset.curr_game_batch_num/self.dataset.buffer_size<0.5: # for _ in range(8): # logging.info("TRAIN Batch:{} starting".format(self.dataset.curr_game_batch_num,)) # # n_playout=self.n_playout # # self.n_playout=8 # self.collect_selfplay_data() # # self.n_playout=n_playout # logging.info("TRAIN Batch:{} end".format(self.dataset.curr_game_batch_num,)) # step += 1 dataset_len = len(self.dataset) training_loader = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=1,) testing_loader = torch.utils.data.DataLoader(self.testdataset, batch_size=self.batch_size, shuffle=True, num_workers=1,) old_probs = None test_batch = None totle = 0 for i, data in enumerate(training_loader): # 计划训练批次 if i==0: _batch, _probs, _win = data print(_batch[0][0]) print(_batch[0][1]) print(_probs[0]) print(_win[0]) # 使用对抗数据重新训练策略价值网络模型 totle_value, v_loss, p_loss, entropy = self.policy_update(data, self.epochs) totle = totle + totle_value if i%10 == 0: logging.info(("TRAIN idx {} : {} / {} v_loss:{:.5f}, p_loss:{:.5f}, entropy:{:.5f}")\ .format(i, i*self.batch_size, dataset_len, v_loss, p_loss, entropy)) # 动态调整学习率 if old_probs is None: test_batch, test_probs, test_win = next(iter(testing_loader)) old_probs, old_value = self.policy_value_net.policy_value(test_batch) else: new_probs, new_value = self.policy_value_net.policy_value(test_batch) kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if i % 50 == 0: logging.info("probs[0] old:{}".format(old_probs[0])) logging.info("probs[0] new:{}".format(new_probs[0])) logging.info("probs[0] tg: {}".format(test_probs[0])) maxlen = min(10, len(test_win)) for j in range(maxlen): logging.info("value[0] old:{} new:{} tg:{}".format(old_value[j][0], new_value[j][0], test_win[j])) old_probs = None if kl > self.kl_targ * 2: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 else: continue logging.info("kl:{} lr_multiplier:{} lr:{}".format(kl, self.lr_multiplier, self.learn_rate*self.lr_multiplier)) self.policy_value_net.save_model(model_file) # 收集自我对抗数据 # for _ in range(self.play_batch_size): # self.collect_selfplay_data() # logging.info("TRAIN {} self-play end, size: {}".format(self.dataset.curr_game_batch_num, self.dataset.curr_size())) # x - y = totle # x + y = dataset_len win = (totle+dataset_len)//2 print("win:", win, "lost:", dataset_len-win, "prop:", win/dataset_len) except KeyboardInterrupt: logging.info('quit')
def run(self): """启动训练""" try: print("start data loader") self.dataset = Dataset(data_dir, self.buffer_size) self.testdataset = copy.copy(self.dataset) self.testdataset.test=True print("end data loader") self.policy_value_net = PolicyValueNet(GAME_WIDTH, GAME_HEIGHT, GAME_ACTIONS_NUM, model_file=model_file) self.policy_value_net.save_model(model_file+".bak") dataset_len = len(self.dataset) training_loader = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False, num_workers=0) testing_loader = torch.utils.data.DataLoader(self.testdataset, batch_size=self.batch_size, shuffle=False,num_workers=0) old_probs = None test_batch = None for i, data in enumerate(training_loader): # 计划训练批次 if i==0: _batch, _qvals, _actions = data for j in range(len(_batch[0])): print(_batch[0][j]) print(_qvals[0]) print(_actions[0]) # 使用对抗数据重新训练策略价值网络模型 _, v_loss, p_loss, entropy = self.policy_update(data, self.epochs) if i%10 == 0: print(("TRAIN idx {} : {} / {} v_loss:{:.5f}, p_loss:{:.5f}, entropy:{:.5f}")\ .format(i, i*self.batch_size, dataset_len, v_loss, p_loss, entropy)) # 动态调整学习率 if old_probs is None: test_batch, test_probs, test_valus = next(iter(testing_loader)) old_probs, old_value = self.policy_value_net.policy_value(test_batch) else: new_probs, new_value = self.policy_value_net.policy_value(test_batch) kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if i % 50 == 0: print("probs[0] old:{}".format(old_probs[0])) print("probs[0] new:{}".format(new_probs[0])) print("probs[0] dst:{}".format(test_probs[0])) maxlen = min(10, len(test_batch)) for j in range(maxlen): print("value[0] old:{} new:{} tg:{}".format(old_value[j][0], new_value[j][0], test_valus[j])) old_probs = None if kl > self.kl_targ * 2: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 else: continue print("kl:{} lr_multiplier:{} lr:{}".format(kl, self.lr_multiplier, self.learn_rate*self.lr_multiplier)) self.policy_value_net.save_model(model_file) except KeyboardInterrupt: print('quit')
def collect_selfplay_data(self): """收集自我对抗数据用于训练""" print("TRAIN Self Play starting ...") jsonfile = os.path.join(data_dir, "result.json") # 游戏代理 agent = Agent() max_game_num = 1 agentcount, agentreward, piececount, agentscore = 0, 0, 0, 0 borads = [] game_num = 0 cpuct_first_flag = random.random() > 0.5 # 尽量不要出现一样的局面 game_keys = [] game_datas = [] # 开始一局游戏 for _ in count(): start_time = time.time() game_num += 1 print('start game :', game_num, 'time:', datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')) result = self.read_status_file(jsonfile) print("QVal:", result["QVal"]) # c_puct 参数自动调节,step=0.1 cpuct_list = [] for cp in result["cpuct"]: cpuct_list.append(cp) if len(cpuct_list) == 2: break cpuct_list.sort() print("cpuct:", result["cpuct"]) if cpuct_first_flag: cpuct = float(cpuct_list[0]) else: cpuct = float(cpuct_list[1]) cpuct_first_flag = not cpuct_first_flag print("c_puct:", cpuct, "n_playout:", self.n_playout) policy_value_net = PolicyValueNet(GAME_WIDTH, GAME_HEIGHT, GAME_ACTIONS_NUM, model_file=model_file) player = MCTSPlayer(policy_value_net.policy_value_fn, c_puct=cpuct, n_playout=self.n_playout) _data = { "steps": [], "shapes": [], "last_state": 0, "score": 0, "piece_count": 0 } # game = copy.deepcopy(agent) game = Agent(isRandomNextPiece=False) if game_num == 1 or game_num == max_game_num: game.show_mcts_process = True piece_idx = [] for i in count(): _step = {"step": i} _step["state"] = game.current_state() _step["piece_count"] = game.piececount _step["shape"] = game.fallpiece["shape"] _step["piece_height"] = game.pieceheight if game_num == 1: action, move_probs = player.get_action(game, temp=self.temp, return_prob=1, need_random=False) else: action, move_probs = player.get_action(game, temp=self.temp, return_prob=1, need_random=False) if game.get_key() in game_keys: print(game.steps, game.piececount, game.fallpiece["shape"], game.piecesteps, "key:", game.get_key(), "key_len:", len(game_keys)) action = random.choice(game.get_availables()) _, reward = game.step(action) _step["key"] = game.get_key() # 这里不鼓励多行消除 _step["reward"] = 1 if reward > 0 else 0 _step["action"] = action _step["move_probs"] = move_probs _data["shapes"].append(_step["shape"]) _data["steps"].append(_step) # 这里的奖励是消除的行数 if reward > 0: result = self.read_status_file(jsonfile) if result["curr"]["height"] == 0: result["curr"]["height"] = game.pieceheight else: result["curr"]["height"] = round( result["curr"]["height"] * 0.99 + game.pieceheight * 0.01, 2) result["shapes"][_step["shape"]] += reward # 如果是第一次奖励,记录当前的是第几个方块 if game.score == reward: if result["first_reward"] == 0: result["first_reward"] = game.piececount else: result["first_reward"] = result[ "first_reward"] * 0.99 + game.piececount * 0.01 # 如果第一次的奖励低于平均数,则将前面的几个方块也进行奖励 if game.piececount < result["first_reward"]: for idx in piece_idx: _data["steps"][idx]["reward"] = 0.5 json.dump(result, open(jsonfile, "w"), ensure_ascii=False) print("#"*40, 'score:', game.score, 'height:', game.pieceheight, 'piece:', game.piececount, "shape:", game.fallpiece["shape"], \ 'step:', i, "step time:", round((time.time()-start_time)/i,3), "#"*40) # 记录当前的方块放置的 idx if game.state != 0: piece_idx.append(i) # 方块的个数越多越好 if game.terminal or (reward > 0 and game.pieceheight > 8): _game_last_reward = 0 # game.getNoEmptyCount()/200. _data["reward"] = _game_last_reward _data["score"] = game.score _data["piece_count"] = game.piececount # 更新状态 game_reward = _game_last_reward + game.score result = self.read_status_file(jsonfile) if result["QVal"] == 0: result["QVal"] = game_reward else: result["QVal"] = result[ "QVal"] * 0.999 + game_reward * 0.001 paytime = time.time() - start_time steptime = paytime / game.steps if result["time"]["agent_time"] == 0: result["time"]["agent_time"] = paytime result["time"]["step_time"] = steptime else: result["time"]["agent_time"] = round( result["time"]["agent_time"] * 0.99 + paytime * 0.01, 3) d = game.steps / 10000.0 if d > 1: d = 0.99 result["time"]["step_time"] = round( result["time"]["step_time"] * (1 - d) + steptime * d, 3) # 记录当前cpuct的统计结果 if str(cpuct) in result["cpuct"]: result["cpuct"][str(cpuct)] = result["cpuct"][str( cpuct)] * 0.99 + game_reward * 0.01 if game_reward > result["best"]["reward"]: result["best"]["reward"] = game_reward result["best"]["pieces"] = game.piececount result["best"]["score"] = game.score result["best"]["agent"] = result["agent"] + agentcount result["agent"] += 1 result["curr"]["reward"] += game.score result["curr"]["pieces"] += game.piececount result["curr"]["agent1000"] += 1 result["curr"]["agent100"] += 1 json.dump(result, open(jsonfile, "w"), ensure_ascii=False) game.print() print(game_num, 'reward:', game.score, "Qval:", game_reward, 'len:', i, "piececount:", game.piececount, "time:", time.time() - start_time) print("pay:", time.time() - start_time, "s\n") agentcount += 1 agentscore += game.score agentreward += game_reward piececount += game.piececount break for step in _data["steps"]: if not step["key"] in game_keys: game_keys.append(step["key"]) game_datas.append(_data) borads.append(game.board) # 如果训练样本超过10000,则停止训练 if len(game_keys) > 10000: break # 如果训练次数超过了最大次数,则直接终止训练 if game_num >= max_game_num: break # 打印borad: from game import blank for y in range(agent.height): line = "" for b in borads: line += "| " for x in range(agent.width): if b[x][y] == blank: line += " " else: line += "%s " % b[x][y] print(line) print((" " + " -" * agent.width + " ") * len(borads)) ## 放弃 按0.50的衰减更新reward # 只关注最后一次得分方块的所有步骤,将消行方块的所有步骤的得分都设置为1 for data in game_datas: step_count = len(data["steps"]) piece_count = -1 v = 0 vlist = [] for i in range(step_count - 1, -1, -1): if piece_count != data["steps"][i]["piece_count"]: piece_count = data["steps"][i]["piece_count"] v = data["steps"][i][ "reward"] # 0.5*v+data["steps"][i]["reward"] if v > 1: v = 1 vlist.insert(0, v) data["steps"][i]["reward"] = v print(vlist) # 总得分为 消行奖励 + (本局消行奖励-平均每局消行奖励/平均每局消行奖励) # for data in game_datas: # step_count = len(data["steps"]) # weight = (data["score"]-result["QVal"])/result["QVal"] # for i in range(step_count): # # if data["steps"][i]["reward"] < 1: # v = data["steps"][i]["reward"] + weight # # if v>1: v=1 # data["steps"][i]["reward"] = v # print("fixed reward") # for data in game_datas: # step_count = len(data["steps"]) # piece_count = -1 # vlist=[] # for i in range(step_count): # if piece_count!=data["steps"][i]["piece_count"]: # piece_count = data["steps"][i]["piece_count"] # vlist.append(data["steps"][i]["reward"]) # print("score:", data["score"], "piece_count:", data["piece_count"], [round(num, 2) for num in vlist]) # 状态 概率 本步表现 本局奖励 states, mcts_probs, values, score = [], [], [], [] for data in game_datas: for step in data["steps"]: states.append(step["state"]) mcts_probs.append(step["move_probs"]) values.append(step["reward"]) score.append(data["score"]) # # 用于统计shape的std # pieces_idx={"t":[], "i":[], "j":[], "l":[], "s":[], "z":[], "o":[]} # var_keys = set() # for data in game_datas: # for shape in set(data["shapes"]): # var_keys.add(shape) # step_key_name = "shape" # for key in var_keys: # _states, _mcts_probs, _values = [], [], [] # # _pieces_idx={"t":[], "i":[], "j":[], "l":[], "s":[], "z":[], "o":[]} # for data in game_datas: # for step in data["steps"]: # if step[step_key_name]!=key: continue # _states.append(step["state"]) # _mcts_probs.append(step["move_probs"]) # _values.append(step["reward"]) # # _pieces_idx[step["shape"]].append(len(values)+len(_values)-1) # if len(_values)==0: continue # # 重新计算 # curr_avg_value = sum(_values)/len(_values) # curr_std_value = np.std(_values) # if curr_std_value<0.01: continue # # for shape in _pieces_idx: # # pieces_idx[shape].extend(_pieces_idx[shape]) # _normalize_vals = [] # # 用正态分布的方式重新计算 # curr_std_value_fix = curr_std_value + 1e-8 # * (2.0**0.5) # curr_std_value / result["vars"]["std"] # for v in _values: # #标准化的标准差为 (x-μ)/(σ/std), std 为 1 # 1/sqrt(2) # _nv = (v-curr_avg_value)/curr_std_value_fix # if _nv <-1 : _nv = -1 # if _nv >1 : _nv = 1 # if _nv == 0: _nv = 1e-8 # _normalize_vals.append(_nv) # # 将最好的一步的值设置为1 # # max_normalize_val = max(_normalize_vals)-1 # # for i in range(len(_normalize_vals)): # # _normalize_vals[i] -= max_normalize_val # print(key, len(_normalize_vals), "max:", max(_normalize_vals), "min:", min(_normalize_vals), "std:", curr_std_value) # states.extend(_states) # mcts_probs.extend(_mcts_probs) # values.extend(_normalize_vals) # result["vars"]["max"] = result["vars"]["max"]*0.999 + max(_normalize_vals)*0.001 # result["vars"]["min"] = result["vars"]["min"]*0.999 + min(_normalize_vals)*0.001 # result["vars"]["avg"] = result["vars"]["avg"]*0.999 + np.average(_normalize_vals)*0.001 # result["vars"]["std"] = result["vars"]["std"]*0.999 + np.std(_normalize_vals)*0.001 # # _states, _mcts_probs, _values = [], [], [] # # if result["vars"]["max"]>1 or result["vars"]["min"]<-1: # # result["vars"]["std"] = round(result["vars"]["std"]-0.0001,4) # # else: # # result["vars"]["std"] = round(result["vars"]["std"]+0.0001,4) # json.dump(result, open(jsonfile,"w"), ensure_ascii=False) assert len(states) > 0 assert len(states) == len(values) assert len(states) == len(mcts_probs) print("TRAIN Self Play end. length:%s value sum:%s saving ..." % (len(states), sum(values))) # 保存对抗数据到data_buffer for obj in self.get_equi_data(states, mcts_probs, values, score): filename = "{}.pkl".format(uuid.uuid1()) savefile = os.path.join(data_wait_dir, filename) pickle.dump(obj, open(savefile, "wb")) # 打印shape的标准差 # for shape in pieces_idx: # test_data=[] # for i in pieces_idx[shape]: # if i>=(len(values)): break # test_data.append(values[i]) # if len(test_data)==0: continue # print(shape, "len:", len(test_data), "max:", max(test_data), "min:", min(test_data), "std:", np.std(test_data)) result = self.read_status_file(jsonfile) if result["curr"]["agent100"] > 100: result["reward"].append( round(result["curr"]["reward"] / result["curr"]["agent1000"], 2)) result["pieces"].append( round(result["curr"]["pieces"] / result["curr"]["agent1000"], 2)) result["qvals"].append(round(result["QVal"], 2)) result["height"].append(result["curr"]["height"]) result["time"]["step_times"].append(result["time"]["step_time"]) result["curr"]["agent100"] -= 100 while len(result["reward"]) > 200: result["reward"].remove(result["reward"][0]) while len(result["pieces"]) > 200: result["pieces"].remove(result["pieces"][0]) while len(result["qvals"]) > 200: result["qvals"].remove(result["qvals"][0]) while len(result["height"]) > 200: result["height"].remove(result["height"][0]) while len(result["time"]["step_times"]) > 200: result["time"]["step_times"].remove( result["time"]["step_times"][0]) # 每100局更新一次cpuct参数 qval = result["QVal"] # cpuct表示概率的可信度 if result["cpuct"][cpuct_list[0]] > result["cpuct"][cpuct_list[1]]: cpuct = round(float(cpuct_list[0]) - 0.01, 2) if cpuct <= 0.01: result["cpuct"] = {"0.01": qval, "1.01": qval} else: result["cpuct"] = { str(cpuct): qval, str(round(cpuct + 1, 2)): qval } else: cpuct = round(float(cpuct_list[0]) + 0.01, 2) result["cpuct"] = { str(cpuct): qval, str(round(cpuct + 1, 2)): qval } if max(result["reward"]) == result["reward"][-1]: newmodelfile = model_file + "_reward_" + str( result["reward"][-1]) if not os.path.exists(newmodelfile): policy_value_net.save_model(newmodelfile) if result["curr"]["agent1000"] > 1000: result["curr"] = { "reward": 0, "pieces": 0, "agent1000": 0, "agent100": 0, "height": 0 } newmodelfile = model_file + "_" + str(result["agent"]) if not os.path.exists(newmodelfile): policy_value_net.save_model(newmodelfile) result["lastupdate"] = datetime.datetime.now().strftime( '%Y-%m-%d %H:%M:%S') json.dump(result, open(jsonfile, "w"), ensure_ascii=False)
def main(): parser = argparse.ArgumentParser(description='Test') parser.add_argument('--replay_memory_size', default=50000, type=int, help='replayMemory_size to store training data') parser.add_argument('--batch_size', default=512, type=int, help='batch size') parser.add_argument('--learning_rate', default=1e-3, type=float, help='learning_rate') parser.add_argument('--evaluate_freq', default=50, type=int, help='evaluate once every #evaluate_freq games') parser.add_argument( '--train_freq', default=1, type=int, help='train #train_epoch times replay mempry within each train') parser.add_argument('--n_eval_game', default=10, type=int, help='number of games during one evaluation') parser.add_argument('--n_burn_in', default=10, type=int, help='number of games to burn in the replay memory') parser.add_argument('--n_iteration', default=20, type=int, help='number of train iteration') parser.add_argument('--width', default=6, type=int) parser.add_argument('--height', default=6, type=int) parser.add_argument('--n_in_row', default=4, type=int) args = parser.parse_args() width, height = args.width, args.height board = Board(width=width, height=height, n_in_row=args.n_in_row) game = Game(board) # Prepare train and eval model AlphaGoNet_train = PolicyValueNet(width, height) #AlphaGoNet_best = PolicyValueNet(width, height) #torch.save(AlphaGoNet_train.policy_value_net.state_dict(), 'model/init.mdl') AlphaGoNet_train.policy_value_net.load_state_dict( torch.load('model/current.mdl')) # Replay is used to store training data: ReplayMemory = deque(maxlen=args.replay_memory_size) player = AlphaGoPlayer(NN_fn=AlphaGoNet_train.policy_value_fn) #eval_player = AlphaGoPlayer(NN_fn=AlphaGoNet_best.policy_value_fn) eval_player = MCTSPlayer() max_win_ratio = .0 # Burn in burn_in(game, player, ReplayMemory, args.n_burn_in) for i in range(args.n_iteration): print 'Iteration NO.:', i train_one_iteration(game, player, ReplayMemory, AlphaGoNet_train, args.batch_size, args.learning_rate, args.train_freq, args.evaluate_freq) win_ratio = evaluate(game, player, eval_player, args.n_eval_game) if win_ratio > max_win_ratio: print('Get current_best model!') max_win_ratio = win_ratio torch.save(AlphaGoNet_train.policy_value_net.state_dict(), 'model/current_best.mdl') else: print('Save current model') torch.save(AlphaGoNet_train.policy_value_net.state_dict(), 'model/current.mdl')
def learn_process(args, share_model, shared_lr_mul, shared_g_cnt, shared_q, lock): from model import PolicyValueNet epochs = 1 kl_targ = 0.025 lr_multiplier = shared_lr_mul # adaptively adjust the learning rate based on KL g_cnt = shared_g_cnt learn_rate = args.lr batch_size = args.batch_size # mini-batch size for training policy_value_net = PolicyValueNet(args.board_max, args.board_max, use_gpu=args.cuda) policy_value_net.policy_value_net.load_state_dict(share_model.state_dict()) data_buffer = deque(maxlen=args.memory_capacity) try: with open('qqq.dat', 'rb') as qq: temp_buffer = pickle.load(qq) print('load buffer length: ', len(temp_buffer)) data_buffer.extend(temp_buffer) except: pass def learn(policy_value_net, rank, data_buffer): """update the policy-value net""" mini_batch = random.sample(data_buffer, batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = policy_value_net.policy_value(state_batch) for i in range(epochs): loss, entropy = policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, learn_rate * lr_multiplier.value) new_probs, new_v = policy_value_net.policy_value(state_batch) kl = np.mean( np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if kl > kl_targ * 4: # early stopping if D_KL diverges badly break # adaptively adjust the learning rate if kl > kl_targ * 2 and lr_multiplier.value > 0.1: lr_multiplier.value /= 1.5 elif kl < kl_targ / 2 and lr_multiplier.value < 10: lr_multiplier.value *= 1.5 # explained_var_old = 1 - np.var(np.array(winner_batch) - old_v.flatten())/np.var(np.array(winner_batch)) # explained_var_new = 1 - np.var(np.array(winner_batch) - new_v.flatten())/np.var(np.array(winner_batch)) # ss = "rank:{} c:{} kl:{:.5f},lr_multiplier:{:.3f},loss:{},entropy:{},explained_var_old:{:.3f},explained_var_new:{:.3f} {} ".format( # rank,self.g_cnt.value,kl, self.lr_multiplier.value, loss, entropy, explained_var_old, explained_var_new,datetime.datetime.now()) ss = "r:{} c:{} kl:{:.5f},lr_mul:{:.3f},loss:{},ent:{} {} ".format( rank, g_cnt.value, kl, lr_multiplier.value, loss, entropy, datetime.datetime.now()) g_cnt.value += 1 if g_cnt.value % 100 == 0: with open('log.txt', 'a') as f: f.write(ss + '\n') print('\r' + ss, end='', flush=True) # print(ss) return loss, entropy print('learner') try: while True: if not shared_q.empty(): print('extend', len(data_buffer), '>', args.learn_start) data_buffer.extend(shared_q.get()) if len(data_buffer) > args.learn_start: for i in range(7): loss, entropy = learn(policy_value_net, 0, data_buffer) with lock: share_model.load_state_dict( policy_value_net.policy_value_net.state_dict()) else: print('buffer fill :', len(data_buffer), '>', args.learn_start, end='\r') if shared_g_cnt.value % 1000 == 0: print('leanr_save') torch.save(policy_value_net.policy_value_net.state_dict(), './net_param') time.sleep(1) # list_loss.append(loss) # list_entropy.append(entropy) # print('loss : ',loss,' entropy : ',entropy) except: torch.save(policy_value_net.policy_value_net.state_dict(), './net_param') with open('qqq.dat', 'wb') as qq: pickle.dump(data_buffer, qq, -1) print('pickle dump') print(999, 'except save')