Ejemplo n.º 1
0
def training_main():

    # reproduced
    torch.cuda.manual_seed(123)
    random.seed(0)
    np.random.seed(0)
    # --- 1. Environment --- #
    game_state = FlappyBird()
    # --- 2. training --- #
    agent_dqn = AgentDQN()
    # record the data
    td = SummaryWriter()
    iter = 0
    total_step = 0
    while iter < MAX_EPISODE:
        iter += 1
        # choose action by epsilon-greedy, epsilon update
        image, reward, terminal = game_state.next_frame(0)
        image = pre_processing(image[:game_state.screen_width, :int(game_state.base_y)], IMAGE_SIZE, IMAGE_SIZE)
        # 288*512*3 --> 1*84*84
        image = torch.from_numpy(image).to(device=DEVICE)
        # 1*4*84*84   batch*frame*width*height
        state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]
        epsilon = update_epsilon(iter, MAX_EPISODE)
        episode_reward = 0
        episode_step = 0
        terminal = False
        start_time = time.time()
        while not terminal:
            action = agent_dqn.choose_action(state, epsilon)
            # interact with the environment
            next_image, reward, terminal = game_state.next_frame(action)
            episode_reward += reward
            episode_step += 1
            total_step += 1
            next_image = pre_processing(next_image[:game_state.screen_width, :int(game_state.base_y)], IMAGE_SIZE, IMAGE_SIZE)
            next_image = torch.from_numpy(next_image).to(device=DEVICE)
            # instead of one frame
            next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]
            agent_dqn.replay_buffer.add(state, action, reward, next_state, terminal)
            if len(agent_dqn.replay_buffer) < 2*BATCH_SIZE:
                # need enough experience
                iter = 1
                episode_reward = 0
                episode_step = 0
                total_step = 0
                continue
            # --- update the model weight --- #
            loss = agent_dqn.update_model(total_step)
            # update the states
            state = next_state
            if terminal:
                print("Iteration: {}/{}, Epsilon {}, Reward: {}, Step: {}, time: {}"
                      .format(iter, MAX_EPISODE, round(epsilon,3), round(episode_reward,2), episode_step, round(time.time()-start_time, 2)))
                td.add_scalar('Train/Reward', episode_reward, iter)
                td.add_scalar('Train/Step', episode_step, iter)
                if iter % 5000 == 0:
                    torch.save(agent_dqn.q_model, "{}/flappy_bird_{}".format('./model', iter))
    torch.save(agent_dqn.q_model, "{}/flappy_bird".format('./model'))
def main():
    clock = pygame.time.Clock()

    N_EPOCHS = 1000
    GAMMA = 0.99
    N_BIRD = 64
    S_BATCH = 256

    env = FlappyBird(N_BIRD)

    main_model = Model()
    target_model = Model()

    memory = Memory()
    agent = Agent()

    for epoch in range(1, N_EPOCHS + 1):
        print('Epoch: {}'.format(epoch))

        env.reset()
        states, rewards, finished = env.random_step()
        target_model.model.set_weights(main_model.model.get_weights())

        running = True
        while running:
            clock.tick(60)

            actions = []
            for state in states:
                actions.append(agent.get_action(state, epoch, main_model))

            next_states, rewards, finished = env.step(actions)
            for state, reward, action, next_state in zip(
                    states, rewards, actions, next_states):
                memory.add((state, action, reward, next_state))

            states = next_states

            if len(memory.buffer) % S_BATCH == 0:
                main_model.replay(memory, env.n_bird, GAMMA, target_model)

            target_model.model.set_weights(main_model.model.get_weights())

            if not len(env.birds):
                running = False
                break

            env.draw()

            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    running = False

        print('\tScore: {}'.format(env.score))

    pygame.quit()

    env = FlappyBird()
Ejemplo n.º 3
0
def main():
    clock = pygame.time.Clock()

    game = FlappyBird(1)

    running = True
    while running:
        clock.tick(60)

        actions = []
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
            if event.type == pygame.KEYDOWN:
                actions.append(array([0, 1]))
            
        if len(actions) == 0:
            actions.append(array([1, 0]))

        game.step(actions)

        if not len(game.birds):
            running = False
            break

        game.draw()

    print('\tScore: {}'.format(game.score))
    
    pygame.quit()
Ejemplo n.º 4
0
def main():
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.connect((HOST, PORT))
    game = FlappyBird()
    while True:
        data = {
            'state': game.s_t,
            'reward': game.reward,
            'terminal': game.terminal
        }
        data_str = cPickle.dumps(data)
        send_msg(sock, data_str)
        action_id = int(sock.recv(1))
        game.process(action_id)
        game.update()
        if data['terminal']:
            game.reset()
    return
Ejemplo n.º 5
0
import sys
sys.path.append("/home/fedmag/Projects/Q-bird/src/game/")
sys.path.append("/home/fedmag/Projects/Q-bird/src/ai/")

# from agent import Agent
from flappy_bird import FlappyBird

# agent = Agent()

game = FlappyBird()
game.run()

print(sys.path)
Ejemplo n.º 6
0
from model import model
from qlearning4k import Agent
from flappy_bird import FlappyBird

game = FlappyBird(frame_rate=10000, sounds=False)
agent = Agent(model, memory_size=100000)
agent.train(game,
            epsilon=[0.01, 0.00001],
            epsilon_rate=0.3,
            gamma=0.99,
            nb_epoch=1000000,
            batch_size=32,
            checkpoint=250)
Ejemplo n.º 7
0
from model import model
from qlearning4k import Agent
from flappy_bird import FlappyBird

game = FlappyBird(frame_rate=30, sounds=True)
model.load_weights('weights.dat')
agent = Agent(model)
agent.play(game, nb_epoch=100, epsilon=0.01, visualize=False)
Ejemplo n.º 8
0
import neat
import pygame

from flappy_bird import FlappyBird

CFG_PATH = 'neat_config.txt'
NUM_BIRD = 50

env = FlappyBird(NUM_BIRD)


def gen(genomes, config):
    clock = pygame.time.Clock()
    env.reset()

    nets = []
    ge = []
    for _, g in genomes:
        nets.append(neat.nn.FeedForwardNetwork.create(g, config))
        g.fitness = 0
        ge.append(g)

    while len(env.birds) > 0:
        clock.tick(60)
        for pipe in env.pipes:
            if pipe.move():
                env.score += 1

                for g in ge:
                    g.fitness += 3