Ejemplo n.º 1
0
def main(**kargs):
    initial_weights_file, initial_i_frame = latest(kargs['weights_dir'])

    print("Continuing using weights from file: ", initial_weights_file, "from", initial_i_frame)

    if kargs['theano_verbose']:
        theano.config.compute_test_value = 'warn'
        theano.config.exception_verbosity = 'high'
        theano.config.optimizer = 'fast_compile'

    ale = ag.init(display_screen=(kargs['visualize'] == 'ale'), record_dir=kargs['record_dir'])
    game = ag.SpaceInvadersGame(ale)


    def new_game():
        game.ale.reset_game()
        game.finished = False
        game.cum_reward = 0
        game.lives = 4
        return game

    replay_memory = dqn.ReplayMemory(size=kargs['dqn.replay_memory_size']) if not kargs['dqn.no_replay'] else None
    # dqn_algo = q.ConstAlgo([3])
    dqn_algo = dqn.DQNAlgo(game.n_actions(),
                           replay_memory=replay_memory,
                           initial_weights_file=initial_weights_file,
                           build_network=kargs['dqn.network'],
                           updates=kargs['dqn.updates'])

    dqn_algo.replay_start_size = kargs['dqn.replay_start_size']
    dqn_algo.final_epsilon = kargs['dqn.final_epsilon']
    dqn_algo.initial_epsilon = kargs['dqn.initial_epsilon']
    dqn_algo.i_frames = initial_i_frame

    dqn_algo.log_frequency=kargs['dqn.log_frequency']


    import Queue
    dqn_algo.mood_q = Queue.Queue() if kargs['show_mood'] else None

    if kargs['show_mood'] is not None:
        plot = kargs['show_mood']()

        def worker():
            while True:
                item = dqn_algo.mood_q.get()
                plot.show(item)
                dqn_algo.mood_q.task_done()

        import threading
        t = threading.Thread(target=worker)
        t.daemon = True
        t.start()

    print(str(dqn_algo))

    visualizer = ag.SpaceInvadersGameCombined2Visualizer() if kargs['visualize'] == 'q' else q.GameNoVisualizer()
    teacher = q.Teacher(new_game, dqn_algo, visualizer,
                        ag.Phi(skip_every=4), repeat_action=4, sleep_seconds=0)
    teacher.teach(500000)
Ejemplo n.º 2
0
def const_on_space_invaders():
    import teacher as q
    import ale_game as ag
    import dqn
    reload(q)
    reload(ag)
    reload(dqn)

    ale = ag.init()
    game = ag.SpaceInvadersGame(ale)

    def new_game():
        game.ale.reset_game()
        game.finished = False
        game.cum_reward = 0
        return game

    const_algo = q.ConstAlgo([2, 2, 2, 2, 2, 0, 0, 0, 0])
    teacher = q.Teacher(new_game, const_algo, ag.SpaceInvadersGameCombined2Visualizer(),
                        ag.Phi(skip_every=6), repeat_action=6)
    teacher.teach(1)
Ejemplo n.º 3
0
    lasagne.layers.get_output(layer, s0_var)
    for layer in lasagne.layers.get_all_layers(n)
]

with np.load(weights_file) as initial_weights:
    param_values = [
        initial_weights['arr_%d' % i]
        for i in range(len(initial_weights.files))
    ]
    lasagne.layers.set_all_param_values(n, param_values)

ff = theano.function([s0_var], outs)

import matplotlib.pyplot as plt
inputs = []
for frame in range(1200, 1400, 20):
    files = [
        os.path.join(directory, "%06d.png" % i)
        for i in range(frame, frame + 16)
    ]
    frames = [np.array(Image.open(f)) for f in files]
    gray_frames = [(np.dot(f, np.array([0.2126, 0.7152,
                                        0.0722]))).astype(np.float32)
                   for f in frames]

    phi = ag.Phi(method="resize")
    inputs.append(np.stack(phi(gray_frames), axis=0).reshape(1, 4, 84, 84))

output = plt.imshow(ff(inputs[0])[1][0][0], cmap='Greys_r')
print(output)
Ejemplo n.º 4
0
def main(**kargs):
    initial_weights_file, i_total_action = latest(kargs['weights_dir'])

    print("Continuing using weights from file: ", initial_weights_file, "from",
          i_total_action)

    if kargs['theano_verbose']:
        theano.config.compute_test_value = 'warn'
        theano.config.exception_verbosity = 'high'
        theano.config.optimizer = 'fast_compile'

    if kargs['game'] == 'simple_breakout':
        game = simple_breakout.SimpleBreakout()

        class P(object):
            def __init__(self):
                self.screen_size = 12

            def __call__(self, frames):
                return frames

        phi = P()
    else:
        ale = ag.init(game=kargs['game'],
                      display_screen=(kargs['visualize'] == 'ale'),
                      record_dir=kargs['record_dir'])
        game = ag.ALEGame(ale)
        phi = ag.Phi(method=kargs["phi_method"])

    replay_memory = dqn.ReplayMemory(size=kargs['dqn.replay_memory_size']
                                     ) if not kargs['dqn.no_replay'] else None
    algo = dqn.DQNAlgo(game.n_actions(),
                       replay_memory=replay_memory,
                       initial_weights_file=initial_weights_file,
                       build_network=kargs['dqn.network'],
                       updates=kargs['dqn.updates'],
                       screen_size=phi.screen_size)

    algo.replay_start_size = kargs['dqn.replay_start_size']
    algo.final_epsilon = kargs['dqn.final_epsilon']
    algo.initial_epsilon = kargs['dqn.initial_epsilon']
    algo.i_action = i_total_action

    algo.log_frequency = kargs['dqn.log_frequency']
    algo.target_network_update_frequency = kargs[
        'target_network_update_frequency']
    algo.final_exploration_frame = kargs['final_exploration_frame']

    import Queue
    algo.mood_q = Queue.Queue() if kargs['show_mood'] else None

    if kargs['show_mood'] is not None:
        plot = kargs['show_mood']()

        def worker():
            while True:
                item = algo.mood_q.get()
                plot.show(item)
                algo.mood_q.task_done()

        import threading
        t = threading.Thread(target=worker)
        t.daemon = True
        t.start()

    print(str(algo))

    if kargs['visualize'] != 'q':
        visualizer = q.GameNoVisualizer()
    else:
        if kargs['game'] == 'simple_breakout':
            visualizer = simple_breakout.SimpleBreakoutVisualizer(algo)
        else:
            visualizer = ag.ALEGameVisualizer(phi.screen_size)

    teacher = q.Teacher(
        game=game,
        algo=algo,
        game_visualizer=visualizer,
        phi=phi,
        repeat_action=kargs['repeat_action'],
        i_total_action=i_total_action,
        total_n_actions=50000000,
        max_actions_per_game=10000,
        skip_n_frames_after_lol=kargs['skip_n_frames_after_lol'],
        run_test_every_n=kargs['run_test_every_n'])
    teacher.teach()
Ejemplo n.º 5
0
def main(game_name, network_type, updates_method,
         target_network_update_frequency, initial_epsilon, final_epsilon,
         test_epsilon, final_exploration_frame, replay_start_size,
         deepmind_rmsprop_epsilon, deepmind_rmsprop_learning_rate,
         deepmind_rmsprop_rho, rmsprop_epsilon, rmsprop_learning_rate,
         rmsprop_rho, phi_type, phi_method, epoch_size, n_training_epochs,
         n_test_epochs, visualize, record_dir, show_mood, replay_memory_size,
         no_replay, repeat_action, skip_n_frames_after_lol,
         max_actions_per_game, weights_dir, algo_initial_state_file,
         log_frequency, theano_verbose):
    args = locals()

    if theano_verbose:
        theano.config.compute_test_value = 'warn'
        theano.config.exception_verbosity = 'high'
        theano.config.optimizer = 'fast_compile'

    if game_name == 'simple_breakout':
        game = simple_breakout.SimpleBreakout()

        class P(object):
            def __init__(self):
                self.screen_size = (12, 12)

            def __call__(self, frames):
                return frames

        phi = P()
    else:
        ale = ag.init(game=game_name,
                      display_screen=(visualize == 'ale'),
                      record_dir=record_dir)
        game = ag.ALEGame(ale)
        if phi_type == '4':
            phi = ag.Phi4(method=phi_method)
        elif phi_type == '1':
            phi = ag.Phi(method=phi_method)
        else:
            raise RuntimeError("Unknown phi: {phi}".format(phi=phi_type))

    if network_type == 'nature':
        build_network = network.build_nature
    elif network_type == 'nature_with_pad':
        build_network = network.build_nature_with_pad
    elif network_type == 'nips':
        build_network = network.build_nips
    elif network_type == 'nature_with_pad_he':
        build_network = network.build_nature_with_pad_he
    elif hasattr(network_type, '__call__'):
        build_network = network_type
    else:
        raise RuntimeError(
            "Unknown network: {network}".format(network=network_type))

    if updates_method == 'deepmind_rmsprop':
        updates = \
            lambda loss, params: u.deepmind_rmsprop(loss, params,
                                                          learning_rate=deepmind_rmsprop_learning_rate,
                                                          rho=deepmind_rmsprop_rho,
                                                          epsilon=deepmind_rmsprop_epsilon)
    elif updates_method == 'rmsprop':
        updates = \
            lambda loss, params: lasagne.updates.rmsprop(loss, params,
                                                         learning_rate=rmsprop_learning_rate,
                                                         rho=rmsprop_rho,
                                                         epsilon=rmsprop_epsilon)
    else:
        raise RuntimeError(
            "Unknown updates: {updates}".format(updates=updates_method))

    replay_memory = dqn.ReplayMemory(
        size=replay_memory_size) if not no_replay else None

    def create_algo():
        algo = dqn.DQNAlgo(game.n_actions(),
                           replay_memory=replay_memory,
                           build_network=build_network,
                           updates=updates,
                           screen_size=phi.screen_size)

        algo.replay_start_size = replay_start_size
        algo.final_epsilon = final_epsilon
        algo.initial_epsilon = initial_epsilon

        algo.log_frequency = log_frequency
        algo.target_network_update_frequency = target_network_update_frequency
        algo.final_exploration_frame = final_exploration_frame
        return algo

    algo_train = create_algo()
    algo_test = create_algo()
    algo_test.final_epsilon = test_epsilon
    algo_test.initial_epsilon = test_epsilon
    algo_test.epsilon = test_epsilon

    import Queue
    algo_train.mood_q = Queue.Queue() if show_mood else None

    if show_mood is not None:
        import Queue
        algo_train.mood_q = Queue.Queue()
        if show_mood == 'plot':
            plot = Plot()
        elif show_mood == "log":
            plot = Log()

        def worker():
            while True:
                item = algo_train.mood_q.get()
                plot.show(item)
                algo_train.mood_q.task_done()

        import threading
        t = threading.Thread(target=worker)
        t.daemon = True
        t.start()

    print(str(algo_train))

    if visualize != 'q':
        visualizer = q.GameNoVisualizer()
    else:
        if game_name == 'simple_breakout':
            visualizer = simple_breakout.SimpleBreakoutVisualizer(algo_train)
        else:
            visualizer = ag.ALEGameVisualizer(phi.screen_size)

    teacher = q.Teacher(game=game,
                        algo=algo_train,
                        game_visualizer=visualizer,
                        phi=phi,
                        repeat_action=repeat_action,
                        max_actions_per_game=max_actions_per_game,
                        skip_n_frames_after_lol=skip_n_frames_after_lol,
                        tester=False)

    tester = q.Teacher(game=game,
                       algo=algo_test,
                       game_visualizer=visualizer,
                       phi=phi,
                       repeat_action=repeat_action,
                       max_actions_per_game=max_actions_per_game,
                       skip_n_frames_after_lol=skip_n_frames_after_lol,
                       tester=True)

    q.teach_and_test(teacher,
                     tester,
                     n_epochs=n_training_epochs,
                     frames_to_test_on=n_test_epochs * epoch_size,
                     epoch_size=epoch_size,
                     state_dir=weights_dir,
                     algo_initial_state_file=algo_initial_state_file)