def get_similar_rzn_ids(query_cs1_a_s2, g, vel_field_data, nmodes): cs1, a, s2 = query_cs1_a_s2 t, x, y, i, j = cs1 matched_rzns = set() for rzn in train_id_list: # vx = Vx_rzns[rzn,i,j] # vy = Vy_rzns[rzn,i,j] vx, vy = extract_velocity(vel_field_data, t, i, j, rzn) g.set_state((t, i, j), xcoord = x, ycoord = y) g.move_exact(a, vx, vy) # assert(t == s2[0]), "damn" if g.current_state() == s2: matched_rzns.add(rzn) return matched_rzns
def update_Q_in_future_kth_rzn(g, Q, N, vel_field_data, nmodes, s1, rzn, eps): """ almost same as from Run_Q_learning_episode() s2: current state in whilie simulating roolout """ t, i, j = s1 g.set_state(s1) dummy_policy = None #stochastic_action_eps_greedy() here, uses Q. so policy is ingnored anyway # a1 = stochastic_action_eps_greedy(policy, s1, g, eps, Q=Q) count = 0 max_delQ = 0 # while not g.is_terminal() and g.if_within_TD_actionable_time(): while not g.is_terminal(s1) and not g.if_edge_state(s1) and g.if_within_actionable_time(): """Will have to change this for general time""" t, i, j = s1 a1 = stochastic_action_eps_greedy(dummy_policy, s1, g, eps, Q=Q) vx, vy = extract_velocity(vel_field_data, t, i, j, rzn) r = g.move_exact(a1, vx, vy, rzn) # r = g.move_exact(a1, Vx_rzns[rzn, i, j], Vy_rzns[rzn, i, j]) s2 = g.current_state() # if g.is_terminal() or (not g.if_within_actionable_time()): alpha = ALPHA / N[s1][a1] N[s1][a1] += N_inc #maxQsa = 0 if next state is a terminal state/edgestate/outside actionable time max_q_s2_a2= 0 if not g.is_terminal(s2) and not g.if_edge_state(s2) and g.if_within_actionable_time(): a2, max_q_s2_a2 = max_dict(Q[s2]) old_qsa = Q[s1][a1] Q[s1][a1] = Q[s1][a1] + alpha*(r + max_q_s2_a2 - Q[s1][a1]) if np.abs(old_qsa - Q[s1][a1]) > max_delQ: max_delQ = np.abs(old_qsa - Q[s1][a1]) s1 = s2 # t, i, j = s1 return Q, N
def build_experience_buffer(grid, vel_field_data, nmodes, paths, sampling_interval, train_path_ids, num_actions ): exp_buffer_all_trajs = [] # print("$$$ CHECK train_path_ids: ", train_path_ids) s_next_badcount = 0 #for counting how many times we dont reach terminal state double_count = 0 s_next_bad_idlist = [] doouble_idlist =[] coord_traj_5k =[] state_traj_5k =[] outbound_count = 0 for k in train_path_ids: exp_buffer_kth_traj = [] # Vxt = Vx_rzns[k, :, :] #kth rzn velocity data # Vyt = Vy_rzns[k, :, :] trajectory = paths[k, 0] #kth trajectory, array of shape (n,2) # append starting point to traj s_t = int(0) coord_traj = [] state_traj = [] for i in range(0, len(trajectory), sampling_interval): coord_traj.append((s_t, trajectory[i][0], trajectory[i][1])) #add first coordinate from the trajectory to coord_traj s_i, s_j = compute_cell(grid, trajectory[i]) # compute indices corrsponding to first coord state_traj.append((s_t, s_i, s_j)) #add first index to the list continating trajectory as state indices s_t += 1 """ # make dictionary states mapping to coords. and choose middle coord to append to traj traj_dict = OrderedDict() for j in range(0, len(trajectory)): s_i, s_j = compute_cell(grid, trajectory[j]) s = (s_i, s_j) c = (trajectory[j][0], trajectory[j][1]) if not traj_dict.get(s): traj_dict[s] = [c] else: traj_dict[s].append(c) keys = list(traj_dict.keys()) keys.remove(keys[0]) #remove first and last keys (states). keys.remove(keys[-1]) #They are appended separately for s in keys: state_traj.append(s) l = len(traj_dict[s]) coord_traj.append(traj_dict[s][int(l//2)]) """ # add last point to the trajectories coord_traj.append((s_t, trajectory[-1][0], trajectory[-1][1])) s_i, s_j = compute_cell(grid, trajectory[-1]) state_traj.append((s_t, s_i, s_j)) # reverse order now to save it from doing so while leaning through "reverse" method state_traj.reverse() coord_traj.reverse() coord_traj_5k.append(coord_traj) state_traj_5k.append(state_traj) #build buffer # print("check warning, rzn: ", k) # print("s1, p1, p2, Vxt, Vyt") for i in range(len(state_traj)-1): # till len -1 because there is i+1 inside the loop s1=state_traj[i+1] s2=state_traj[i] #IMP: perhaps only used as dummy if not (grid.is_terminal(s1)): t ,m, n = s1 # m, n = s1 #vx=Vxt[t,i,j] # TODO: the below is only valid as this is a stationary vel feild. hence, t idx not needed try: vx, vy = extract_velocity(vel_field_data, t, m, n, k) except: print("$$$ CHECK t,i,j: ", t, m, n) outbound_count+=1 # vx = Vxt[m, n] # vy = Vyt[m, n] p1=coord_traj[i+1] p2=coord_traj[i] grid.set_state(s1, xcoord=p1[1], ycoord=p1[2]) """COMMENTING THIS STATEMENT BELOW""" # if (s1[1],s1[2])!=(s2[1],s2[2]): # print(s1,p1,p2, vx, vy) a1 = Calculate_action(s1, s2, p1, p2, vx, vy, grid, coord_traj_theta= False) r1 = grid.move_exact(a1, vx, vy, k) s_next = grid.current_state() # if grid.current_state() != s2: # print("**** mismatch: ",s1, a1, s2, grid.current_state()) if i == 0: if grid.is_terminal(s_next): s_next_badcount += 1 s_next_bad_idlist.append((k,s1,a1,r1,s_next,s2)) exp_buffer_kth_traj.append([s1, a1, r1, s_next]) else: double_count += 1 doouble_idlist.append((k, s1, s2)) # if k == 0: # # print("$$$$ CHECk state_traj: ", state_traj) # pass # # print("$$$$ CHECk 0th buffer: ") # # for sars in exp_buffer_kth_traj: # # print(sars) #append kth-traj-list to master list exp_buffer_all_trajs.append(exp_buffer_kth_traj) # picklePolicy(doouble_idlist, 'doouble_idlist' ) # picklePolicy( s_next_bad_idlist, 's_next_bad_idlist') # picklePolicy( coord_traj_5k, 'coord_traj_5k') # picklePolicy( state_traj_5k,'state_traj_5k ') print("$$$$$$$ s_next_badcount: ", s_next_badcount) print("$$$$$$$ double_count: ", double_count) print("$$$$$ dlen(train_path_ids) ", len(train_path_ids)) print('$$$$$ outbound_count= ', outbound_count) return exp_buffer_all_trajs
def run_and_plot_onboard_routing_episodes(setup_grid_params, Q, N, fpath, fname): # g, xs, ys, X, Y, vel_field_data, nmodes, useful_num_rzns, paths, params, param_str g, xs, ys, X, Y, vel_field_data, nmodes, _, paths, _, _ = setup_grid_params g.make_bcrumb_dict(paths, train_id_list) gcopy = copy.deepcopy(g) # Copy Q to Qcopy msize = 15 # fsize = 3 fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(1, 1, 1) ax.set_xlim(0,1) ax.set_ylim(0,1) minor_ticks = [i/100.0 for i in range(101) if i%20!=0] major_ticks = [i/100.0 for i in range(0,120,20)] ax.set_xticks(minor_ticks, minor=True) ax.set_xticks(major_ticks, minor=False) ax.set_yticks(major_ticks, minor=False) ax.set_yticks(minor_ticks, minor=True) ax.grid(b= True, which='both', color='#CCCCCC', axis='both',linestyle = '-', alpha = 0.5) ax.tick_params(axis='both', which='both', labelsize=6) ax.set_xlabel('X (Non-Dim)') ax.set_ylabel('Y (Non-Dim)') st_point= g.start_state plt.scatter(g.xs[st_point[1]], g.ys[g.ni - 1 - st_point[0]], marker = 'o', s = msize, color = 'k', zorder = 1e5) plt.scatter(g.xs[g.endpos[1]], g.ys[g.ni - 1 - g.endpos[0]], marker = '*', s = msize*2, color ='k', zorder = 1e5) plt.gca().set_aspect('equal', adjustable='box') # plt.quiver(X, Y, Vx_rzns[0, :, :], Vy_rzns[0, :, :]) t_list=[] traj_list = [] bad_count = 0 # for k in range(len(test_id_list)): for k in range(n_test_paths_range[0], n_test_paths_range[1]): Qcopy = copy.deepcopy(Q) Ncopy = copy.deepcopy(N) rzn = test_id_list[k] init_list = [None for i in range(rollout_interval)] cs1as2_list = deque(init_list) #to keep a fixed lenght list representation print("-------- In rzn ", rzn, " of test_id_list ---------") g.set_state(g.start_state) dont_plot =False bad_flag = False xtr = [] ytr = [] s1 = g.start_state t, i, j = s1 cs1 = (t, g.x, g.y ,i, j) a, q_s_a = max_dict(Qcopy[s1]) xtr.append(g.x) ytr.append(g.y) loop_count = 0 # while not g.is_terminal() and g.if_within_actionable_time() and g.current_state: # print("__CHECK__ t, i, j") while True: loop_count += 1 vx, vy = extract_velocity(vel_field_data, t, i, j, rzn) r = g.move_exact(a, vx, vy) # r = g.move_exact(a, Vx_rzns[rzn, i, j], Vy_rzns[rzn, i, j]) s2 = g.current_state() (t, i, j) = s2 cs1_a_s2 = (cs1, a, s2) # keep n latest transitions where n = rollout_interval cs1as2_list.pop() cs1as2_list.appendleft(cs1_a_s2) xtr.append(g.x) ytr.append(g.y) if g.if_edge_state((i,j)): bad_count += 1 # dont_plot=True break if (not g.is_terminal(almost = True)) and g.if_within_actionable_time(): if loop_count % rollout_interval == 0: print("------------loopcount/mission_time =", loop_count) # for kk in range(len(cs1as2_list)): # check_cs1_a_s2 = cs1as2_list[kk] # check_cs1 = check_cs1_a_s2[0] # check_s2 = check_cs1_a_s2[2] # tij1 = (check_cs1[0],check_cs1[3],check_cs1[4]) # print("check: ", tij1, check_s2) Qcopy, Ncopy = update_Q_in_future_rollouts(gcopy, Qcopy, Ncopy, cs1as2_list, vel_field_data, nmodes, loop_count) s1 = s2 #for next iteration of loop cs1 = (t, g.x, g.y, i, j) a, q_s_a = max_dict(Qcopy[s1]) elif g.is_terminal(almost = True): break else: # i.e. not terminal and not in actinable time. # already checked if ternminal or not. If not terminal # if time reaches nt ie not within actionable time, then increment badcount and Dont plot bad_count+=1 bad_flag=True # dont_plot=True break if dont_plot==False: plt.plot(xtr, ytr) # if bad flag is True then append None to the list. These nones are counted later if bad_flag == False: traj_list.append((xtr,ytr)) t_list.append(t) #ADDED for trajactory comparison else: traj_list.append(None) t_list.append(None) if fname != None: plt.savefig(join(fpath,fname),bbox_inches = "tight", dpi=200) plt.cla() plt.close(fig) writePolicytoFile(t_list, join(fpath,fname+'tlist' )) picklePolicy(traj_list, join(fpath,fname+'_coord_traj')) print("*** pickled phase2 traj_list ***") return t_list, bad_count
def plot_and_return_exact_trajectory_set_train_data(g, policy, X, Y, vel_field_data, nmodes, test_id_list, n_test_paths_range, fpath, fname='Trajectories'): """ Makes plots across all rzns with different colors for test and train data returns list for all rzns. """ # time calculation and state trajectory print("--- in plot_functions.plot_exact_trajectory_set---") msize = 15 # fsize = 3 fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(1, 1, 1) ax.set_xlim(0, 1) ax.set_ylim(0, 1) minor_ticks = [i / 100.0 for i in range(101) if i % 20 != 0] major_ticks = [i / 100.0 for i in range(0, 120, 20)] ax.set_xticks(minor_ticks, minor=True) ax.set_xticks(major_ticks, minor=False) ax.set_yticks(major_ticks, minor=False) ax.set_yticks(minor_ticks, minor=True) ax.grid(b=True, which='both', color='#CCCCCC', axis='both', linestyle='-', alpha=0.5) ax.tick_params(axis='both', which='both', labelsize=6) ax.set_xlabel('X (Non-Dim)') ax.set_ylabel('Y (Non-Dim)') st_point = g.start_state plt.scatter(g.xs[st_point[1]], g.ys[g.ni - 1 - st_point[0]], marker='o', s=msize, color='k', zorder=1e5) plt.scatter(g.xs[g.endpos[1]], g.ys[g.ni - 1 - g.endpos[0]], marker='*', s=msize * 2, color='k', zorder=1e5) plt.gca().set_aspect('equal', adjustable='box') # plt.quiver(X, Y, vStream_x[0, :, :], vStream_y[0, :, :]) # _,m,n = vStream_x.shape bad_count = 0 t_list = [] G_list = [] traj_list = [] sars_traj_list = [] for rzn in test_id_list[n_test_paths_range[0]:n_test_paths_range[1]]: # print("rzn: ", rzn) g.set_state(g.start_state) dont_plot = False bad_flag = False # t = 0 G = 0 xtr = [] ytr = [] sars_traj = [] s1 = g.start_state t, i, j = s1 a = policy[g.current_state()] xtr.append(g.x) ytr.append(g.y) loop_count = 0 # while not g.is_terminal() and g.if_within_actionable_time() and g.current_state: # print("__CHECK__ t, i, j") while True: loop_count += 1 vx, vy = extract_velocity(vel_field_data, t, i, j, rzn) r = g.move_exact(a, vx, vy) G = G + r s2 = g.current_state() (t, i, j) = s2 sars_traj.append((s1, a, r, s2)) xtr.append(g.x) ytr.append(g.y) s1 = s2 #for next iteration of loop if g.if_edge_state((i, j)): bad_count += 1 # dont_plot=True break if (not g.is_terminal( almost=True)) and g.if_within_actionable_time(): a = policy[g.current_state()] elif g.is_terminal(almost=True): break else: # i.e. not terminal and not in actinable time. # already checked if ternminal or not. If not terminal # if time reaches nt ie not within actionable time, then increment badcount and Dont plot bad_count += 1 bad_flag = True # dont_plot=True break #Debugging measure: additional check to break loop because code gets stuck sometims if loop_count > g.ni * c_ni: print("t: ", t) print("g.current_state: ", g.current_state()) print("xtr: ", xtr) print("ytr: ", ytr) break # if t > g.ni * c_ni: #if trajectory goes haywire, dont plot it. # bad_count+=1 # dont_plot=True # break if dont_plot == False: plt.plot(xtr, ytr, '--') # if bad flag is True then append None to the list. These nones are counted later if bad_flag == False: sars_traj_list.append(sars_traj) traj_list.append((xtr, ytr)) t_list.append(t) G_list.append(G) #ADDED for trajactory comparison else: sars_traj_list.append(None) traj_list.append(None) t_list.append(None) G_list.append(None) if fname != None: plt.savefig(join(fpath, fname), bbox_inches="tight", dpi=200) plt.cla() plt.close(fig) print("*** pickling traj_list ***") picklePolicy(traj_list, join(fpath, fname + 'coord_traj')) # picklePolicy(sars_traj_list,join(fpath,'sars_traj_'+fname) ) print("*** pickled ***") return t_list, G_list, bad_count # g, xs, ys, X, Y, vel_field_data, nmodes, useful_num_rzns, paths, params, param_str = setup_grid(num_actions=16) # rel_path = 'Experiments/26/DP' # exp_num_case_dir = join(ROOT_DIR, rel_path) # policy = read_pickled_File(join(exp_num_case_dir, 'policy')) # # tlist_file = join(exp_num_case_dir, 'TrajTimes2.txt') # # with open(tlist_file, 'r') as f: # # phase1_tlist = ast.literal_eval(f.read()) # test_id_rel_path ='Experiments/104/QL/num_passes_50/QL_Iter_x1/dt_size_2500/ALPHA_0.05/eps_0_0.1' # test_id_list = read_pickled_File(join(test_id_rel_path, 'test_id_list')) # global n_test_paths_range # n_test_paths_range = [0, len(test_id_list)] # t_list, G_list, bad_count = plot_and_return_exact_trajectory_set_train_data(g, policy, X, Y, vel_field_data, nmodes, test_id_list, n_test_paths_range, exp_num_case_dir, fname='Explicit_plot_DPpolicy_104_testid') # phase1_results = calc_mean_and_std(t_list) # avg_time_ph1, std_time_ph1, cnt_ph1 , none_cnt_ph1 = phase1_results # print("avg_time_ph1", avg_time_ph1,'\n', # "std_time_ph1", std_time_ph1, '\n', # "cnt_ph1",cnt_ph1 , '\n', # "none_cnt_ph1", none_cnt_ph1) # print(t_list) # print("stats from explicit plot: ") # print(calc_mean_and_std(t_list)) # summ = 0 # cnt = 0 # print(len(phase1_tlist)) # print(test_id_list[:n_test_paths]) # for i in range(n_test_paths): # rzn = test_id_list[i] # t = phase1_tlist[rzn] # print(t) # if t != None: # summ += phase1_tlist[rzn] # cnt += 1 # print("----- phase 1 data and explict plot -----") # print('n_test_paths= ',n_test_paths) # print("mean= ", summ/cnt) # print("cnt = ", cnt) # print("pfail or badcount% = ", cnt/n_test_paths)
def plot_and_return_exact_trajectory_set_train_data(g, policy, X, Y, vel_field_data, nmodes, test_id_list,n_test_paths, fpath, fname='Trajectories'): """ Makes plots across all rzns with different colors for test and train data returns list for all rzns. """ # time calculation and state trajectory print("--- in plot_functions.plot_exact_trajectory_set---") msize = 15 fsize = 3 #---------------------------- beautify plot --------------------------- # time calculation and state trajectory fig = plt.figure(figsize=(fsize, fsize)) ax = fig.add_subplot(1, 1, 1) ax.set_xlim(0, 1) ax.set_ylim(0, 1) # set grid minor_ticks = [i/100 for i in range(101) if i % 20 != 0] major_ticks = [i/100 for i in range(0, 120, 20)] ax.set_xticks(minor_ticks, minor=True) ax.set_xticks(major_ticks, minor=False) ax.set_yticks(major_ticks, minor=False) ax.set_yticks(minor_ticks, minor=True) ax.grid(b=True, which='both', color='#CCCCCC', axis='both', linestyle='-', alpha=0.5) ax.tick_params(axis='both', which='both', labelsize=6) ax.set_xlabel('X (Non-Dim)') ax.set_ylabel('Y (Non-Dim)') st_point= g.start_state plt.scatter(g.xs[st_point[2]], g.ys[g.ni - 1 - st_point[1]], marker = 'o', s = msize, color = 'k', zorder = 1e5) plt.scatter(g.xs[g.endpos[1]], g.ys[g.ni - 1 - g.endpos[0]], marker = '*', s = msize*2, color ='k', zorder = 1e5) plt.gca().set_aspect('equal', adjustable='box') jet = cm = plt.get_cmap('jet') cNorm = colors.Normalize(vmin=np.min(50), vmax=np.max(60)) scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=jet) scalarMap._A = [] # plt.quiver(X, Y, vStream_x[0, :, :], vStream_y[0, :, :]) # _,m,n = vStream_x.shape bad_count =0 t_list=[] G_list=[] traj_list = [] sars_traj_list = [] for rzn in test_id_list[n_test_paths_range[0]:n_test_paths_range[1]]: # print("rzn: ", rzn) g.set_state(g.start_state) dont_plot =False bad_flag = False # t = 0 G = 0 xtr = [] ytr = [] sars_traj = [] s1 = g.start_state t, i, j = s1 a = policy[g.current_state()] xtr.append(g.x) ytr.append(g.y) loop_count = 0 # while not g.is_terminal() and g.if_within_actionable_time() and g.current_state: # print("__CHECK__ t, i, j") while True: loop_count += 1 vx, vy = extract_velocity(vel_field_data, t, i, j, rzn) r = g.move_exact(a, vx, vy) G = G + r s2 = g.current_state() (t, i, j) = s2 sars_traj.append((s1, a, r, s2)) xtr.append(g.x) ytr.append(g.y) s1 = s2 #for next iteration of loop if g.if_edge_state((i,j)): bad_count += 1 # dont_plot=True break if (not g.is_terminal(almost = True)) and g.if_within_actionable_time(): a = policy[g.current_state()] elif g.is_terminal(almost = True): break else: # i.e. not terminal and not in actinable time. # already checked if ternminal or not. If not terminal # if time reaches nt ie not within actionable time, then increment badcount and Dont plot bad_count+=1 bad_flag=True # dont_plot=True break #Debugging measure: additional check to break loop because code gets stuck sometims if loop_count > g.ni * c_ni: print("t: ", t) print("g.current_state: ", g.current_state()) print("xtr: ",xtr) print("ytr: ",ytr) break # if t > g.ni * c_ni: #if trajectory goes haywire, dont plot it. # bad_count+=1 # dont_plot=True # break if dont_plot==False: colorval = scalarMap.to_rgba(t) # plt.plot(plot_set[i][0], plot_set[i][1], color=colorval, alpha=0.6) plt.plot(xtr, ytr, color=colorval, alpha=0.6) # if bad flag is True then append None to the list. These nones are counted later if bad_flag == False: sars_traj_list.append(sars_traj) traj_list.append((xtr,ytr)) t_list.append(t) G_list.append(G) #ADDED for trajactory comparison else: sars_traj_list.append(None) traj_list.append(None) t_list.append(None) G_list.append(None) if fname != None: plt.savefig(join(fpath,fname),bbox_inches = "tight", dpi=200) plt.cla() plt.close(fig) print("*** pickling traj_list ***") picklePolicy(traj_list, join(fpath,fname + 'coord_traj')) # picklePolicy(sars_traj_list,join(fpath,'sars_traj_'+fname) ) print("*** pickled ***") return t_list, G_list, bad_count
def plot_exact_trajectory_set_DP(g, policy, X, Y, vel_field_data, test_id_list, fpath, fname='Trajectories'): # time calculation and state trajectory fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(1, 1, 1) # set grid minor_xticks = np.arange(g.xs[0] - 0.5 * g.dj, g.xs[-1] + 2 * g.dj, g.dj) minor_yticks = np.arange(g.ys[0] - 0.5 * g.di, g.ys[-1] + 2 * g.di, g.di) major_xticks = np.arange(g.xs[0], g.xs[-1] + 2 * g.dj, 5 * g.dj) major_yticks = np.arange(g.ys[0], g.ys[-1] + 2 * g.di, 5 * g.di) ax.set_xticks(minor_xticks, minor=True) ax.set_yticks(minor_yticks, minor=True) ax.set_xticks(major_xticks) ax.set_yticks(major_yticks) ax.grid(which='major', color='#CCCCCC', linestyle='') ax.grid(which='minor', color='#CCCCCC', linestyle='--') st_point = g.start_state plt.scatter(g.xs[st_point[2]], g.ys[g.ni - 1 - st_point[1]], c='g') plt.scatter(g.xs[g.endpos[1]], g.ys[g.ni - 1 - g.endpos[0]], c='r') plt.grid() plt.gca().set_aspect('equal', adjustable='box') # plt.quiver(X, Y, vel_field_data[0][0,:,:], vel_field_data[1][0, :, :]) # nt, nrzns, nmodes = vel_field_data[4].shape #vel_field_data[4] is all_Yi bad_count = 0 t_list_all = [] t_list_reached = [] G_list = [] traj_list = [] for rzn in test_id_list: g.set_state(g.start_state) dont_plot = False bad_flag = False # t = 0 G = 0 xtr = [] ytr = [] t, i, j = g.start_state a = policy[g.current_state()] xtr.append(g.x) ytr.append(g.y) # while (not g.is_terminal()) and g.if_within_actionable_time(): while True: vx, vy = extract_velocity(vel_field_data, t, i, j, rzn) r = g.move_exact(a, vx, vy) G = G + r # t += 1 (t, i, j) = g.current_state() xtr.append(g.x) ytr.append(g.y) # if edge state encountered, then increment badcount and Dont plot if g.if_edge_state((i, j)): bad_count += 1 # dont_plot=True break if (not g.is_terminal( almost=True)) and g.if_within_actionable_time(): a = policy[g.current_state()] elif g.is_terminal(almost=True): break else: # i.e. not terminal and not in actinable time. # already checked if ternminal or not. If not terminal # if time reaches nt ie not within actionable time, then increment badcount and Dont plot bad_count += 1 bad_flag = True # dont_plot=True break if dont_plot == False: plt.plot(xtr, ytr) if bad_flag == False: traj_list.append((xtr, ytr)) t_list_all.append(t) G_list.append(G) t_list_reached.append(t) #ADDED for trajactory comparison else: traj_list.append(None) t_list_all.append(None) G_list.append(None) if fname != None: plt.savefig(join(fpath, fname), dpi=300) print("*** pickling traj_list ***") picklePolicy(traj_list, join(fpath, fname)) print("*** pickled ***") bad_count_tuple = (bad_count, str(bad_count * 100 / len(test_id_list)) + '%') return t_list_all, t_list_reached, G_list, bad_count_tuple
def plot_exact_trajectory_set(g, policy, X, Y, vel_field_data, nmodes, train_id_set, test_id_set, goodlist, fpath, fname='Trajectories'): """ Makes plots across all rzns with different colors for test and train data returns list for all rzns. """ # time calculation and state trajectory print("--- in plot_functions.plot_exact_trajectory_set---") msize = 15 # fsize = 3 fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(1, 1, 1) ax.set_xlim(0, 1) ax.set_ylim(0, 1) minor_ticks = [i / 100 for i in range(101) if i % 20 != 0] major_ticks = [i / 100 for i in range(0, 120, 20)] ax.set_xticks(minor_ticks, minor=True) ax.set_xticks(major_ticks, minor=False) ax.set_yticks(major_ticks, minor=False) ax.set_yticks(minor_ticks, minor=True) ax.grid(b=True, which='both', color='#CCCCCC', axis='both', linestyle='-', alpha=0.5) ax.tick_params(axis='both', which='both', labelsize=6) ax.set_xlabel('X (Non-Dim)') ax.set_ylabel('Y (Non-Dim)') st_point = g.start_state plt.scatter(g.xs[st_point[1]], g.ys[g.ni - 1 - st_point[0]], marker='o', s=msize, color='k', zorder=1e5) plt.scatter(g.xs[g.endpos[1]], g.ys[g.ni - 1 - g.endpos[0]], marker='*', s=msize * 2, color='k', zorder=1e5) plt.gca().set_aspect('equal', adjustable='box') # plt.quiver(X, Y, vStream_x[0, :, :], vStream_y[0, :, :]) # n_rzn,m,n = vStream_x.shape bad_count = 0 t_list = [] G_list = [] traj_list = [] for rzn in goodlist: # print("rzn: ", rzn) color = 'r' if rzn in train_id_set: color = 'b' elif rzn in test_id_set: color = 'g' g.set_state(g.start_state) dont_plot = False bad_flag = False # t = 0 G = 0 xtr = [] ytr = [] t, i, j = g.start_state a = policy[g.current_state()] xtr.append(g.x) ytr.append(g.y) loop_count = 0 # while not g.is_terminal() and g.if_within_actionable_time() and g.current_state: # print("__CHECK__ t, i, j") while True: loop_count += 1 vx, vy = extract_velocity(vel_field_data, t, i, j, rzn) # r = g.move_exact(a, vStream_x[rzn, i, j], vStream_y[rzn, i, j]) r = g.move_exact(a, vx, vy) G = G + r s = g.current_state() (t, i, j) = s xtr.append(g.x) ytr.append(g.y) if g.if_edge_state((i, j)): bad_count += 1 # dont_plot=True break if (not g.is_terminal( almost=True)) and g.if_within_actionable_time(): a = policy[g.current_state()] elif g.is_terminal(almost=True): break else: # i.e. not terminal and not in actinable time. # already checked if ternminal or not. If not terminal # if time reaches nt ie not within actionable time, then increment badcount and Dont plot bad_count += 1 bad_flag = True # dont_plot=True break #Debugging measure: additional check to break loop because code gets stuck sometims if loop_count > g.ni * c_ni: print("t: ", t) print("g.current_state: ", g.current_state()) print("xtr: ", xtr) print("ytr: ", ytr) break # if t > g.ni * c_ni: #if trajectory goes haywire, dont plot it. # bad_count+=1 # dont_plot=True # break if dont_plot == False: if color == 'g': plt.plot(xtr, ytr, color=color, zorder=1e5) else: plt.plot(xtr, ytr, color=color) # if bad flag is True then append None to the list. These nones are counted later if bad_flag == False: traj_list.append((xtr, ytr)) t_list.append(t) G_list.append(G) #ADDED for trajactory comparison else: traj_list.append(None) t_list.append(None) G_list.append(None) if fname != None: plt.savefig(join(fpath, fname), bbox_inches="tight", dpi=200) plt.cla() plt.close(fig) print("*** pickling traj_list ***") picklePolicy(traj_list, join(fpath, fname)) print("*** pickled ***") return t_list, G_list, bad_count