def records_sample(model_num, verificate=False): db_path = utils.PAI_DB_PATH if utils.USE_PAI else utils.DB_PATH if not verificate: new_records_pattern = os.path.join(db_path, 'game-{}*'.format(model_num)) # all_records_pattern = os.path.join(db_path, 'game*') while True: new_records = utils.pai_find_path(new_records_pattern) if len(new_records) >= utils.TRAIN_EPOCH_GAME_NUM * 2 + 10: # break return new_records else: time.sleep(60) # all_records = utils.pai_find_path(all_records_pattern) # old_records = list(set(all_records) - set(new_records)) # try: # old_records = list(np.random.choice(old_records, utils.TRAIN_SAMPLE_NUM)) # raise Exception('eee') # return old_records + new_records # except: # return new_records else: records_pattern = os.path.join( db_path, 'game-verification-{}*'.format(model_num)) while True: records = utils.pai_find_path(records_pattern) if len(records) >= utils.VERIFICATION_GAME_NUM * 2: return records else: time.sleep(60)
def save_history_to_tfrecord(self, reward): if utils.SAVE_RECORD: net_model_num = self.mct.net.get_model_num() verification_pattern = os.path.join( utils.PAI_DB_PATH if utils.USE_PAI else utils.DB_PATH, 'game-verification-{}*'.format(net_model_num)) if len(utils.pai_find_path( verification_pattern)) < 2 * utils.VERIFICATION_GAME_NUM: tfr_name = 'game-verification-{}-{}-{}.tfrecord'.format( net_model_num, time.time(), self.color_str) else: tfr_name = 'game-{}-{}-{}.tfrecord'.format( net_model_num, time.time(), self.color_str) tfr_path = os.path.join( utils.PAI_DB_PATH if utils.USE_PAI else utils.DB_PATH, tfr_name) tfr_writer = generate_writer(tfr_path) for i, expect in enumerate(self.prob_history): feature = self.game.board.get_feature(self.color, i) example = generate_example(feature, expect, reward) tfr_writer.write(example.SerializeToString()) # for i, base_expect in enumerate(self.prob_history): # base_feature = self.game.board.get_feature(self.color, i) # base_expect = np.reshape(base_expect, (self.size, self.size)) # for rot_num in range(8): # rot = self.mct.net.rot[rot_num] # feature = rot(base_feature, (1, 2)) # expect = rot(base_expect).reshape(self.game.board.full_size) # example = generate_example(feature, expect, reward) # tfr_writer.write(example.SerializeToString()) tfr_writer.close()
def compare(compare_model_num=None, default_model_num=None): if utils.USE_PAI: if compare_model_num is None: compare_model_num = utils.pai_read_best() if default_model_num is None: new_compare = utils.pai_find_path( os.path.join(utils.PAI_RECORD_PATH, 'compare-{}'.format(compare_model_num)) ) if new_compare: default_model_num = int(new_compare[0].split('-')[-1]) else: default_model_num = utils.pai_read_best('compare') if default_model_num != compare_model_num: win, total = utils.pai_read_compare_record(default_model_num, compare_model_num) else: total = 10000 if total > utils.COMPARE_TIME: time.sleep(120) return game = Game(utils.MCTS, utils.MCTS, utils.SIZE, compare_model_num, default_model_num) game.logger.info('Compare model {} with default model {}'.format(compare_model_num, default_model_num)) while True: win, total = utils.pai_read_compare_record(default_model_num, compare_model_num) game.logger.info('Now compare result: {}-{}'.format(win, total)) if total > utils.COMPARE_TIME: break black_as_best = np.random.choice([True, False]) if black_as_best: game.logger.info('Black as best') game.black_player.mct.reset_net(default_model_num) game.white_player.mct.reset_net(compare_model_num) else: game.logger.info('White as best') game.black_player.mct.reset_net(compare_model_num) game.white_player.mct.reset_net(default_model_num) game.start() winner = game.board.winner if winner is utils.EMPTY: pass elif (winner is utils.BLACK and black_as_best) or (winner is utils.WHITE and not black_as_best): utils.pai_write_compare_record(default_model_num, compare_model_num, False) else: utils.pai_write_compare_record(default_model_num, compare_model_num, True) game.reset() if win / total > utils.COMPARE_WIN_RATE: utils.pai_change_best(compare_model_num, 'compare') game.logger.info('Change best model to {}'.format(compare_model_num)) else: utils.pai_change_best(default_model_num, 'compare') game.logger.info('Best model does not change')
def main(_): pai_constant_init() records = utils.pai_find_path( os.path.join(utils.PAI_RECORD_PATH, '*x*-*.psq')) game = Game(utils.TRANS, utils.TRANS) for record in records: game.black_player.add_record(record) game.white_player.add_record(record) game.start() game.reset()
def main(): if utils.USE_PAI: try: model_num = utils.pai_read_best() except: model_num = 0 utils.pai_change_best(model_num) db_pattern = os.path.join(utils.PAI_DB_PATH, 'game-{}-*'.format(model_num)) if len(utils.pai_find_path(db_pattern)) / 2 >= utils.TRAIN_EPOCH_GAME_NUM: pass else: game = Game(utils.MCTS, utils.MCTS, black_net_model_num=model_num, white_net_model_num=model_num) while len(utils.pai_find_path(db_pattern)) / 2 < utils.TRAIN_EPOCH_GAME_NUM: game.logger.info('There are {} records now'.format(len(utils.pai_find_path(db_pattern))/ 2)) game.start() utils.pai_win_rate_record(model_num, game.board.winner) game.reset() del game time.sleep(120) else: game = Game() game.start()
def save_history_to_tfrecord(self, reward): if utils.SAVE_RECORD: net_model_num = 0 verification_pattern = os.path.join( utils.PAI_DB_PATH if utils.USE_PAI else utils.DB_PATH, 'game-verification-{}*'.format(net_model_num)) if len(utils.pai_find_path( verification_pattern)) < 2 * utils.VERIFICATION_GAME_NUM: tfr_name = 'game-verification-{}-{}-{}.tfrecord'.format( net_model_num, time.time(), self.color_str) else: tfr_name = 'game-{}-{}-{}.tfrecord'.format( net_model_num, time.time(), self.color_str) tfr_path = os.path.join( utils.PAI_DB_PATH if utils.USE_PAI else utils.DB_PATH, tfr_name) tfr_writer = generate_writer(tfr_path) for i, expect in enumerate(self.prob_history): feature = self.game.board.get_feature(self.color, i) example = generate_example(feature, expect, reward) tfr_writer.write(example.SerializeToString()) tfr_writer.close()