示例#1
0
 def write_dict(self, dict, base_path):
     cleared_direcories = []
     for key, value in dict.items():
         path = os.path.join(self.f_prefix, base_path, key[0])
         ext_removed_file_name = remove_file_extention(key[1])
         file_name = ext_removed_file_name + "_" + str(key[2])
         file_name = add_file_extention(file_name, 'txt')
         self.dataloader.write_dataset(value, file_name, path)
示例#2
0
def main():


    parser = argparse.ArgumentParser()

    # frame rate of video
    parser.add_argument('--frame', type=int, default=1,
                        help='Frame of video created from plots')
    # gru model
    parser.add_argument('--gru', action="store_true", default=False,
                        help='Visualization of GRU model')
    # number of validation dataset
    parser.add_argument('--num_of_data', type=int, default=3,
                        help='Number of validation data will be visualized (If 0 is given, will work on test data mode)')
    # drive support
    parser.add_argument('--drive', action="store_true", default=False,
                        help='Use Google drive or not')
    # minimum lenght of trajectory
    parser.add_argument('--min_traj', type=int,  default=3,
                        help='Min. treshold of number of frame to be removed from a sequence')
    # percentage of peds will be taken for each frame
    parser.add_argument('--max_ped_ratio', type=float,  default=0.8,
                        help='Percentage of pedestrian will be illustrated in a plot for a sequence')
    # maximum ped numbers
    parser.add_argument('--max_target_ped', type=int,  default=20,
                        help='Maximum number of peds in final plot')
    # method to be visualized
    parser.add_argument('--method', type=int, default=1,
                        help='Method of lstm will be used (1 = social lstm, 2 = obstacle lstm, 3 = vanilla lstm)')


    # Parse the parameters
    args = parser.parse_args()

    prefix = ''
    f_prefix = '.'
    if args.drive is True:
      prefix='drive/semester_project/social_lstm_final/'
      f_prefix = 'drive/semester_project/social_lstm_final'

    model_name = "LSTM"
    method_name = get_method_name(args.method)
    if args.gru:
        model_name = "GRU"

    
    plot_file_directory = 'validation'

    # Directories
    if args.num_of_data is 0:
        plot_file_directory = 'test'

    # creation of paths
    save_plot_directory = os.path.join(f_prefix, 'plot',method_name, model_name,'plots/')
    plot_directory = os.path.join(f_prefix, 'plot', method_name, model_name, plot_file_directory)
    video_directory = os.path.join(f_prefix, 'plot',method_name, model_name,'videos/')
    plot_file_name = get_all_file_names(plot_directory)
    num_of_data = np.clip(args.num_of_data, 0, len(plot_file_name))
    plot_file_name = random.sample(plot_file_name, num_of_data)

    
    for file_index in range(len(plot_file_name)):
        file_name = plot_file_name[file_index]
        folder_name = remove_file_extention(file_name)
        print("Now processing: ", file_name)

        file_path = os.path.join(plot_directory, file_name)
        video_save_directory = os.path.join(video_directory, folder_name)
        figure_save_directory = os.path.join(save_plot_directory, folder_name)

        # remove existed plots
        clear_folder(video_save_directory)
        clear_folder(figure_save_directory)


        if not os.path.exists(video_save_directory):
            os.makedirs(video_save_directory)
        if not os.path.exists(figure_save_directory):
            os.makedirs(figure_save_directory)
        

        try:
            f = open(file_path, 'rb')
        except FileNotFoundError:
            print("File not found: %s"%file_path)
            continue


        results = pickle.load(f)
        result_arr = np.array(results)
        true_trajectories = np.array(result_arr[:,0])
        pred_trajectories = np.array(result_arr[:,1])
        frames = np.array(result_arr[:, 4])

        target_id_trajs = []
        args.max_target_ped = np.clip(args.max_target_ped, 0, len(results)-1)
        
        min_r = -10
        max_r = 10
        plot_offset = 1

        for i in range(len(results)):
            print("##########################################################################################")
            name = 'sequence' + str(i).zfill(5)
            print("Now processing seq: ",name)

            if args.num_of_data is 0: #test data visualization
                target_traj = plot_trajectories(results[i][0], results[i][1], results[i][2], results[i][3], results[i][4], name, figure_save_directory,  args.min_traj ,args.max_ped_ratio, results[i][5], [min_r, max_r, plot_offset], results[i][6])
            else:
                target_traj =  plot_trajectories(results[i][0], results[i][1], results[i][2], results[i][3],results[i][4], name, figure_save_directory, args.min_traj ,args.max_ped_ratio, results[i][5], [min_r, max_r, plot_offset], 20)
            target_traj.append(results[i][2])#pedlist
            target_traj.append(results[i][3])#lookup
            target_id_trajs.append(target_traj)

        
        save_video(figure_save_directory, video_save_directory, plot_file_name[file_index], args.frame)
        plot_target_trajs(target_id_trajs, figure_save_directory, args.max_target_ped, plot_offset)