コード例 #1
0
import numpy as np
import datetime

from catch_ball import CatchBall
from dqn_model import get_dqn_model


if __name__ == "__main__":
    # environment, agent
    env = CatchBall()
    dqn = get_dqn_model(env)

    try:
        dqn.load_weights('dqn_{}_weights.h5f'.format("catch_ball"))
        print("start from saved weights")
    except:
        print("start from random weights")

    # Okay, now it's time to learn something! We visualize the training here for show, but this
    # slows down training quite a lot. You can always safely abort the training prematurely using
    # Ctrl + C.
    dqn.fit(env, nb_steps=50000, visualize=False, verbose=1)

    # After training is done, we save the final weights.
    dqn.save_weights('dqn_{}_weights.h5f'.format("catch_ball"), overwrite=True)
    dqn.save_weights('dqn_{}_weights.{}.h5f.bak'.format(
        "catch_ball", datetime.datetime.today().strftime("%Y-%m%d-%H%M")
    ), overwrite=True)
コード例 #2
0
ファイル: test.py プロジェクト: Tonyan/tf-dqn-simple
    # animate
    img.set_array(state_t_1)
    plt.axis("off")
    return img,


if __name__ == "__main__":
    # args
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model_path")
    parser.add_argument("-s", "--save", dest="save", action="store_true")
    parser.set_defaults(save=False)
    args = parser.parse_args()

    # environmet, agent
    env = CatchBall()
    agent = DQNAgent(env.enable_actions, env.name)
    agent.load_model(args.model_path)

    # variables
    win, lose = 0, 0
    state_t_1, reward_t, terminal = env.observe()

    # animate
    fig = plt.figure(figsize=(env.screen_n_rows / 2, env.screen_n_cols / 2))
    fig.canvas.set_window_title("{}-{}".format(env.name, agent.name))
    img = plt.imshow(state_t_1, interpolation="none", cmap="gray")
    ani = animation.FuncAnimation(fig, animate, init_func=init, interval=(1000 / env.frame_rate), blit=True)

    if args.save:
        # save animation (requires ImageMagick)
コード例 #3
0
import numpy as np

from catch_ball import CatchBall
from dqn_agent import DQNAgent

if __name__ == "__main__":
    # parameters
    n_epochs = 1000

    # environment, agent
    env = CatchBall()
    agent = DQNAgent(env.enable_actions, env.name)

    # variables
    win = 0

    for e in range(n_epochs):
        # reset
        frame = 0
        loss = 0.0
        Q_max = 0.0
        env.reset()

        state_t, reward_t_none = env.observe()

        # execute action in environment
        action_t = agent.select_action(state_t, agent.exploration)
        env.execute_action(action_t)

        # observe environment
        state_t, reward_t = env.observe()
コード例 #4
0
ファイル: train.py プロジェクト: maekawatoshiki/dqn-sample
import numpy as np

from catch_ball import CatchBall
from dqn_agent import DQNAgent

if __name__ == "__main__":
    # parameters
    n_epochs = 1000

    # environment, agent
    env = CatchBall()
    agent = DQNAgent(env.enable_actions, env.name)

    # variables
    win = 0

    for e in range(n_epochs):
        # reset
        frame = 0
        loss = 0.0
        Q_max = 0.0
        env.reset()
        state_t_1, reward_t, terminal = env.observe()

        while not terminal:
            state_t = state_t_1

            # execute action in environment
            action_t = agent.select_action(state_t, agent.exploration)
            env.execute_action(action_t)
コード例 #5
0
ファイル: test.py プロジェクト: yukiB/keras-dqn-test
    plt.axis("off")
    return img,


if __name__ == "__main__":
    # args
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model_path")
    parser.add_argument("-s", "--save", dest="save", action="store_true")
    parser.add_argument("--simple", dest="is_simple", action="store_true", default=False,
                        help='Test simple model without cnn (8 x 8) (default: off)')
    parser.set_defaults(save=False)
    args = parser.parse_args()

    # environmet, agent
    env = CatchBall(time_limit=False, simple=args.is_simple)
    agent = DQNAgent(env.enable_actions, env.name)
    agent.load_model(args.model_path, simple=args.is_simple)

    # variables
    n_catched = 0
    state_t_1, reward_t, terminal = env.observe()
    S = deque(maxlen=state_num)

    # animate
    fig = plt.figure(figsize=(env.screen_n_rows / 2, env.screen_n_cols / 2))
    fig.canvas.set_window_title("{}-{}".format(env.name, agent.name))
    img = plt.imshow(state_t_1, interpolation="none", cmap="gray")
    ani = animation.FuncAnimation(fig, animate, init_func=init, interval=(1000 / env.frame_rate), blit=True)

    if args.save:
コード例 #6
0
ファイル: test.py プロジェクト: syybata/tf_dqn_fx
    # animate
    img.set_array(state_t_1)
    plt.axis("off")
    return img,


if __name__ == "__main__":
    # args
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model_path")
    parser.add_argument("-s", "--save", dest="save", action="store_true")
    parser.set_defaults(save=False)
    args = parser.parse_args()

    # environmet, agent
    env = CatchBall()
    agent = DQNAgent(env.enable_actions, env.name)
    agent.load_model(args.model_path)

    # variables
    win, lose = 0, 0
    state_t_1, reward_t, terminal = env.observe()
    balance = 0

    # animate
    fig = plt.figure(figsize=(env.screen_n_rows / 2, env.screen_n_cols / 2))
    fig.canvas.set_window_title("{}-{}".format(env.name, agent.name))
    img = plt.imshow(state_t_1, interpolation="none", cmap="gray")
    ani = animation.FuncAnimation(fig,
                                  animate,
                                  init_func=init,
コード例 #7
0
ファイル: test.py プロジェクト: knnfm/tf-othello-rd
from __future__ import division

import argparse
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

from catch_ball import CatchBall
from dqn_agent import DQNAgent
from debug_log import DebugLog

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model_path")
    args = parser.parse_args()

    env = CatchBall()
    agent = DQNAgent(env.enable_actions, env.name)
    agent.load_model(args.model_path)

    env.reset_board_status()
    env.set_test_game()

    # 1ゲーム内の処理開始地点
    while env.is_playable() is True:
        env.print_board()
        env.test_player_play()

        state = env.observe()

        while True is True:
            action = agent.select_action(state, 0.0)
コード例 #8
0
import numpy as np

from catch_ball import CatchBall
from dqn_agent import DQNAgent

n_rows = 16
n_cols = 32
n_playerlength = 3


if __name__ == "__main__":
    # parameters
    n_epochs = 10000

    # environment, agent
    env = CatchBall(n_rows, n_cols, n_playerlength)
    agent = DQNAgent(env.enable_actions, env.name, env.screen_n_rows, env.screen_n_cols, env.player_length)

    # variables
    win = 0

    for e in range(n_epochs):
        # reset
        frame = 0
        loss = 0.0
        Q_max = 0.0
        env.reset(keepPos=True)
        state_t_1, reward_t, terminal = env.observe()

        while not terminal:
            state_t = state_t_1
コード例 #9
0
from catch_ball import CatchBall

if __name__ == '__main__':
    cb = CatchBall()
    cb.run()
コード例 #10
0
    log.append(env.get_hand_name(challenger_hand))

    if reward == 0:
        print "AI EVEN " + "".join(log)
    elif reward == 1:
        print "AI WIN " + "".join(log)
    else:
        print "AI LOSE " + "".join(log)


if __name__ == "__main__":
    # args
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model_path")
    parser.add_argument("-a", "--action")
    args = parser.parse_args()

    # environmet, agent
    env = CatchBall()
    agent = DQNAgent(env.enable_actions, env.name)
    agent.load_model(args.model_path)
    env.set_card(args.action)
    state_t, reward_t = env.observe()

    # variables
    action_t = agent.select_action(state_t, 0.0)
    env.execute_action(action_t)
    state_t, reward_t = env.observe()

    print_result(env, int(action_t), env.get_hand_number(state_t), reward_t)
コード例 #11
0
ファイル: test.py プロジェクト: knnfm/tensorflow-sudoku
    log = []
    log.append("AI:")
    log.append(env.get_hand_name(ai_hand))
    log.append(" ")
    log.append("Challenger:")
    log.append(env.get_hand_name(challenger_hand))

    if reward == 0:
        print "AI EVEN " + "".join(log)
    elif reward == 1:
        print "AI WIN " + "".join(log)
    else:
        print "AI LOSE " + "".join(log)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model_path")
    args = parser.parse_args()

    env = CatchBall()
    agent = DQNAgent(env.enable_actions, env.name)
    agent.load_model(args.model_path)

    env.reset()
    env.set_question()
    state_t_1, reward_t, terminal = env.observe()
    action_t = agent.select_action(state_t_1, 0.0)
    print state_t_1
    print action_t
コード例 #12
0
ファイル: train.py プロジェクト: Tonyan/tf-dqn-simple
import numpy as np

from catch_ball import CatchBall
from dqn_agent import DQNAgent


if __name__ == "__main__":
    # parameters
    n_epochs = 1000

    # environment, agent
    env = CatchBall()
    agent = DQNAgent(env.enable_actions, env.name)

    # variables
    win = 0

    for e in range(n_epochs):
        # reset
        frame = 0
        loss = 0.0
        Q_max = 0.0
        env.reset()
        state_t_1, reward_t, terminal = env.observe()

        while not terminal:
            state_t = state_t_1

            # execute action in environment
            action_t = agent.select_action(state_t, agent.exploration)
            env.execute_action(action_t)
コード例 #13
0
ファイル: train.py プロジェクト: knnfm/tf-othello-rd
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import copy
import numpy as np

from catch_ball import CatchBall
from dqn_agent import DQNAgent

if __name__ == "__main__":
    # 学習に回す回数
    n_epochs = 0
    n_game = 100

    env = CatchBall()
    agent = DQNAgent(env.enable_actions, env.name)
    total_result_log = ""

    for e in range(n_game):
        # るーぷ開始地点
        frame = 0
        win = 0
        loss = 0.0
        Q_max = 0.0
        env.reset_board_status()
        env.set_new_game()
        state_after = env.observe()

        # 1ゲーム内の処理開始地点
        while env.is_playable() is True:
            # print "*********************************************************************************"