Example #1
0
def main():
    '''
    the function for training
    '''
    agent = DQN()
    env = FlappyBird()
    env.reset()

    s_t = env.s_t

    while True:
        action_id, action_q = agent.epsilon_greedy(s_t)
        env.process(action_id)

        action = np.zeros(ACTIONS_DIM)
        action[action_id] = 1

        s_t1, reward, terminal = (env.s_t1, env.reward, env.terminal)
        agent.perceive(s_t, action, reward, s_t1, terminal)

        if agent.global_t % 10 == 0:
            print 'global_t:', agent.global_t, '/ epsilon:', agent.epsilon, '/ terminal:', terminal, \
                '/ action:', action_id, '/ reward:', reward, '/ q_value:', action_q

        if terminal:
            env.reset()
        s_t = s_t1
        # env.update()  # it doesn't work, and cause Q NaN
        # break
    return
Example #2
0
def main():
    '''
    the function for training
    '''
    agent = DQN()
    game = FlappyBird()
    game.reset()

    s_t = game.s_t

    while agent.global_t < MAX_TIME_STEP:
        action_id, action_q = agent.epsilon_greedy(s_t)
        game.process(action_id)
        action = np.zeros(ACTIONS_DIM)
        action[action_id] = 1
        s_t1, reward, terminal = (game.s_t1, game.reward, game.terminal)
        agent.perceive(s_t, action, reward, s_t1, terminal)

        if agent.global_t % 10 == 0:
            print 'global_t:', agent.global_t, '/ epsilon:', agent.epsilon, '/ terminal:', game.terminal, \
                '/ action:', action_id, '/ reward:', game.reward, '/ q_value:', action_q

        if game.terminal:
            game.reset()
        # s_t <- s_t1
        s_t = s_t1
        # game.update()

    return
Example #3
0
def main():
    '''
    the function for training
    '''
    agent = DRQN()
    env = FlappyBird()

    while True:
        env.reset()
        episode_buffer = []
        lstm_state = (np.zeros([1, LSTM_UNITS]), np.zeros([1, LSTM_UNITS]))
        s_t = env.s_t
        while not env.terminal:
            # action_id = random.randint(0, 1)
            action_id, action_q, lstm_state = agent.epsilon_greedy(
                s_t, lstm_state)
            env.process(action_id)

            action = np.zeros(ACTIONS_DIM)
            action[action_id] = 1
            s_t1, reward, terminal = (env.s_t1, env.reward, env.terminal)
            # frame skip
            episode_buffer.append((s_t, action, reward, s_t1, terminal))
            agent.perceive(s_t, action, reward, s_t1, terminal)
            if agent.global_t % 10 == 0:
                print 'global_t:', agent.global_t, '/ epsilon:', agent.epsilon, '/ terminal:', terminal, \
                    '/ action:', action_id, '/ reward:', reward, '/ q_value:', action_q

            # s_t <- s_t1
            s_t = s_t1
            if len(episode_buffer) >= 50:
                # start a new episode buffer, in case of an over-long memory
                agent.replay_buffer.add(episode_buffer)
                episode_buffer = []
                print '----------- episode buffer > 100---------'
        # reset the state
        if len(episode_buffer) > LSTM_MAX_STEP:
            agent.replay_buffer.add(episode_buffer)
        print 'episode_buffer', len(episode_buffer)
        print 'replay_buffer.size:', agent.replay_buffer.size()
        # break
    return
Example #4
0
import sys
import cPickle
import math

from game.flappy_bird import FlappyBird
from SocketServer import BaseRequestHandler, UDPServer

flapp_bird = FlappyBird()


class UDPHandler(BaseRequestHandler):
    def handle(self):
        action = self.request[0]
        action = cPickle.loads(action)
        socket = self.request[1]

        global flapp_bird
        x_t, reward, terminal = flapp_bird.frame_step(action)
        data = cPickle.dumps((x_t, reward, terminal))

        # not larger than 8192 due to the limitation of MXU of udp
        buffer_size = 8192
        total_size = len(data)
        block_num = int(math.ceil(total_size / float(buffer_size)))

        # send the length
        offset = 0
        header = {
            "buffer_size": buffer_size,
            "total_size": total_size,
            "block_num": block_num
Example #5
0
from game.flappy_bird import FlappyBird
from qtable import QLearningTable

def get_args():
    parser = argparse.ArgumentParser("""Implementation of Q-Learning to play Flappy Bird""")
    parser.add_argument("--qtable",
                        type=str,
                        default="qtable/q_table.json",
                        help="Specify the Q table.")
    
    args = parser.parse_args()
    return args

args = get_args()
fb = FlappyBird()
qlt = QLearningTable(e_greedy=1)

qlt.q_table = json.load(open(args.qtable, "r"))
observation = fb.next_frame(random.choice([0, 1]))[-1]

while 1:
    qlt.check_state_exist(observation)
    action = qlt.choose_action(observation)
    reward, terminal, score, observation_ = fb.next_frame(action, text='')[1:]
    observation = observation_

    if terminal:
        fb.__init__()
        observation = fb.next_frame(random.choice([0, 1]))[-1]
Example #6
0
def train(args):
    torch.manual_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Network()
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.MSELoss()
    fb = FlappyBird()

    image, reward, terminal, score = fb.next_frame(randint(0, 1))[:-1]

    img = image_pre_processing(image, args.img_size, args.img_size)
    img = torch.from_numpy(img)
    img.to(device)
    state = torch.cat(tuple(img for _ in range(4)))[None, :, :, :]

    replay_memory = []
    episode = 0

    while 1:
        prediction = model(state)[0]
        epsilon = args.initial_epsilon
        if np.random.uniform() > epsilon:
            action = torch.argmax(prediction)
        else:
            action = randint(0, 1)

        next_image, reward, terminal = fb.next_frame(action, '')[:3]

        if terminal:
            fb.__init__()
            next_image, reward, terminal = fb.next_frame(action, '')[:3]

        next_img = image_pre_processing(next_image, args.img_size,
                                        args.img_size)
        next_img = torch.from_numpy(next_img)
        next_state = torch.cat((state[0, 1:, :, :], next_img))[None, :, :, :]
        replay_memory.append([state, action, reward, next_state, terminal])
        if len(replay_memory) > args.replay_memory_size:
            del replay_memory[0]

        batch = sample(replay_memory, min(len(replay_memory), args.batch_size))
        state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = zip(
            *batch)

        state_batch = torch.cat(tuple(state for state in state_batch))
        action_batch = torch.from_numpy(
            np.array([[1, 0] if action == 0 else [0, 1]
                      for action in action_batch],
                     dtype=np.float32))
        reward_batch = torch.from_numpy(
            np.array(reward_batch, dtype=np.float32)[:, None])
        next_state_batch = torch.cat(tuple(state
                                           for state in next_state_batch))

        state_batch.to(device)
        action_batch.to(device)
        reward_batch.to(device)
        next_state_batch.to(device)

        current_prediction_batch = model(state_batch)
        next_prediction_batch = model(next_state_batch)

        y_batch = torch.cat(
            tuple(reward if terminal else reward + args.gamma * torch.max(prediction) for reward, terminal, prediction in \
                zip(reward_batch, terminal_batch, next_prediction_batch))
        )
        q_value = torch.sum(current_prediction_batch * action_batch, dim=1)
        optimizer.zero_grad()
        loss = criterion(q_value, y_batch)
        loss.backward()
        optimizer.step()

        state = next_state
        episode += 1