Exemplo n.º 1
0
def plot_trial(mdp_data):
    """Plot a trial given a learned MDP and policy, return as a gif file."""
    time = 0
    cart_pole = CartPole(Physics())
    state_tuple = (0., 0., 0., 0.)
    state = cart_pole.get_state(state_tuple)
    cart_pole.plot_cart(state_tuple, time)
    os.mkdir('frames')  # contain frames
    files = []
    # simulate a trial
    while True:
        time += 1
        action = choose_action(state, mdp_data)
        state_tuple = cart_pole.simulate(action, state_tuple)
        new_state = cart_pole.get_state(state_tuple)
        cart_pole.plot_cart(state_tuple, time)
        files.append(f'frame{time}.png')
        if new_state == mdp_data['num_states'] - 1:
            break
        state = new_state
    # create gif file
    with imageio.get_writer('simulation.gif', mode='I') as writer:
        for filename in files:
            image = imageio.imread(f'frames/{filename}')
            writer.append_data(image)
    # remove redundancy
    shutil.rmtree("frames")
Exemplo n.º 2
0
def main():
    # Simulation parameters
    pause_time = 0.0001
    min_trial_length_to_start_display = 100
    display_started = min_trial_length_to_start_display == 0

    NUM_STATES = 163
    GAMMA = 0.995
    TOLERANCE = 0.01
    NO_LEARNING_THRESHOLD = 20

    # Time cycle of the simulation
    time = 0

    # These variables perform bookkeeping (how many cycles was the pole
    # balanced for before it fell). Useful for plotting learning curves.
    time_steps_to_failure = []
    num_failures = 0
    time_at_start_of_current_trial = 0

    # You should reach convergence well before this
    max_failures = 500

    # Initialize a cart pole
    cart_pole = CartPole(Physics())

    # Starting `state_tuple` is (0, 0, 0, 0)
    # x, x_dot, theta, theta_dot represents the actual continuous state vector
    x, x_dot, theta, theta_dot = 0.0, 0.0, 0.0, 0.0
    state_tuple = (x, x_dot, theta, theta_dot)

    # `state` is the number given to this state, you only need to consider
    # this representation of the state
    state = cart_pole.get_state(state_tuple)
    # if min_trial_length_to_start_display == 0 or display_started == 1:
    #     cart_pole.show_cart(state_tuple, pause_time)

    mdp_data = initialize_mdp_data(NUM_STATES)

    # This is the criterion to end the simulation.
    # You should change it to terminate when the previous
    # 'NO_LEARNING_THRESHOLD' consecutive value function computations all
    # converged within one value function iteration. Intuitively, it seems
    # like there will be little learning after this, so end the simulation
    # here, and say the overall algorithm has converged.

    consecutive_no_learning_trials = 0
    while consecutive_no_learning_trials < NO_LEARNING_THRESHOLD:

        action = choose_action(state, mdp_data)

        # Get the next state by simulating the dynamics
        state_tuple = cart_pole.simulate(action, state_tuple)
        # x, x_dot, theta, theta_dot = state_tuple

        # Increment simulation time
        time = time + 1

        # Get the state number corresponding to new state vector
        new_state = cart_pole.get_state(state_tuple)
        # if display_started == 1:
        #     cart_pole.show_cart(state_tuple, pause_time)

        # reward function to use - do not change this!
        if new_state == NUM_STATES - 1:
            R = -1
        else:
            R = 0

        update_mdp_transition_counts_reward_counts(mdp_data, state, action,
                                                   new_state, R)

        # Recompute MDP model whenever pole falls
        # Compute the value function V for the new model
        if new_state == NUM_STATES - 1:

            update_mdp_transition_probs_reward(mdp_data)

            converged_in_one_iteration = update_mdp_value(
                mdp_data, TOLERANCE, GAMMA)

            if converged_in_one_iteration:
                consecutive_no_learning_trials = consecutive_no_learning_trials + 1
            else:
                consecutive_no_learning_trials = 0

        # Do NOT change this code: Controls the simulation, and handles the case
        # when the pole fell and the state must be reinitialized.
        if new_state == NUM_STATES - 1:
            num_failures += 1
            if num_failures >= max_failures:
                break
            print('[INFO] Failure number {}'.format(num_failures))
            time_steps_to_failure.append(time - time_at_start_of_current_trial)
            # time_steps_to_failure[num_failures] = time - time_at_start_of_current_trial
            time_at_start_of_current_trial = time

            if time_steps_to_failure[num_failures -
                                     1] > min_trial_length_to_start_display:
                display_started = 1

            # Reinitialize state
            # x = 0.0
            x = -1.1 + np.random.uniform() * 2.2
            x_dot, theta, theta_dot = 0.0, 0.0, 0.0
            state_tuple = (x, x_dot, theta, theta_dot)
            state = cart_pole.get_state(state_tuple)
        else:
            state = new_state

    # plot the learning curve (time balanced vs. trial)
    log_tstf = np.log(np.array(time_steps_to_failure))
    plt.plot(np.arange(len(time_steps_to_failure)), log_tstf, 'k')
    window = 30
    w = np.array([1 / window for _ in range(window)])
    weights = lfilter(w, 1, log_tstf)
    x = np.arange(window // 2, len(log_tstf) - window // 2)
    plt.plot(x, weights[window:len(log_tstf)], 'r--')
    plt.xlabel('Num failures')
    plt.ylabel('Log of num steps to failure')
    plt.savefig('./control.pdf')