Exemplo n.º 1
0
def main(environment, file_out, weight_file, action_value, f_duration, watch,
         save):
    use_CNN = True
    env = gym.make(environment)
    if use_CNN is True:
        state_size = (88, 80, 1)
    else:
        state_size = env.observation_space.shape[0]

    action_size = env.action_space.n

    # Stack group_size number of atari images
    group_size = 4

    # The following are hard-coded for now, but original image
    # is scaled by preprocssing down to 88, 80, 1 and we combine
    # 4 of them to get a batch of images
    # Note that the "1" argument is the number of copies of environment to train simultaneously
    runner = Runner(environment, 1, group_size)

    online_dqn = DQAgent(state_size,
                         action_size,
                         loss="huber_loss",
                         action=action_value,
                         use_CNN=True)
    target_dqn = DQAgent(state_size,
                         action_size,
                         loss="huber_loss",
                         action=action_value,
                         use_CNN=True)
    online_dqn.model.load_weights(weight_file)
    target_dqn.update_target_weights(online_dqn.model)

    print("Playing {} using weights {} and action {}").format(
        environment, weight_file, action_value)

    epsilon_max = .1
    online_dqn.epsilon = epsilon_max
    done = False

    done_flags = True
    lives = 5

    state = runner.reset_all()
    cumulative_reward = 0
    global_step = 0
    if save is True:
        images = []
    while not done:
        global_step += 1

        q_values = online_dqn.model.predict(state)[0]

        if done_flags is False:
            action = online_dqn.action(q_values, online_dqn.epsilon)
        else:
            random_fire_actions = np.random.randint(1, 3)
            for i in range(random_fire_actions):
                action = 1
                next_state, reward, done, info = runner.step([action])
            state = next_state
            done_flags = False
            continue

        next_state, reward, done, info = runner.step([action])
        if watch is True:
            runner.render()
            sleep(.05)
        if save is True:
            images.append(runner.render(mode="rgb_array"))
        cumulative_reward += reward

        # Losing a life is bad, so say so
        remaining_lives = info[0]["ale.lives"]
        life_lost_flag = bool(lives - remaining_lives)
        lives = remaining_lives

        done_flags = False
        if life_lost_flag or done:
            done_flags = True

        state = next_state

        if done:
            print("Score {}, Total steps {}").format(cumulative_reward,
                                                     global_step)
            break
    if save is True:
        imageio.mimsave(file_out, images, duration=f_duration)
    return 0
def main(environment, loss_function, action_value, use_CNN, total_games,
         burn_in, training_interval, target_update_interval, save_interval,
         num_epochs, batch_size, learning_rate, epsilon_max, epsilon_min,
         epsilon_decay_steps, gamma, memory_size, log_interval):
    # Set up logging
    start_time = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
    log_dir, parameter_file, score_file = setup_logs(environment, start_time)

    ################################
    # Save our training parameters #
    line = "loss_function: {}\nactionvalue: {}\ntotal_games: {}\ntraining_interval: {}\ntarget_update_interval: {}\nsave_interval: {}\nnum_epochs: {}\nbatch_size: {}\nlearning_rate: {}\nepsilon_max: {}\nepsilon_min: {}\nepsilon_decay_steps: {}\ngamma: {}\nmemory_size: {}\nlog_interval: {}\n".format(
        loss_function, action_value, total_games, training_interval,
        target_update_interval, save_interval, num_epochs, batch_size,
        learning_rate, epsilon_max, epsilon_min, epsilon_decay_steps, gamma,
        memory_size, log_interval)
    os.write(parameter_file, line)
    ################################

    # Set up our environment
    env = gym.make(environment)
    if use_CNN is True:
        state_size = (88, 80, 1)
    else:
        state_size = env.observation_space.shape[0]

    action_size = env.action_space.n

    # Stack group_size number of atari images
    group_size = 4

    # The following are hard-coded for now, but original image
    # is scaled by preprocssing down to 88, 80, 1 and we combine
    # 4 of them to get a batch of images
    # Note that the "1" argument is the number of copies of environment to train simultaneously
    runner = Runner(environment, 1, group_size)

    # Note that if use_CNN = True, then the state_size is ignored!
    online_dqn = DQAgent(state_size,
                         action_size,
                         loss=loss_function,
                         action=action_value,
                         learning_rate=learning_rate,
                         epsilon=epsilon_max,
                         gamma=gamma,
                         memory_size=memory_size,
                         use_CNN=use_CNN)
    target_dqn = DQAgent(state_size,
                         action_size,
                         loss=loss_function,
                         action=action_value,
                         learning_rate=learning_rate,
                         epsilon=epsilon_max,
                         gamma=gamma,
                         memory_size=memory_size,
                         use_CNN=use_CNN)

    target_dqn.update_target_weights(online_dqn.model)

    # Include a threshold value to stop training
    solved_thresh = 500

    print("Playing {} using loss {} and action {}").format(
        environment, loss_function, action_value)

    done = False
    score_history = deque([], maxlen=log_interval)
    max_score = 0
    global_step = 0
    game_num = 1

    state = runner.reset_all()
    cumulative_reward = 0
    lives = 5
    done_flags = True

    while game_num < total_games:
        # Use target_dqn to make Q-values
        # online_dqn then takes epsilon-greedy action
        global_step += 1

        q_values = online_dqn.model.predict(state)[0]

        # If we lose a life, start with a few FIRE actions
        # to get started again. Random to avoid learning
        # fixed sequence of actions
        if done_flags is False:
            action = online_dqn.action(q_values, online_dqn.epsilon)
        else:
            random_fire_actions = np.random.randint(1, 3)
            for i in range(random_fire_actions):
                action = FIRE_ACTION_NUMBER
                next_state, reward, done, info = runner.step([action])
            state = next_state
            done_flags = False
            continue

        next_state, reward, done, info = runner.step([action])
        cumulative_reward += reward[0]

        # Losing a life is bad, so say so
        remaining_lives = info[0]["ale.lives"]
        life_lost_flag = bool(lives - remaining_lives)
        lives = remaining_lives

        done_flags = False
        if life_lost_flag or done:
            done_flags = True

        # Store the result in memory so we can replay later
        online_dqn.remember(state, action, reward, next_state, done_flags)
        state = next_state

        if done:
            score_history.append(cumulative_reward)

            if cumulative_reward > max_score:
                max_score = cumulative_reward

            if game_num % log_interval == 0:
                os.write(score_file, str(list(score_history)) + '\n')
                print(
                    "Completed game {}/{}, global step {}, last {} games average: {:.3f}, max: {}, min: {}. Best so far {}. Epsilon: {:.3f}"
                    .format(game_num, total_games, global_step, log_interval,
                            np.average(score_history), np.max(score_history),
                            np.min(score_history), max_score,
                            online_dqn.epsilon))

            game_num += 1
            cumulative_reward = 0
            lives = 5
            state = runner.reset_all()

            # If we have an average score > 195.0 over 100 consecutive rounds, we have solved CartPole!
            if game_num > 100:
                avg_last_100 = np.average(score_history)

                if avg_last_100 > solved_thresh:
                    stop_time = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
                    print("Congratulations! {} has been solved after {} games."
                          ).format(environment, game_num)
                    online_dqn.model.save(
                        os.path.join(
                            log_dir,
                            "online_dqn_{}_solved.h5".format(environment)))
                    line = "Training start: {}\nTraining ends:  {}\n".format(
                        start_time, stop_time)
                    os.write(parameter_file, line)
                    os.write(score_file, str(list(score_history)) + '\n')
                    os.close(parameter_file)
                    os.close(score_file)
                    return 0

        # For the first burn_in number of rounds, just populate memory
        if global_step < burn_in:
            continue
        # Once we are past the burn_in exploration period, we start to train
        # This is a linear decay that goes from epsilon_max to epsion_min in epsilon_decay_steps
        online_dqn.epsilon = max(
            epsilon_max +
            ((global_step - burn_in) / float(epsilon_decay_steps)) *
            (epsilon_min - epsilon_max), epsilon_min)

        if (global_step % training_interval == 0):
            replay_from_memory(online_dqn, target_dqn, batch_size, num_epochs)

        if (global_step % target_update_interval == 0):
            target_dqn.update_target_weights(online_dqn.model)

        if global_step % save_interval == 0:
            online_dqn.model.save(os.path.join(log_dir, "online_dqn" + ".h5"))

    ##################################################################
    # If we're here, then we finished our training without solution #
    # Let's save the most recent models and make the plots anyway   #
    #################################################################
    stop_time = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
    online_dqn.model.save(
        os.path.join(log_dir, "online_dqn_" + str(global_step) + ".h5"))

    print("Done! Completed game {}/{}, global_step {}".format(
        game_num, total_games, global_step))
    line = "\n \nTraining start: {}\nTraining ends:  {}\n \n".format(
        start_time, stop_time)
    os.write(parameter_file, line)
    if game_num % log_interval != 0:
        os.write(score_file,
                 str(list(score_history)[:game_num % log_interval]) + '\n')
    os.close(parameter_file)
    os.close(score_file)
    return 0