def get_similar_rzn_ids(query_cs1_a_s2, g, vel_field_data, nmodes):

    cs1, a, s2 = query_cs1_a_s2
    t, x, y, i, j = cs1
    matched_rzns = set()
    for rzn in train_id_list:
        # vx = Vx_rzns[rzn,i,j]
        # vy = Vy_rzns[rzn,i,j]
        vx, vy = extract_velocity(vel_field_data, t, i, j, rzn)
        g.set_state((t, i, j), xcoord = x, ycoord = y)
        g.move_exact(a, vx, vy)
        # assert(t == s2[0]), "damn"
        if g.current_state() == s2:
            matched_rzns.add(rzn)

    return matched_rzns
def update_Q_in_future_kth_rzn(g, Q, N, vel_field_data, nmodes, s1, rzn, eps):
    """
    almost same as from Run_Q_learning_episode()
    s2: current state in whilie simulating roolout
    """

    t, i, j = s1
    g.set_state(s1)
    dummy_policy = None   #stochastic_action_eps_greedy() here, uses Q. so policy is ingnored anyway
    # a1 = stochastic_action_eps_greedy(policy, s1, g, eps, Q=Q)
    count = 0
    max_delQ = 0

    # while not g.is_terminal() and g.if_within_TD_actionable_time():
    while not g.is_terminal(s1) and not g.if_edge_state(s1) and g.if_within_actionable_time():
        """Will have to change this for general time"""
        
        t, i, j = s1
        a1 = stochastic_action_eps_greedy(dummy_policy, s1, g, eps, Q=Q)
        vx, vy = extract_velocity(vel_field_data, t, i, j, rzn)
        r = g.move_exact(a1, vx, vy, rzn)
        # r = g.move_exact(a1, Vx_rzns[rzn, i, j], Vy_rzns[rzn, i, j])
        s2 = g.current_state()
        # if g.is_terminal() or (not g.if_within_actionable_time()):

        alpha = ALPHA / N[s1][a1]
        N[s1][a1] += N_inc

        #maxQsa = 0 if next state is a terminal state/edgestate/outside actionable time
        max_q_s2_a2= 0
        if not g.is_terminal(s2) and not g.if_edge_state(s2) and g.if_within_actionable_time():
            a2, max_q_s2_a2 = max_dict(Q[s2])

        old_qsa = Q[s1][a1]
        Q[s1][a1] = Q[s1][a1] + alpha*(r + max_q_s2_a2 - Q[s1][a1])

        if np.abs(old_qsa - Q[s1][a1]) > max_delQ:
            max_delQ = np.abs(old_qsa - Q[s1][a1])


        s1 = s2
        # t, i, j = s1

    return Q, N
def build_experience_buffer(grid, vel_field_data, nmodes, paths, sampling_interval, train_path_ids, num_actions ):
    exp_buffer_all_trajs = []
    # print("$$$ CHECK  train_path_ids: ", train_path_ids)
    
    s_next_badcount = 0 #for counting how many times we dont reach terminal state
    double_count = 0
    s_next_bad_idlist = []
    doouble_idlist =[]
    coord_traj_5k =[]
    state_traj_5k =[]
    outbound_count = 0
    for k in train_path_ids:
        exp_buffer_kth_traj = []
        # Vxt = Vx_rzns[k, :, :]  #kth rzn velocity data
        # Vyt = Vy_rzns[k, :, :]
        trajectory = paths[k, 0] #kth trajectory, array of shape (n,2)

        # append starting point to traj
        s_t = int(0)
        coord_traj = []
        state_traj = []
        for i in range(0, len(trajectory), sampling_interval):
            coord_traj.append((s_t, trajectory[i][0], trajectory[i][1]))  #add first coordinate from the trajectory to coord_traj
            s_i, s_j = compute_cell(grid, trajectory[i])         # compute indices corrsponding to first coord
            state_traj.append((s_t, s_i, s_j))  #add first index to the list continating trajectory as state indices
            s_t += 1

        """
        # make dictionary states mapping to coords. and choose middle coord to append to traj
        traj_dict = OrderedDict()
        for j in range(0, len(trajectory)):
            s_i, s_j = compute_cell(grid, trajectory[j])
            s = (s_i, s_j)
            c = (trajectory[j][0], trajectory[j][1])
            if not traj_dict.get(s):
                traj_dict[s] = [c]
            else:
                traj_dict[s].append(c)
        keys = list(traj_dict.keys())
        keys.remove(keys[0])        #remove first and last keys (states).
        keys.remove(keys[-1])          #They are appended separately

        for s in keys:
            state_traj.append(s)
            l = len(traj_dict[s])
            coord_traj.append(traj_dict[s][int(l//2)])
        """
        # add last point to the trajectories
        coord_traj.append((s_t, trajectory[-1][0], trajectory[-1][1]))
        s_i, s_j = compute_cell(grid, trajectory[-1])
        state_traj.append((s_t, s_i, s_j))

        #  reverse order now to save it from doing so while leaning through "reverse" method
        state_traj.reverse()
        coord_traj.reverse()
        
        coord_traj_5k.append(coord_traj)
        state_traj_5k.append(state_traj)
        
        #build buffer
        # print("check warning, rzn: ", k)
        # print("s1, p1, p2, Vxt, Vyt")
        for i in range(len(state_traj)-1): # till len -1 because there is i+1 inside the loop
            s1=state_traj[i+1]
            s2=state_traj[i] #IMP: perhaps only used as dummy
            if not (grid.is_terminal(s1)):
                t ,m, n = s1
                # m, n = s1
                #vx=Vxt[t,i,j]
                # TODO: the below is only valid as this is a stationary vel feild. hence, t idx not needed
                try:
                    vx, vy = extract_velocity(vel_field_data, t, m, n, k)
                except:
                    print("$$$ CHECK t,i,j: ", t, m, n)
                    outbound_count+=1

                # vx = Vxt[m, n]
                # vy = Vyt[m, n]

                p1=coord_traj[i+1]
                p2=coord_traj[i]
                grid.set_state(s1, xcoord=p1[1], ycoord=p1[2])

                """COMMENTING THIS STATEMENT BELOW"""
                # if (s1[1],s1[2])!=(s2[1],s2[2]):
                # print(s1,p1,p2, vx, vy)
                a1 = Calculate_action(s1, s2, p1, p2, vx, vy, grid, coord_traj_theta= False)
                r1 = grid.move_exact(a1, vx, vy, k)
                s_next = grid.current_state()
                # if grid.current_state() != s2:
                #     print("**** mismatch: ",s1, a1, s2, grid.current_state())
                if i == 0:
                    if grid.is_terminal(s_next):
                        s_next_badcount += 1
                        s_next_bad_idlist.append((k,s1,a1,r1,s_next,s2))
                exp_buffer_kth_traj.append([s1, a1, r1, s_next])
            else:
                double_count += 1
                doouble_idlist.append((k, s1, s2))

        # if k == 0:
        #     # print("$$$$ CHECk state_traj: ", state_traj)
        #     pass
        #     # print("$$$$ CHECk 0th buffer: ")
        #     # for sars in exp_buffer_kth_traj:
        #     #     print(sars)

        #append kth-traj-list to master list
        exp_buffer_all_trajs.append(exp_buffer_kth_traj)

    # picklePolicy(doouble_idlist, 'doouble_idlist' )
    # picklePolicy( s_next_bad_idlist, 's_next_bad_idlist')
    # picklePolicy( coord_traj_5k, 'coord_traj_5k')
    # picklePolicy( state_traj_5k,'state_traj_5k ')
    print("$$$$$$$ s_next_badcount: ", s_next_badcount)
    print("$$$$$$$ double_count: ", double_count)
    print("$$$$$ dlen(train_path_ids) ", len(train_path_ids))
    print('$$$$$ outbound_count= ', outbound_count)

    return exp_buffer_all_trajs
def run_and_plot_onboard_routing_episodes(setup_grid_params, Q, N, fpath, fname):
# g, xs, ys, X, Y, vel_field_data, nmodes, useful_num_rzns, paths, params, param_str
    g, xs, ys, X, Y, vel_field_data, nmodes, _, paths, _, _ = setup_grid_params
    g.make_bcrumb_dict(paths, train_id_list)
   
    gcopy = copy.deepcopy(g)
    # Copy Q to Qcopy

    msize = 15
    # fsize = 3

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(1, 1, 1)
    ax.set_xlim(0,1)
    ax.set_ylim(0,1)

    minor_ticks = [i/100.0 for i in range(101) if i%20!=0]
    major_ticks = [i/100.0 for i in range(0,120,20)]

    ax.set_xticks(minor_ticks, minor=True)
    ax.set_xticks(major_ticks, minor=False)
    ax.set_yticks(major_ticks, minor=False)
    ax.set_yticks(minor_ticks, minor=True)

    ax.grid(b= True, which='both', color='#CCCCCC', axis='both',linestyle = '-', alpha = 0.5)
    ax.tick_params(axis='both', which='both', labelsize=6)

    ax.set_xlabel('X (Non-Dim)')
    ax.set_ylabel('Y (Non-Dim)')

    st_point= g.start_state
    plt.scatter(g.xs[st_point[1]], g.ys[g.ni - 1 - st_point[0]], marker = 'o', s = msize, color = 'k', zorder = 1e5)
    plt.scatter(g.xs[g.endpos[1]], g.ys[g.ni - 1 - g.endpos[0]], marker = '*', s = msize*2, color ='k', zorder = 1e5)
    plt.gca().set_aspect('equal', adjustable='box')

    # plt.quiver(X, Y, Vx_rzns[0, :, :], Vy_rzns[0, :, :])

    
    t_list=[]
    traj_list = []
    bad_count = 0
    # for k in range(len(test_id_list)):
    for k in range(n_test_paths_range[0], n_test_paths_range[1]):
        Qcopy = copy.deepcopy(Q)
        Ncopy = copy.deepcopy(N)
        rzn = test_id_list[k]

        init_list = [None for i in range(rollout_interval)]
        cs1as2_list = deque(init_list)         #to keep a fixed lenght list representation  

        print("-------- In rzn ", rzn, " of test_id_list ---------")
        g.set_state(g.start_state)
        dont_plot =False
        bad_flag = False

        xtr = []
        ytr = []

        s1 = g.start_state
        t, i, j = s1
        cs1 = (t, g.x, g.y ,i, j)
        a, q_s_a = max_dict(Qcopy[s1])

        xtr.append(g.x)
        ytr.append(g.y)
        loop_count = 0
        # while not g.is_terminal() and g.if_within_actionable_time() and g.current_state:
        # print("__CHECK__ t, i, j")
        while True:
            loop_count += 1
            vx, vy = extract_velocity(vel_field_data, t, i, j, rzn)
            r = g.move_exact(a, vx, vy)
            # r = g.move_exact(a, Vx_rzns[rzn, i, j], Vy_rzns[rzn, i, j])
            s2 = g.current_state()
            (t, i, j) = s2
            cs1_a_s2 = (cs1, a, s2)

            # keep n latest transitions where n = rollout_interval
            cs1as2_list.pop()
            cs1as2_list.appendleft(cs1_a_s2)

            xtr.append(g.x)
            ytr.append(g.y)


            if g.if_edge_state((i,j)):
                bad_count += 1
                # dont_plot=True
                break
            if (not g.is_terminal(almost = True)) and  g.if_within_actionable_time():
                if loop_count % rollout_interval == 0:
                    print("------------loopcount/mission_time =", loop_count)
                    # for kk in range(len(cs1as2_list)):
                    #     check_cs1_a_s2 = cs1as2_list[kk]
                    #     check_cs1 = check_cs1_a_s2[0]
                    #     check_s2 = check_cs1_a_s2[2]
                    #     tij1 = (check_cs1[0],check_cs1[3],check_cs1[4])
                        # print("check: ", tij1, check_s2)
                    Qcopy, Ncopy = update_Q_in_future_rollouts(gcopy, Qcopy, Ncopy, cs1as2_list, vel_field_data, nmodes, loop_count)
                s1 = s2 #for next iteration of loop
                cs1 = (t, g.x, g.y, i, j)
                a, q_s_a = max_dict(Qcopy[s1])
            elif g.is_terminal(almost = True):
                break
            else:
            #  i.e. not terminal and not in actinable time.
            # already checked if ternminal or not. If not terminal 
            # if time reaches nt ie not within actionable time, then increment badcount and Dont plot
                bad_count+=1
                bad_flag=True
                # dont_plot=True
                break


        if dont_plot==False:
            plt.plot(xtr, ytr)
        # if bad flag is True then append None to the list. These nones are counted later
        if bad_flag == False:  
            traj_list.append((xtr,ytr))
            t_list.append(t)
        #ADDED for trajactory comparison
        else:
            traj_list.append(None)
            t_list.append(None)


    if fname != None:
        plt.savefig(join(fpath,fname),bbox_inches = "tight", dpi=200)
        plt.cla()
        plt.close(fig)
        writePolicytoFile(t_list, join(fpath,fname+'tlist' ))
        picklePolicy(traj_list, join(fpath,fname+'_coord_traj'))
        print("*** pickled phase2 traj_list ***")

    return t_list, bad_count
Beispiel #5
0
def plot_and_return_exact_trajectory_set_train_data(g,
                                                    policy,
                                                    X,
                                                    Y,
                                                    vel_field_data,
                                                    nmodes,
                                                    test_id_list,
                                                    n_test_paths_range,
                                                    fpath,
                                                    fname='Trajectories'):
    """
    Makes plots across all rzns with different colors for test and train data
    returns list for all rzns.
    """

    # time calculation and state trajectory
    print("--- in plot_functions.plot_exact_trajectory_set---")

    msize = 15
    # fsize = 3

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(1, 1, 1)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    minor_ticks = [i / 100.0 for i in range(101) if i % 20 != 0]
    major_ticks = [i / 100.0 for i in range(0, 120, 20)]

    ax.set_xticks(minor_ticks, minor=True)
    ax.set_xticks(major_ticks, minor=False)
    ax.set_yticks(major_ticks, minor=False)
    ax.set_yticks(minor_ticks, minor=True)

    ax.grid(b=True,
            which='both',
            color='#CCCCCC',
            axis='both',
            linestyle='-',
            alpha=0.5)
    ax.tick_params(axis='both', which='both', labelsize=6)

    ax.set_xlabel('X (Non-Dim)')
    ax.set_ylabel('Y (Non-Dim)')

    st_point = g.start_state
    plt.scatter(g.xs[st_point[1]],
                g.ys[g.ni - 1 - st_point[0]],
                marker='o',
                s=msize,
                color='k',
                zorder=1e5)
    plt.scatter(g.xs[g.endpos[1]],
                g.ys[g.ni - 1 - g.endpos[0]],
                marker='*',
                s=msize * 2,
                color='k',
                zorder=1e5)
    plt.gca().set_aspect('equal', adjustable='box')

    # plt.quiver(X, Y, vStream_x[0, :, :], vStream_y[0, :, :])

    # _,m,n = vStream_x.shape
    bad_count = 0

    t_list = []
    G_list = []
    traj_list = []
    sars_traj_list = []

    for rzn in test_id_list[n_test_paths_range[0]:n_test_paths_range[1]]:
        # print("rzn: ", rzn)

        g.set_state(g.start_state)
        dont_plot = False
        bad_flag = False
        # t = 0
        G = 0

        xtr = []
        ytr = []
        sars_traj = []

        s1 = g.start_state
        t, i, j = s1

        a = policy[g.current_state()]

        xtr.append(g.x)
        ytr.append(g.y)
        loop_count = 0
        # while not g.is_terminal() and g.if_within_actionable_time() and g.current_state:
        # print("__CHECK__ t, i, j")
        while True:
            loop_count += 1
            vx, vy = extract_velocity(vel_field_data, t, i, j, rzn)
            r = g.move_exact(a, vx, vy)

            G = G + r
            s2 = g.current_state()
            (t, i, j) = s2

            sars_traj.append((s1, a, r, s2))
            xtr.append(g.x)
            ytr.append(g.y)

            s1 = s2  #for next iteration of loop
            if g.if_edge_state((i, j)):
                bad_count += 1
                # dont_plot=True
                break

            if (not g.is_terminal(
                    almost=True)) and g.if_within_actionable_time():
                a = policy[g.current_state()]
            elif g.is_terminal(almost=True):
                break
            else:
                #  i.e. not terminal and not in actinable time.
                # already checked if ternminal or not. If not terminal
                # if time reaches nt ie not within actionable time, then increment badcount and Dont plot
                bad_count += 1
                bad_flag = True
                # dont_plot=True
                break

            #Debugging measure: additional check to break loop because code gets stuck sometims
            if loop_count > g.ni * c_ni:
                print("t: ", t)
                print("g.current_state: ", g.current_state())
                print("xtr: ", xtr)
                print("ytr: ", ytr)
                break

            # if t > g.ni * c_ni: #if trajectory goes haywire, dont plot it.
            #     bad_count+=1
            #     dont_plot=True
            #     break

        if dont_plot == False:
            plt.plot(xtr, ytr, '--')

        # if bad flag is True then append None to the list. These nones are counted later
        if bad_flag == False:
            sars_traj_list.append(sars_traj)
            traj_list.append((xtr, ytr))
            t_list.append(t)
            G_list.append(G)
        #ADDED for trajactory comparison
        else:
            sars_traj_list.append(None)
            traj_list.append(None)
            t_list.append(None)
            G_list.append(None)

    if fname != None:
        plt.savefig(join(fpath, fname), bbox_inches="tight", dpi=200)
        plt.cla()
        plt.close(fig)
        print("*** pickling traj_list ***")
        picklePolicy(traj_list, join(fpath, fname + 'coord_traj'))
        # picklePolicy(sars_traj_list,join(fpath,'sars_traj_'+fname) )
        print("*** pickled ***")

    return t_list, G_list, bad_count


# g, xs, ys, X, Y, vel_field_data, nmodes, useful_num_rzns, paths, params, param_str = setup_grid(num_actions=16)

# rel_path = 'Experiments/26/DP'
# exp_num_case_dir = join(ROOT_DIR, rel_path)
# policy = read_pickled_File(join(exp_num_case_dir, 'policy'))

# # tlist_file = join(exp_num_case_dir, 'TrajTimes2.txt')
# # with open(tlist_file, 'r') as f:
# #     phase1_tlist = ast.literal_eval(f.read())

# test_id_rel_path ='Experiments/104/QL/num_passes_50/QL_Iter_x1/dt_size_2500/ALPHA_0.05/eps_0_0.1'
# test_id_list = read_pickled_File(join(test_id_rel_path, 'test_id_list'))

# global n_test_paths_range
# n_test_paths_range = [0, len(test_id_list)]

# t_list, G_list, bad_count = plot_and_return_exact_trajectory_set_train_data(g, policy, X, Y, vel_field_data, nmodes, test_id_list, n_test_paths_range, exp_num_case_dir, fname='Explicit_plot_DPpolicy_104_testid')

# phase1_results = calc_mean_and_std(t_list)
# avg_time_ph1, std_time_ph1, cnt_ph1 , none_cnt_ph1 = phase1_results
# print("avg_time_ph1", avg_time_ph1,'\n',
#         "std_time_ph1", std_time_ph1, '\n',
#         "cnt_ph1",cnt_ph1 , '\n',
#         "none_cnt_ph1", none_cnt_ph1)

# print(t_list)
# print("stats from explicit plot: ")
# print(calc_mean_and_std(t_list))

# summ = 0
# cnt = 0
# print(len(phase1_tlist))
# print(test_id_list[:n_test_paths])
# for i in range(n_test_paths):
#     rzn = test_id_list[i]
#     t =  phase1_tlist[rzn]
#     print(t)
#     if t != None:
#         summ += phase1_tlist[rzn]
#         cnt += 1
# print("----- phase 1 data and explict plot -----")
# print('n_test_paths= ',n_test_paths)
# print("mean= ", summ/cnt)
# print("cnt = ", cnt)
# print("pfail or badcount% = ", cnt/n_test_paths)
def plot_and_return_exact_trajectory_set_train_data(g, policy, X, Y, vel_field_data, nmodes, test_id_list,n_test_paths, fpath, fname='Trajectories'):
    """
    Makes plots across all rzns with different colors for test and train data
    returns list for all rzns.
    """

    # time calculation and state trajectory
    print("--- in plot_functions.plot_exact_trajectory_set---")

    msize = 15
    fsize = 3

    #---------------------------- beautify plot ---------------------------
    # time calculation and state trajectory
    fig = plt.figure(figsize=(fsize, fsize))
    ax = fig.add_subplot(1, 1, 1)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    # set grid

    minor_ticks = [i/100 for i in range(101) if i % 20 != 0]
    major_ticks = [i/100 for i in range(0, 120, 20)]

    ax.set_xticks(minor_ticks, minor=True)
    ax.set_xticks(major_ticks, minor=False)
    ax.set_yticks(major_ticks, minor=False)
    ax.set_yticks(minor_ticks, minor=True)

    ax.grid(b=True, which='both', color='#CCCCCC', axis='both', linestyle='-', alpha=0.5)
    ax.tick_params(axis='both', which='both', labelsize=6)

    ax.set_xlabel('X (Non-Dim)')
    ax.set_ylabel('Y (Non-Dim)')

    st_point= g.start_state
    plt.scatter(g.xs[st_point[2]], g.ys[g.ni - 1 - st_point[1]], marker = 'o', s = msize, color = 'k', zorder = 1e5)
    plt.scatter(g.xs[g.endpos[1]], g.ys[g.ni - 1 - g.endpos[0]], marker = '*', s = msize*2, color ='k', zorder = 1e5)
    plt.gca().set_aspect('equal', adjustable='box')

    jet = cm = plt.get_cmap('jet')
    cNorm = colors.Normalize(vmin=np.min(50), vmax=np.max(60))
    scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=jet)
    scalarMap._A = []


    # plt.quiver(X, Y, vStream_x[0, :, :], vStream_y[0, :, :])

    # _,m,n = vStream_x.shape
    bad_count =0

    t_list=[]
    G_list=[]
    traj_list = []
    sars_traj_list = []

    for rzn in test_id_list[n_test_paths_range[0]:n_test_paths_range[1]]:
        # print("rzn: ", rzn)

        g.set_state(g.start_state)
        dont_plot =False
        bad_flag = False
        # t = 0
        G = 0

        xtr = []
        ytr = []
        sars_traj = []

        s1 = g.start_state
        t, i, j = s1

        a = policy[g.current_state()]

        xtr.append(g.x)
        ytr.append(g.y)
        loop_count = 0
        # while not g.is_terminal() and g.if_within_actionable_time() and g.current_state:
        # print("__CHECK__ t, i, j")
        while True:
            loop_count += 1
            vx, vy = extract_velocity(vel_field_data, t, i, j, rzn)
            r = g.move_exact(a, vx, vy)

            G = G + r
            s2 = g.current_state()
            (t, i, j) = s2

            sars_traj.append((s1, a, r, s2))
            xtr.append(g.x)
            ytr.append(g.y)

            s1 = s2 #for next iteration of loop
            if g.if_edge_state((i,j)):
                bad_count += 1
                # dont_plot=True
                break

            if (not g.is_terminal(almost = True)) and  g.if_within_actionable_time():
                a = policy[g.current_state()]
            elif g.is_terminal(almost = True):
                break
            else:
            #  i.e. not terminal and not in actinable time.
            # already checked if ternminal or not. If not terminal 
            # if time reaches nt ie not within actionable time, then increment badcount and Dont plot
                bad_count+=1
                bad_flag=True
                # dont_plot=True
                break

            #Debugging measure: additional check to break loop because code gets stuck sometims
            if loop_count > g.ni * c_ni:
                print("t: ", t)
                print("g.current_state: ", g.current_state())
                print("xtr: ",xtr)
                print("ytr: ",ytr)
                break

            # if t > g.ni * c_ni: #if trajectory goes haywire, dont plot it.
            #     bad_count+=1
            #     dont_plot=True
            #     break

        if dont_plot==False:
            colorval = scalarMap.to_rgba(t)
            # plt.plot(plot_set[i][0], plot_set[i][1], color=colorval, alpha=0.6)
            plt.plot(xtr, ytr, color=colorval, alpha=0.6)

        # if bad flag is True then append None to the list. These nones are counted later
        if bad_flag == False:  
            sars_traj_list.append(sars_traj) 
            traj_list.append((xtr,ytr))
            t_list.append(t)
            G_list.append(G)
        #ADDED for trajactory comparison
        else:
            sars_traj_list.append(None) 
            traj_list.append(None)
            t_list.append(None)
            G_list.append(None)


    if fname != None:
        plt.savefig(join(fpath,fname),bbox_inches = "tight", dpi=200)
        plt.cla()
        plt.close(fig)
        print("*** pickling traj_list ***")
        picklePolicy(traj_list, join(fpath,fname + 'coord_traj'))
        # picklePolicy(sars_traj_list,join(fpath,'sars_traj_'+fname) )
        print("*** pickled ***")

    return t_list, G_list, bad_count
Beispiel #7
0
def plot_exact_trajectory_set_DP(g,
                                 policy,
                                 X,
                                 Y,
                                 vel_field_data,
                                 test_id_list,
                                 fpath,
                                 fname='Trajectories'):

    # time calculation and state trajectory
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(1, 1, 1)
    # set grid
    minor_xticks = np.arange(g.xs[0] - 0.5 * g.dj, g.xs[-1] + 2 * g.dj, g.dj)
    minor_yticks = np.arange(g.ys[0] - 0.5 * g.di, g.ys[-1] + 2 * g.di, g.di)

    major_xticks = np.arange(g.xs[0], g.xs[-1] + 2 * g.dj, 5 * g.dj)
    major_yticks = np.arange(g.ys[0], g.ys[-1] + 2 * g.di, 5 * g.di)

    ax.set_xticks(minor_xticks, minor=True)
    ax.set_yticks(minor_yticks, minor=True)
    ax.set_xticks(major_xticks)
    ax.set_yticks(major_yticks)

    ax.grid(which='major', color='#CCCCCC', linestyle='')
    ax.grid(which='minor', color='#CCCCCC', linestyle='--')
    st_point = g.start_state
    plt.scatter(g.xs[st_point[2]], g.ys[g.ni - 1 - st_point[1]], c='g')
    plt.scatter(g.xs[g.endpos[1]], g.ys[g.ni - 1 - g.endpos[0]], c='r')
    plt.grid()
    plt.gca().set_aspect('equal', adjustable='box')

    # plt.quiver(X, Y, vel_field_data[0][0,:,:], vel_field_data[1][0, :, :])

    # nt, nrzns, nmodes = vel_field_data[4].shape #vel_field_data[4] is all_Yi

    bad_count = 0
    t_list_all = []
    t_list_reached = []
    G_list = []
    traj_list = []

    for rzn in test_id_list:
        g.set_state(g.start_state)
        dont_plot = False
        bad_flag = False
        # t = 0
        G = 0

        xtr = []
        ytr = []

        t, i, j = g.start_state

        a = policy[g.current_state()]
        xtr.append(g.x)
        ytr.append(g.y)

        # while (not g.is_terminal()) and g.if_within_actionable_time():
        while True:
            vx, vy = extract_velocity(vel_field_data, t, i, j, rzn)
            r = g.move_exact(a, vx, vy)
            G = G + r
            # t += 1
            (t, i, j) = g.current_state()

            xtr.append(g.x)
            ytr.append(g.y)

            # if edge state encountered, then increment badcount and Dont plot
            if g.if_edge_state((i, j)):
                bad_count += 1
                # dont_plot=True
                break

            if (not g.is_terminal(
                    almost=True)) and g.if_within_actionable_time():
                a = policy[g.current_state()]
            elif g.is_terminal(almost=True):
                break
            else:
                #  i.e. not terminal and not in actinable time.
                # already checked if ternminal or not. If not terminal
                # if time reaches nt ie not within actionable time, then increment badcount and Dont plot
                bad_count += 1
                bad_flag = True
                # dont_plot=True
                break

        if dont_plot == False:
            plt.plot(xtr, ytr)

        if bad_flag == False:
            traj_list.append((xtr, ytr))
            t_list_all.append(t)
            G_list.append(G)
            t_list_reached.append(t)

        #ADDED for trajactory comparison
        else:
            traj_list.append(None)
            t_list_all.append(None)
            G_list.append(None)

    if fname != None:
        plt.savefig(join(fpath, fname), dpi=300)
        print("*** pickling traj_list ***")
        picklePolicy(traj_list, join(fpath, fname))
        print("*** pickled ***")

    bad_count_tuple = (bad_count,
                       str(bad_count * 100 / len(test_id_list)) + '%')
    return t_list_all, t_list_reached, G_list, bad_count_tuple
Beispiel #8
0
def plot_exact_trajectory_set(g,
                              policy,
                              X,
                              Y,
                              vel_field_data,
                              nmodes,
                              train_id_set,
                              test_id_set,
                              goodlist,
                              fpath,
                              fname='Trajectories'):
    """
    Makes plots across all rzns with different colors for test and train data
    returns list for all rzns.
    """

    # time calculation and state trajectory
    print("--- in plot_functions.plot_exact_trajectory_set---")

    msize = 15
    # fsize = 3

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(1, 1, 1)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    minor_ticks = [i / 100 for i in range(101) if i % 20 != 0]
    major_ticks = [i / 100 for i in range(0, 120, 20)]

    ax.set_xticks(minor_ticks, minor=True)
    ax.set_xticks(major_ticks, minor=False)
    ax.set_yticks(major_ticks, minor=False)
    ax.set_yticks(minor_ticks, minor=True)

    ax.grid(b=True,
            which='both',
            color='#CCCCCC',
            axis='both',
            linestyle='-',
            alpha=0.5)
    ax.tick_params(axis='both', which='both', labelsize=6)

    ax.set_xlabel('X (Non-Dim)')
    ax.set_ylabel('Y (Non-Dim)')

    st_point = g.start_state
    plt.scatter(g.xs[st_point[1]],
                g.ys[g.ni - 1 - st_point[0]],
                marker='o',
                s=msize,
                color='k',
                zorder=1e5)
    plt.scatter(g.xs[g.endpos[1]],
                g.ys[g.ni - 1 - g.endpos[0]],
                marker='*',
                s=msize * 2,
                color='k',
                zorder=1e5)
    plt.gca().set_aspect('equal', adjustable='box')

    # plt.quiver(X, Y, vStream_x[0, :, :], vStream_y[0, :, :])

    # n_rzn,m,n = vStream_x.shape
    bad_count = 0

    t_list = []
    G_list = []
    traj_list = []

    for rzn in goodlist:
        # print("rzn: ", rzn)
        color = 'r'
        if rzn in train_id_set:
            color = 'b'
        elif rzn in test_id_set:
            color = 'g'

        g.set_state(g.start_state)
        dont_plot = False
        bad_flag = False
        # t = 0
        G = 0

        xtr = []
        ytr = []

        t, i, j = g.start_state

        a = policy[g.current_state()]

        xtr.append(g.x)
        ytr.append(g.y)
        loop_count = 0
        # while not g.is_terminal() and g.if_within_actionable_time() and g.current_state:
        # print("__CHECK__ t, i, j")
        while True:
            loop_count += 1
            vx, vy = extract_velocity(vel_field_data, t, i, j, rzn)
            # r = g.move_exact(a, vStream_x[rzn, i, j], vStream_y[rzn, i, j])
            r = g.move_exact(a, vx, vy)

            G = G + r
            s = g.current_state()
            (t, i, j) = s

            xtr.append(g.x)
            ytr.append(g.y)

            if g.if_edge_state((i, j)):
                bad_count += 1
                # dont_plot=True
                break

            if (not g.is_terminal(
                    almost=True)) and g.if_within_actionable_time():
                a = policy[g.current_state()]
            elif g.is_terminal(almost=True):
                break
            else:
                #  i.e. not terminal and not in actinable time.
                # already checked if ternminal or not. If not terminal
                # if time reaches nt ie not within actionable time, then increment badcount and Dont plot
                bad_count += 1
                bad_flag = True
                # dont_plot=True
                break

            #Debugging measure: additional check to break loop because code gets stuck sometims
            if loop_count > g.ni * c_ni:
                print("t: ", t)
                print("g.current_state: ", g.current_state())
                print("xtr: ", xtr)
                print("ytr: ", ytr)
                break

            # if t > g.ni * c_ni: #if trajectory goes haywire, dont plot it.
            #     bad_count+=1
            #     dont_plot=True
            #     break

        if dont_plot == False:
            if color == 'g':
                plt.plot(xtr, ytr, color=color, zorder=1e5)
            else:
                plt.plot(xtr, ytr, color=color)

        # if bad flag is True then append None to the list. These nones are counted later
        if bad_flag == False:
            traj_list.append((xtr, ytr))
            t_list.append(t)
            G_list.append(G)
        #ADDED for trajactory comparison
        else:
            traj_list.append(None)
            t_list.append(None)
            G_list.append(None)

    if fname != None:

        plt.savefig(join(fpath, fname), bbox_inches="tight", dpi=200)
        plt.cla()
        plt.close(fig)
        print("*** pickling traj_list ***")
        picklePolicy(traj_list, join(fpath, fname))
        print("*** pickled ***")

    return t_list, G_list, bad_count