コード例 #1
0
ファイル: eval_ac_bot.py プロジェクト: LJQCN101/alphago
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--agent1', required=True)
    parser.add_argument('--agent2', required=True)
    parser.add_argument('--num-games', '-n', type=int, default=10)

    args = parser.parse_args()
    agent1 = rl.load_ac_agent(h5py.File(args.agent1))
    agent2 = rl.load_ac_agent(h5py.File(args.agent2))
    num_games = args.num_games

    wins = 0
    losses = 0
    color1 = Player.black
    for i in range(num_games):
        print('Simulating game %d/%d...' % (i + 1, num_games))
        if color1 == Player.black:
            black_player, white_player = agent1, agent2
        else:
            white_player, black_player = agent1, agent2
        game_record = simulate_game(black_player, white_player)
        if game_record.winner == color1:
            wins += 1
        else:
            losses += 1
        color1 = color1.other
    print('Agent 1 record: %d/%d' % (wins, wins + losses))
コード例 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--board-size', type=int, default=19)
    parser.add_argument('--learning-agent', required=True)
    parser.add_argument('--num-games', '-n', type=int, default=10)
    # parser.add_argument('--game-log-out', required=True)
    parser.add_argument('--experience-out', required=True)
    parser.add_argument('--temperature', type=float, default=0.0)

    args = parser.parse_args()
    global BOARD_SIZE
    BOARD_SIZE = args.board_size
    agent1 = rl.load_ac_agent(h5py.File(args.learning_agent))
    agent2 = rl.load_ac_agent(h5py.File(args.learning_agent))
    agent1.set_temperature(args.temperature)
    agent2.set_temperature(args.temperature)

    collector1 = rl.ExperienceCollector()
    collector2 = rl.ExperienceCollector()

    color1 = Player.black
    # logf = open(args.game_log_out, 'a')
    # logf.write('Begin training at %s\n' % (datetime.datetime.now().strftime('%Y-%m-%d %H:%M'),))
    for i in range(args.num_games):
        print('Simulating game %d/%d...' % (i + 1, args.num_games))
        collector1.begin_episode()
        agent1.set_collector(collector1)
        collector2.begin_episode()
        agent2.set_collector(collector2)

        if color1 == Player.black:
            black_player, white_player = agent1, agent2
        else:
            white_player, black_player = agent1, agent2
        game_record = simulate_game(black_player, white_player)
        if game_record.winner == color1:
            print('Agent 1 wins.')
            collector1.complete_episode(reward=1)
            collector2.complete_episode(reward=-1)
        else:
            print('Agent 2 wins.')
            collector2.complete_episode(reward=1)
            collector1.complete_episode(reward=-1)
        color1 = color1.other

    experience = rl.combine_experience([collector1, collector2])
    # logf.write('Saving experience buffer to %s\n' % args.experience_out)
    with h5py.File(args.experience_out, 'w') as experience_outf:
        experience.serialize(experience_outf)
コード例 #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--bind-address', default='127.0.0.1')
    parser.add_argument('--port', '-p', type=int, default=5000)
    parser.add_argument('--pg-agent')
    parser.add_argument('--predict-agent')
    parser.add_argument('--q-agent')
    parser.add_argument('--ac-agent')

    args = parser.parse_args()

    bots = {'mcts': mcts.MCTSAgent(800, temperature=0.7)}
    if args.pg_agent:
        bots['pg'] = agent.load_policy_agent(h5py.File(args.pg_agent))
    if args.predict_agent:
        bots['predict'] = agent.load_prediction_agent(
            h5py.File(args.predict_agent))
    if args.q_agent:
        q_bot = rl.load_q_agent(h5py.File(args.q_agent))
        q_bot.set_temperature(0.01)
        bots['q'] = q_bot
    if args.ac_agent:
        ac_bot = rl.load_ac_agent(h5py.File(args.ac_agent))
        ac_bot.set_temperature(0.05)
        bots['ac'] = ac_bot

    web_app = httpfrontend.get_web_app(bots)
    web_app.run(host=args.bind_address, port=args.port, threaded=False)
コード例 #4
0
ファイル: train_ac.py プロジェクト: AndrewNomura/GammaGo_3
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--learning-agent', required=True)
    parser.add_argument('--agent-out', required=True)
    parser.add_argument('--lr', type=float, default=0.0001)
    parser.add_argument('--bs', type=int, default=512)
    parser.add_argument('experience', nargs='+')

    args = parser.parse_args()
    learning_agent_filename = args.learning_agent
    experience_files = args.experience
    updated_agent_filename = args.agent_out
    learning_rate = args.lr
    batch_size = args.bs

    learning_agent = rl.load_ac_agent(h5py.File(learning_agent_filename))
    for exp_filename in experience_files:
        exp_buffer = rl.load_experience(h5py.File(exp_filename))
        learning_agent.train(
            exp_buffer,
            lr=learning_rate,
            batch_size=batch_size)

    with h5py.File(updated_agent_filename, 'w') as updated_agent_outf:
        learning_agent.serialize(updated_agent_outf)
コード例 #5
0
def main():
    workdir = '//home/nail//Code_Go//checkpoints//'
    os.chdir(workdir)
    bind_address = '127.0.0.1'
    port = 5000
    predict_agent, pg_agent, q_agent, ac_agent = '', '', '', ''
    agent_type = input('Агент(pg/predict/q/ac = ').lower()
    if agent_type == 'pg':
        pg_agent = input(
            'Введите имя файла для игры с ботом политика градиентов =')
        pg_agent = workdir + pg_agent + '.h5'
    if agent_type == 'predict':
        predict_agent = input(
            'Введите имя файла для игры с ботом предсказания хода =')
        predict_agent = workdir + predict_agent + '.h5'
    if agent_type == 'q':
        q_agent = input(
            'Введите имя файла для игры с ботом ценность действия =')
        q_agent = workdir + q_agent + '.h5'
    if agent_type == 'ac':
        ac_agent = input('Введите имя файла для игры с ботом актор-критик =')
        ac_agent = workdir + ac_agent + '.h5'

    bots = {'mcts': mcts.MCTSAgent(800, temperature=0.7)}
    if agent_type == 'pg':
        bots['pg'] = agent.load_policy_agent(h5py.File(pg_agent, 'r'))
    if agent_type == 'predict':
        bots['predict'] = agent.load_prediction_agent(
            h5py.File(predict_agent, 'r'))
    if agent_type == 'q':
        q_bot = rl.load_q_agent(h5py.File(q_agent, 'r'))
        q_bot.set_temperature(0.01)
        bots['q'] = q_bot
    if agent_type == 'ac':
        ac_bot = rl.load_ac_agent(h5py.File(ac_agent, 'r'))
        ac_bot.set_temperature(0.05)
        bots['ac'] = ac_bot

    web_app = httpfrontend.get_web_app(bots)
    web_app.run(host=bind_address, port=port, threaded=False)
コード例 #6
0
def load_agent(filename):
    with h5py.File(filename, 'r') as h5file:
        return rl.load_ac_agent(h5file)