def plot_trial(mdp_data): """Plot a trial given a learned MDP and policy, return as a gif file.""" time = 0 cart_pole = CartPole(Physics()) state_tuple = (0., 0., 0., 0.) state = cart_pole.get_state(state_tuple) cart_pole.plot_cart(state_tuple, time) os.mkdir('frames') # contain frames files = [] # simulate a trial while True: time += 1 action = choose_action(state, mdp_data) state_tuple = cart_pole.simulate(action, state_tuple) new_state = cart_pole.get_state(state_tuple) cart_pole.plot_cart(state_tuple, time) files.append(f'frame{time}.png') if new_state == mdp_data['num_states'] - 1: break state = new_state # create gif file with imageio.get_writer('simulation.gif', mode='I') as writer: for filename in files: image = imageio.imread(f'frames/{filename}') writer.append_data(image) # remove redundancy shutil.rmtree("frames")
def main(): # Simulation parameters pause_time = 0.0001 min_trial_length_to_start_display = 100 display_started = min_trial_length_to_start_display == 0 NUM_STATES = 163 GAMMA = 0.995 TOLERANCE = 0.01 NO_LEARNING_THRESHOLD = 20 # Time cycle of the simulation time = 0 # These variables perform bookkeeping (how many cycles was the pole # balanced for before it fell). Useful for plotting learning curves. time_steps_to_failure = [] num_failures = 0 time_at_start_of_current_trial = 0 # You should reach convergence well before this max_failures = 500 # Initialize a cart pole cart_pole = CartPole(Physics()) # Starting `state_tuple` is (0, 0, 0, 0) # x, x_dot, theta, theta_dot represents the actual continuous state vector x, x_dot, theta, theta_dot = 0.0, 0.0, 0.0, 0.0 state_tuple = (x, x_dot, theta, theta_dot) # `state` is the number given to this state, you only need to consider # this representation of the state state = cart_pole.get_state(state_tuple) # if min_trial_length_to_start_display == 0 or display_started == 1: # cart_pole.show_cart(state_tuple, pause_time) mdp_data = initialize_mdp_data(NUM_STATES) # This is the criterion to end the simulation. # You should change it to terminate when the previous # 'NO_LEARNING_THRESHOLD' consecutive value function computations all # converged within one value function iteration. Intuitively, it seems # like there will be little learning after this, so end the simulation # here, and say the overall algorithm has converged. consecutive_no_learning_trials = 0 while consecutive_no_learning_trials < NO_LEARNING_THRESHOLD: action = choose_action(state, mdp_data) # Get the next state by simulating the dynamics state_tuple = cart_pole.simulate(action, state_tuple) # x, x_dot, theta, theta_dot = state_tuple # Increment simulation time time = time + 1 # Get the state number corresponding to new state vector new_state = cart_pole.get_state(state_tuple) # if display_started == 1: # cart_pole.show_cart(state_tuple, pause_time) # reward function to use - do not change this! if new_state == NUM_STATES - 1: R = -1 else: R = 0 update_mdp_transition_counts_reward_counts(mdp_data, state, action, new_state, R) # Recompute MDP model whenever pole falls # Compute the value function V for the new model if new_state == NUM_STATES - 1: update_mdp_transition_probs_reward(mdp_data) converged_in_one_iteration = update_mdp_value( mdp_data, TOLERANCE, GAMMA) if converged_in_one_iteration: consecutive_no_learning_trials = consecutive_no_learning_trials + 1 else: consecutive_no_learning_trials = 0 # Do NOT change this code: Controls the simulation, and handles the case # when the pole fell and the state must be reinitialized. if new_state == NUM_STATES - 1: num_failures += 1 if num_failures >= max_failures: break print('[INFO] Failure number {}'.format(num_failures)) time_steps_to_failure.append(time - time_at_start_of_current_trial) # time_steps_to_failure[num_failures] = time - time_at_start_of_current_trial time_at_start_of_current_trial = time if time_steps_to_failure[num_failures - 1] > min_trial_length_to_start_display: display_started = 1 # Reinitialize state # x = 0.0 x = -1.1 + np.random.uniform() * 2.2 x_dot, theta, theta_dot = 0.0, 0.0, 0.0 state_tuple = (x, x_dot, theta, theta_dot) state = cart_pole.get_state(state_tuple) else: state = new_state # plot the learning curve (time balanced vs. trial) log_tstf = np.log(np.array(time_steps_to_failure)) plt.plot(np.arange(len(time_steps_to_failure)), log_tstf, 'k') window = 30 w = np.array([1 / window for _ in range(window)]) weights = lfilter(w, 1, log_tstf) x = np.arange(window // 2, len(log_tstf) - window // 2) plt.plot(x, weights[window:len(log_tstf)], 'r--') plt.xlabel('Num failures') plt.ylabel('Log of num steps to failure') plt.savefig('./control.pdf')