def visualize_bar(data, bin_size=2.0, title='Bar Graph of Key-Value Pair', xlabel='index', ylabel='count', vis=True, save_path=None, debug=True, closefig=True): ''' visualize the bar graph of a data, which can be a dictionary or list of dictionary different from function of visualize_bar_graph, this function does not depend on panda and dataframe, it's simpler but with less functionality also the key of this function takes continuous scalar variable ''' if debug: assert isstring(title) and isstring(xlabel) and isstring( ylabel), 'title/xlabel/ylabel is not correct' assert isdict(data) or islist(data), 'input data is not correct' assert isscalar(bin_size), 'the bin size is not a floating number' if isdict(data): index_list = data.keys() if debug: assert islistofscalar( index_list ), 'the input dictionary does not contain a scalar key' frequencies = data.values() else: index_list = range(len(data)) frequencies = data index_str_list = scalarlist2strlist(index_list, debug=debug) index_list = np.array(index_list) fig, ax = get_fig_ax_helper(fig=None, ax=None) # ax.set_xticks(index_list) # ax.set_xticklabels(index_str_list) plt.bar(index_list, frequencies, bin_size, color='r', alpha=0.5) plt.title(title, fontsize=20) plt.xlabel(xlabel) plt.ylabel(ylabel) return save_vis_close_helper(fig=fig, ax=ax, vis=vis, save_path=save_path, debug=debug, transparent=False, closefig=closefig)
def visualize_image_with_pts(input_image, input_pts, color_index=0, pts_size=20, vis_threshold=0.3, label=False, label_list=None, label_size=20, bgr2rgb=False, save_path=None, vis=False, warning=True, debug=True, closefig=True): ''' visualize an image and plot points on top of it parameters: input_image: a pil or numpy image input_pts: 2(3) x num_pts numpy array or a dictionary of 2(3) x num_pts array when there are 3 channels in pts, the third one denotes the occlusion/confidence flag occlusion: 0 -> invisible and not annotated, 1 -> visible and annotated, -1 -> visible but not annotated color_index: a scalar or a list of color indexes vis_threshold: the points with confidence above the threshold will be drawn label: determine to add text label for each point, if label list is None, then an automatic list is created label_list: label string for all points, if label list is not None, the label is True automatically if the input points is a dictionary, then every point array in the dict follow the same label list bgr2rgb: true if the image needs to be converted from bgr to rgb pts_size: size of points label_size: font of labels outputs: fig, ax: figure handle for future use ''' fig, ax = visualize_image(input_image, bgr2rgb=bgr2rgb, vis=False, save_path=None, warning=warning, debug=debug, closefig=False) if isdict(input_pts): for pts_id, pts_array_tmp in input_pts.items(): visualize_pts_array(pts_array_tmp, fig=fig, ax=ax, color_index=color_index, pts_size=pts_size, label=label, label_list=label_list, label_size=label_size, plot_occl=False, covariance=False, xlim=None, ylim=None, vis_threshold=vis_threshold, debug=debug, vis=False, save_path=None, warning=warning, closefig=False) color_index += 1 else: visualize_pts_array(input_pts, fig=fig, ax=ax, color_index=color_index, pts_size=pts_size, label=label, label_list=label_list, label_size=label_size, plot_occl=False, covariance=False, xlim=None, ylim=None, vis_threshold=vis_threshold, debug=debug, vis=False, save_path=None, warning=warning, closefig=False) return save_vis_close_helper(fig=fig, ax=ax, vis=vis, save_path=save_path, debug=debug, warning=warning, closefig=closefig)
def facial_landmark_evaluation(pred_dict_all, anno_dict, num_pts, error_threshold, normalization_ced=True, normalization_vec=False, covariance=True, display_list=None, debug=True, vis=False, save=True, save_path=None): ''' evaluate the performance of facial landmark detection parameter: pred_dict_all: a dictionary for all basline methods. Each key is the method name and the value is corresponding prediction dictionary, which keys are the image path and values are 2 x N prediction results anno_dict: a dictionary which keys are the image path and values are 2 x N annotation results num_pts: number of points vis: determine if visualizing the pck curve save: determine if saving the visualization results save_path: a directory to save all the results visualization: 1. 2d pck curve (total and point specific) for all points for all methods 2. point error vector (total and point specific) for all points and for all methods 3. mean square error return: metrics_all: a list of list to have detailed metrics over all methods ptswise_mse: a list of list to have average MSE over all key-points for all methods ''' num_methods = len(pred_dict_all) if debug: assert isdict(pred_dict_all) and num_methods > 0 and all(isdict(pred_dict) for pred_dict in pred_dict_all.values()), 'predictions result format is not correct' assert isdict(anno_dict), 'annotation result format is not correct' assert ispositiveinteger(num_pts), 'number of points is not correct' assert isscalar(error_threshold), 'error threshold is not correct' assert islogical(normalization_ced) and islogical(normalization_vec), 'normalization flag is not correct' if display_list is not None: assert len(display_list) == num_methods, 'display list is not correct %d vs %d' % (len(display_list), num_methods) num_images = len(pred_dict_all.values()[0]) if debug: assert num_images > 0, 'the predictions are empty' assert num_images == len(anno_dict), 'number of images is not equal to number of annotations: %d vs %d' % (num_images, len(anno_dict)) assert all(num_images == len(pred_dict) for pred_dict in pred_dict_all.values()), 'number of images in results from different methods are not equal' # calculate normalized mean error for each single image based on point-to-point Euclidean distance normalized by the bounding box size # calculate point error vector for each single image based on error vector normalized by the bounding box size normed_mean_error_dict = dict() normed_mean_error_pts_specific_dict = dict() normed_mean_error_pts_specific_valid_dict = dict() pts_error_vec_dict = dict() pts_error_vec_pts_specific_dict = dict() mse_error_dict_dict = dict() for method_name, pred_dict in pred_dict_all.items(): normed_mean_error_total = np.zeros((num_images, ), dtype='float32') normed_mean_error_pts_specific = np.zeros((num_images, num_pts), dtype='float32') normed_mean_error_pts_specific_valid = np.zeros((num_images, num_pts), dtype='bool') pts_error_vec = np.zeros((num_images, 2), dtype='float32') pts_error_vec_pts_specific = np.zeros((num_images, 2, num_pts), dtype='float32') mse_error_dict = dict() count = 0 count_skip_num_images = 0 # it's possible that no annotation exists on some images, than no error should be counted for those images, we count the number of those images for image_path, pts_prediction in pred_dict.items(): _, filename, _ = fileparts(image_path) pts_anno = anno_dict[filename] # 2 x N annotation pts_keep_index = range(num_pts) # to avoid list object type, do conversion here if islist(pts_anno): pts_anno = np.asarray(pts_anno) if islist(pts_prediction): pts_prediction = np.asarray(pts_prediction) if debug: assert (is2dptsarray(pts_anno) or is2dptsarray_occlusion(pts_anno)) and pts_anno.shape[1] == num_pts, 'shape of annotations is not correct (%d x %d) vs (%d x %d)' % (2, num_pts, pts_anno.shape[0], pts_anno.shape[1]) # if the annotation has 3 channels (include extra occlusion channel, we keep only the points with annotations) # occlusion: -1 -> visible but not annotated, 0 -> invisible and not annotated, 1 -> visible, we keep only visible and annotated points if pts_anno.shape[0] == 3: pts_keep_index = np.where(pts_anno[2, :] == 1)[0].tolist() if len(pts_keep_index) <= 0: # if no point is annotated in current image count_skip_num_images += 1 continue pts_anno = pts_anno[0:2, pts_keep_index] pts_prediction = pts_prediction[:, pts_keep_index] # to avoid the point location includes the score or occlusion channel, only take the first two channels here if pts_prediction.shape[0] == 3 or pts_prediction.shape[0] == 4: pts_prediction = pts_prediction[0:2, :] num_pts_tmp = len(pts_keep_index) if debug: assert pts_anno.shape[1] <= num_pts, 'number of points is not correct: %d vs %d' % (pts_anno.shape[1], num_pts) assert pts_anno.shape == pts_prediction.shape, 'shape of annotations and predictions are not the same {} vs {}'.format(print_np_shape(pts_anno, debug=debug), print_np_shape(pts_prediction, debug=debug)) # print 'number of points to keep is %d' % num_pts_tmp # calculate bbox for normalization if normalization_ced or normalization_vec: assert len(pts_keep_index) == num_pts, 'some points are not annotated. Normalization on PCK curve is not allowed.' bbox_anno = pts2bbox(pts_anno, debug=debug) # 1 x 4 bbox_TLWH = bbox_TLBR2TLWH(bbox_anno, debug=debug) # 1 x 4 bbox_size = math.sqrt(bbox_TLWH[0, 2] * bbox_TLWH[0, 3]) # scalar # calculate normalized error for all points normed_mean_error, _ = pts_euclidean(pts_prediction, pts_anno, debug=debug) # scalar if normalization_ced: normed_mean_error /= bbox_size normed_mean_error_total[count] = normed_mean_error mse_error_dict[image_path] = normed_mean_error if normed_mean_error == 0: print pts_prediction print pts_anno # calculate normalized error point specifically for pts_index in xrange(num_pts): if pts_index in pts_keep_index: # if current point not annotated in current image, just keep 0 normed_mean_error_pts_specific_valid[count, pts_index] = True else: continue pts_index_from_keep_list = pts_keep_index.index(pts_index) pts_prediction_tmp = np.reshape(pts_prediction[:, pts_index_from_keep_list], (2, 1)) pts_anno_tmp = np.reshape(pts_anno[:, pts_index_from_keep_list], (2, 1)) normed_mean_error_pts_specifc_tmp, _ = pts_euclidean(pts_prediction_tmp, pts_anno_tmp, debug=debug) if normalization_ced: normed_mean_error_pts_specifc_tmp /= bbox_size normed_mean_error_pts_specific[count, pts_index] = normed_mean_error_pts_specifc_tmp # calculate the point error vector error_vector = pts_prediction - pts_anno # 2 x num_pts_tmp if normalization_vec: error_vector /= bbox_size pts_error_vec_pts_specific[count, :, pts_keep_index] = np.transpose(error_vector) pts_error_vec[count, :] = np.sum(error_vector, axis=1) / num_pts_tmp count += 1 print 'number of skipped images is %d' % count_skip_num_images assert count + count_skip_num_images == num_images, 'all cells in the array must be filled %d vs %d' % (count + count_skip_num_images, num_images) # print normed_mean_error_total # time.sleep(1000) # save results to dictionary normed_mean_error_dict[method_name] = normed_mean_error_total[:count] normed_mean_error_pts_specific_dict[method_name] = normed_mean_error_pts_specific[:count, :] normed_mean_error_pts_specific_valid_dict[method_name] = normed_mean_error_pts_specific_valid[:count, :] pts_error_vec_dict[method_name] = np.transpose(pts_error_vec[:count, :]) # 2 x num_images pts_error_vec_pts_specific_dict[method_name] = pts_error_vec_pts_specific[:count, :, :] mse_error_dict_dict[method_name] = mse_error_dict # calculate mean value if mse: mse_value = dict() # dictionary to record all average MSE for different methods mse_dict = dict() # dictionary to record all point-wise MSE for different keypoints for method_name, error_array in normed_mean_error_dict.items(): mse_value[method_name] = np.mean(error_array) else: mse_value = None # save mse error list to file for each method error_list_savedir = os.path.join(save_path, 'error_list') mkdir_if_missing(error_list_savedir) for method_name, mse_error_dict in mse_error_dict_dict.items(): mse_error_list_path = os.path.join(error_list_savedir, 'error_%s.txt' % method_name) mse_error_list = open(mse_error_list_path, 'w') sorted_tuple_list = sorted(mse_error_dict.items(), key=operator.itemgetter(1), reverse=True) for tuple_index in range(len(sorted_tuple_list)): image_path_tmp = sorted_tuple_list[tuple_index][0] mse_error_tmp = sorted_tuple_list[tuple_index][1] mse_error_list.write('{:<200} {}\n'.format(image_path_tmp, '%.2f' % mse_error_tmp)) mse_error_list.close() print '\nsave mse error list for %s to %s' % (method_name, mse_error_list_path) # visualize the ced (cumulative error distribution curve) print('visualizing pck curve....\n') pck_savedir = os.path.join(save_path, 'pck') mkdir_if_missing(pck_savedir) pck_savepath = os.path.join(pck_savedir, 'pck_curve_overall.png') table_savedir = os.path.join(save_path, 'metrics') mkdir_if_missing(table_savedir) table_savepath = os.path.join(table_savedir, 'detailed_metrics_overall.txt') _, metrics_all = visualize_ced(normed_mean_error_dict, error_threshold=error_threshold, normalized=normalization_ced, truncated_list=truncated_list, title='2D PCK curve (all %d points)' % num_pts, display_list=display_list, debug=debug, vis=vis, pck_savepath=pck_savepath, table_savepath=table_savepath) metrics_title = ['Method Name / Point Index'] ptswise_mse_table = [[normed_mean_error_pts_specific_dict.keys()[index_tmp]] for index_tmp in xrange(num_methods)] for pts_index in xrange(num_pts): metrics_title.append(str(pts_index + 1)) normed_mean_error_dict_tmp = dict() for method_name, error_array in normed_mean_error_pts_specific_dict.items(): normed_mean_error_pts_specific_valid_temp = normed_mean_error_pts_specific_valid_dict[method_name] # Some points at certain images might not be annotated. When calculating MSE for these specific point, we remove those images to avoid "false" mean average error valid_array_per_pts_per_method = np.where(normed_mean_error_pts_specific_valid_temp[:, pts_index] == True)[0].tolist() error_array_per_pts = error_array[:, pts_index] error_array_per_pts = error_array_per_pts[valid_array_per_pts_per_method] num_image_tmp = len(valid_array_per_pts_per_method) # print(num_image_tmp) if num_image_tmp == 0: continue # aaa normed_mean_error_dict_tmp[method_name] = np.reshape(error_array_per_pts, (num_image_tmp, )) pck_savepath = os.path.join(pck_savedir, 'pck_curve_pts_%d.png' % (pts_index+1)) table_savepath = os.path.join(table_savedir, 'detailed_metrics_pts_%d.txt' % (pts_index+1)) if len(normed_mean_error_dict_tmp) == 0: continue metrics_dict, _ = visualize_ced(normed_mean_error_dict_tmp, error_threshold=error_threshold, normalized=normalization_ced, truncated_list=truncated_list, display2terminal=False, title='2D PCK curve for point %d' % (pts_index+1), display_list=display_list, debug=debug, vis=vis, pck_savepath=pck_savepath, table_savepath=table_savepath) for method_index in range(num_methods): method_name = normed_mean_error_pts_specific_dict.keys()[method_index] ptswise_mse_table[method_index].append('%.1f' % metrics_dict[method_name]['MSE']) # reorder the table order_index_list = [display_list.index(method_name_tmp) for method_name_tmp in normed_mean_error_pts_specific_dict.keys()] order_index_list = [0] + [order_index_tmp + 1 for order_index_tmp in order_index_list] # print table to terminal ptswise_mse_table = list_reorder([metrics_title] + ptswise_mse_table, order_index_list, debug=debug) table = AsciiTable(ptswise_mse_table) print '\nprint point-wise average MSE' print table.table # save table to file ptswise_savepath = os.path.join(table_savedir, 'pointwise_average_MSE.txt') table_file = open(ptswise_savepath, 'w') table_file.write(table.table) table_file.close() print '\nsave point-wise average MSE to %s' % ptswise_savepath # visualize the error vector map # print('visualizing error vector distribution map....\n') # error_vec_save_dir = os.path.join(save_path, 'error_vec') # mkdir_if_missing(error_vec_save_dir) # savepath_tmp = os.path.join(error_vec_save_dir, 'error_vector_distribution_all.png') # visualize_pts(pts_error_vec_dict, title='Point Error Vector Distribution (all %d points)' % num_pts, mse=mse, mse_value=mse_value, display_range=display_range, display_list=display_list, xlim=xlim, ylim=ylim, covariance=covariance, debug=debug, vis=vis, save_path=savepath_tmp) # for pts_index in xrange(num_pts): # pts_error_vec_pts_specific_dict_tmp = dict() # for method_name, error_vec_dict in pts_error_vec_pts_specific_dict.items(): # pts_error_vec_pts_specific_valid = normed_mean_error_pts_specific_valid_dict[method_name] # get valid flag # valid_image_index_per_pts = np.where(pts_error_vec_pts_specific_valid[:, pts_index] == True)[0].tolist() # get images where the points with current index are annotated # print(len(valid_image_index_per_pts)) # pts_error_vec_pts_specific_dict_tmp[method_name] = np.transpose(error_vec_dict[valid_image_index_per_pts, :, pts_index]) # 2 x num_images # savepath_tmp = os.path.join(error_vec_save_dir, 'error_vector_distribution_pts_%d.png' % (pts_index+1)) # if mse: # mse_dict_tmp = visualize_pts(pts_error_vec_pts_specific_dict_tmp, title='Point Error Vector Distribution for Point %d' % (pts_index+1), mse=mse, display_range=display_range, display_list=display_list, xlim=xlim, ylim=ylim, covariance=covariance, debug=debug, vis=vis, save_path=savepath_tmp) # mse_best = min(mse_dict_tmp.values()) # mse_single = dict() # mse_single['mse'] = mse_best # mse_single['num_images'] = len(valid_image_index_per_pts) # assume number of valid images is equal for all methods # mse_dict[pts_index] = mse_single # else: # visualize_pts(pts_error_vec_pts_specific_dict_tmp, title='Point Error Vector Distribution for Point %d' % (pts_index+1), mse=mse, display_range=display_range, display_list=display_list, xlim=xlim, ylim=ylim, covariance=covariance, debug=debug, vis=vis, save_path=savepath_tmp) # save mse to json file for further use # if mse: # json_path = os.path.join(save_path, 'mse_pts.json') # # if existing, compare and select the best # if is_path_exists(json_path): # with open(json_path, 'r') as file: # mse_dict_old = json.load(file) # file.close() # for pts_index, mse_single in mse_dict_old.items(): # mse_dict_new = mse_dict[int(pts_index)] # mse_new = mse_dict_new['mse'] # if mse_new < mse_single['mse']: # mse_single['mse'] = mse_new # mse_dict_old[pts_index] = mse_single # with open(json_path, 'w') as file: # print('overwrite old mse to {}'.format(json_path)) # json.dump(mse_dict_old, file) # file.close() # else: # with open(json_path, 'w') as file: # print('save mse for all keypoings to {}'.format(json_path)) # json.dump(mse_dict, file) # file.close() print('\ndone!!!!!\n') return metrics_all, ptswise_mse_table
def visualize_bar_graph(data, title='Bar Graph of Key-Value Pair', xlabel='pixel error', ylabel='keypoint index', label=False, label_list=None, vis=True, save_path=None, debug=True, closefig=True): ''' visualize the bar graph of a data, which can be a dictionary or list of dictionary inside each dictionary, the keys (string) should be the same which is the y label, the values should be scalar ''' if debug: assert isstring(title) and isstring(xlabel) and isstring( ylabel), 'title/xlabel/ylabel is not correct' assert isdict(data) or islistofdict(data), 'input data is not correct' if isdict(data): assert all( isstring(key_tmp) for key_tmp in data.keys()), 'the keys are not all strings' assert all( isscalar(value_tmp) for value_tmp in data.values()), 'the keys are not all strings' else: assert len(data) <= len( color_set ), 'number of data set is larger than number of color to use' keys = sorted(data[0].keys()) for dict_tmp in data: if not (sorted(dict_tmp.keys()) == keys): print(dict_tmp.keys()) print(keys) assert False, 'the keys are not equal across different input set' assert all(isstring(key_tmp) for key_tmp in dict_tmp.keys()), 'the keys are not all strings' assert all( isscalar(value_tmp) for value_tmp in dict_tmp.values()), 'the values are not all scalars' # convert dictionary to DataFrame data_new = dict() if isdict(data): key_list = data.keys() sorted_index = sorted(range(len(key_list)), key=lambda k: key_list[k]) data_new['names'] = (np.asarray(key_list)[sorted_index]).tolist() data_new['values'] = (np.asarray(data.values())[sorted_index]).tolist() else: key_list = data[0].keys() sorted_index = sorted(range(len(key_list)), key=lambda k: key_list[k]) data_new['names'] = (np.asarray(key_list)[sorted_index]).tolist() num_sets = len(data) for set_index in range(num_sets): data_new['value_%03d' % set_index] = (np.asarray( data[set_index].values())[sorted_index]).tolist() dataframe = DataFrame(data_new) # plot width = 2000 height = 2000 alpha = 0.5 figsize = width / float(dpi), height / float(dpi) fig = plt.figure(figsize=figsize) sns.set(style='whitegrid') # fig, ax = get_fig_ax_helper(fig=None, ax=None) if isdict(data): g = sns.barplot(x='values', y='names', data=dataframe, label='data', color='b') plt.legend(ncol=1, loc='lower right', frameon=True, fontsize=5) else: num_sets = len(data) for set_index in range(num_sets): if set_index == 0: sns.set_color_codes('pastel') else: sns.set_color_codes('muted') if label: sns.barplot(x='value_%03d' % set_index, y='names', data=dataframe, label=label_list[set_index], color=color_set[set_index], alpha=alpha) else: sns.barplot(x='value_%03d' % set_index, y='names', data=dataframe, color=solor_set[set_index], alpha=alpha) plt.legend(ncol=len(data), loc='lower right', frameon=True, fontsize=5) sns.despine(left=True, bottom=True) plt.title(title, fontsize=20) plt.xlim([0, 50]) plt.xlabel(xlabel) plt.ylabel(ylabel) num_yticks = len(data_new['names']) adaptive_fontsize = -0.0555556 * num_yticks + 15.111 plt.yticks(fontsize=adaptive_fontsize) return save_vis_close_helper(fig=fig, vis=vis, save_path=save_path, debug=debug, closefig=closefig)
def visualize_distribution(data, bin_size=None, vis=False, save_path=None, debug=True, closefig=True): ''' visualize the histogram of a data, which can be a dictionary or list or numpy array or tuple or a list of list ''' if debug: assert istuple(data) or isdict(data) or islist(data) or isnparray( data), 'input data is not correct' # convert data type if istuple(data): data = list(data) elif isdict(data): data = data.values() elif isnparray(data): data = data.tolist() num_bins = 1000.0 fig, ax = get_fig_ax_helper(fig=None, ax=None) # calculate bin size if bin_size is None: if islistoflist(data): max_value = np.max(np.max(data)) min_value = np.min(np.min(data)) else: max_value = np.max(data) min_value = np.min(data) bin_size = (max_value - min_value) / num_bins else: try: bin_size = float(bin_size) except TypeError: print('size of bin should be an float value') # plot if islistoflist(data): max_value = np.max(np.max(data)) min_value = np.min(np.min(data)) bins = np.arange(min_value - bin_size, max_value + bin_size, bin_size) # fixed bin size plt.xlim([min_value - bin_size, max_value + bin_size]) for data_list_tmp in data: if debug: assert islist(data_list_tmp), 'the nested list is not correct!' # plt.hist(data_list_tmp, bins=bins, alpha=0.3) sns.distplot(data_list_tmp, bins=bins, kde=False) # sns.distplot(data_list_tmp, bins=bins, kde=False) else: bins = np.arange( min(data) - 10 * bin_size, max(data) + 10 * bin_size, bin_size) # fixed bin size plt.xlim([min(data) - bin_size, max(data) + bin_size]) plt.hist(data, bins=bins, alpha=0.5) plt.title('distribution of data') plt.xlabel('data (bin size = %f)' % bin_size) plt.ylabel('count') return save_vis_close_helper(fig=fig, ax=ax, vis=vis, save_path=save_path, debug=debug, closefig=closefig)
def visualize_nearest_neighbor(featuremap_dict, num_neighbor=5, top_number=5, vis=True, save_csv=False, csv_save_path=None, save_vis=False, save_img=False, save_thumb_name='nearest_neighbor.png', img_src_folder=None, ext_filter='.jpg', nn_save_folder=None, debug=True): ''' visualize nearest neighbor for featuremap from images parameter: featuremap_dict: a dictionary contains image path as key, and featuremap as value, the featuremap needs to be numpy array with any shape. No flatten needed num_neighbor: number of neighbor to visualize, the first nearest is itself top_number: number of top to visualize, since there might be tons of featuremap (length of dictionary), we choose the top ten with lowest distance with their nearest neighbor csv_save_path: path to save .csv file which contains indices and distance array for all elements nn_save_folder: save the nearest neighbor images for top featuremap return: all_sorted_nearest_id: a 2d matrix, each row is a feature followed by its nearest neighbor in whole feature dataset, the column is sorted by the distance of all nearest neighbor each row selected_nearest_id: only top number of sorted nearest id ''' print('processing feature map to nearest neightbor.......') if debug: assert isdict(featuremap_dict), 'featuremap should be dictionary' assert all( isnparray(featuremap_tmp) for featuremap_tmp in featuremap_dict. values()), 'value of dictionary should be numpy array' assert isinteger( num_neighbor ) and num_neighbor > 1, 'number of neighborhodd is an integer larger than 1' if save_csv and csv_save_path is not None: assert is_path_exists_or_creatable( csv_save_path), 'path to save .csv file is not correct' if save_vis or save_img: if nn_save_folder is not None: # save image directly assert isstring(ext_filter), 'extension filter is not correct' assert is_path_exists( img_src_folder), 'source folder for image is not correct' assert all( isstring(path_tmp) for path_tmp in featuremap_dict.keys() ) # key should be the path for the image assert is_path_exists_or_creatable( nn_save_folder ), 'folder to save top visualized images is not correct' assert isstring( save_thumb_name), 'name of thumbnail is not correct' if ext_filter.find('.') == -1: ext_filter = '.%s' % ext_filter # flatten the feature map nn_feature_dict = dict() for key, featuremap_tmp in featuremap_dict.items(): nn_feature_dict[key] = featuremap_tmp.flatten() num_features = len(nn_feature_dict) # nearest neighbor featuremap = np.array(nn_feature_dict.values()) nearbrs = NearestNeighbors(n_neighbors=num_neighbor, algorithm='ball_tree').fit(featuremap) distances, indices = nearbrs.kneighbors(featuremap) if debug: assert featuremap.shape[ 0] == num_features, 'shape of feature map is not correct' assert indices.shape == ( num_features, num_neighbor), 'shape of indices is not correct' assert distances.shape == ( num_features, num_neighbor), 'shape of indices is not correct' # convert the nearest indices for all featuremap to the key accordingly id_list = nn_feature_dict.keys() max_length = len(max( id_list, key=len)) # find the maximum length of string in the key nearest_id = np.chararray(indices.shape, itemsize=max_length + 1) for x in range(nearest_id.shape[0]): for y in range(nearest_id.shape[1]): nearest_id[x, y] = id_list[indices[x, y]] if debug: assert list(nearest_id[:, 0]) == id_list, 'nearest neighbor has problem' # sort the feature based on distance print('sorting the feature based on distance') featuremap_distance = np.sum(distances, axis=1) if debug: assert featuremap_distance.shape == ( num_features, ), 'distance is not correct' sorted_indices = np.argsort(featuremap_distance) all_sorted_nearest_id = nearest_id[sorted_indices, :] # save to the csv file if save_csv and csv_save_path is not None: print('Saving nearest neighbor result as .csv to path: %s' % csv_save_path) with open(csv_save_path, 'w+') as file: np.savetxt(file, distances, delimiter=',', fmt='%f') np.savetxt(file, all_sorted_nearest_id, delimiter=',', fmt='%s') file.close() # choose the best to visualize selected_sorted_indices = sorted_indices[0:top_number] if debug: for i in range(num_features - 1): assert featuremap_distance[ sorted_indices[i]] < featuremap_distance[sorted_indices[ i + 1]], 'feature map is not well sorted based on distance' selected_nearest_id = nearest_id[selected_sorted_indices, :] if save_vis: fig, axarray = plt.subplots(top_number, num_neighbor) for index in range(top_number): for nearest_index in range(num_neighbor): img_path = os.path.join( img_src_folder, '%s%s' % (selected_nearest_id[index, nearest_index], ext_filter)) if debug: print('loading image from %s' % img_path) img = imread(img_path) if isgrayimage_dimension(img): axarray[index, nearest_index].imshow(img, cmap='gray') elif iscolorimage_dimension(img): axarray[index, nearest_index].imshow(img) else: assert False, 'unknown error' axarray[index, nearest_index].axis('off') save_thumb = os.path.join(nn_save_folder, save_thumb_name) fig.savefig(save_thumb) if vis: plt.show() plt.close(fig) # save top visualization to the folder if save_img and nn_save_folder is not None: for top_index in range(top_number): file_list = selected_nearest_id[top_index] save_subfolder = os.path.join(nn_save_folder, file_list[0]) mkdir_if_missing(save_subfolder) for file_tmp in file_list: file_src = os.path.join(img_src_folder, '%s%s' % (file_tmp, ext_filter)) save_path = os.path.join(save_subfolder, '%s%s' % (file_tmp, ext_filter)) if debug: print('saving %s to %s' % (file_src, save_path)) shutil.copyfile(file_src, save_path) return all_sorted_nearest_id, selected_nearest_id
def visualize_ced(normed_mean_error_dict, error_threshold, normalized=True, truncated_list=None, display2terminal=True, display_list=None, title='2D PCK curve', debug=True, vis=False, pck_savepath=None, table_savepath=None, closefig=True): ''' visualize the cumulative error distribution curve (alse called NME curve or pck curve) all parameters are represented by percentage parameter: normed_mean_error_dict: a dictionary whose keys are the method name and values are (N, ) numpy array to represent error in evaluation error_threshold: threshold to display in x axis return: AUC: area under the curve MSE: mean square error ''' if debug: assert isdict( normed_mean_error_dict ), 'the input normalized mean error dictionary is not correct' assert islogical( normalized), 'the normalization flag should be logical' if normalized: assert error_threshold > 0 and error_threshold < 100, 'threshold percentage is not well set' if save: assert is_path_exists_or_creatable( pck_savepath ), 'please provide a valid path to save the pck results' assert is_path_exists_or_creatable( table_savepath ), 'please provide a valid path to save the table results' assert isstring(title), 'title is not correct' if truncated_list is not None: assert islistofscalar( truncated_list), 'the input truncated list is not correct' if display_list is not None: assert islist(display_list) and len( display_list) == len(normed_mean_error_dict ), 'the input display list is not correct' assert CHECK_EQ_LIST_UNORDERED( display_list, normed_mean_error_dict.keys(), debug=debug ), 'the input display list does not match the error dictionary key list' else: display_list = normed_mean_error_dict.keys() # set display parameters width, height = 1000, 800 legend_fontsize = 10 scale_distance = 48.8 line_index, color_index = 0, 0 figsize = width / float(dpi), height / float(dpi) fig = plt.figure(figsize=figsize) # set figure handle num_bins = 1000 if normalized: maximum_x = 1 scale = num_bins / 100 else: maximum_x = error_threshold + 1 scale = num_bins / maximum_x x_axis = np.linspace( 0, maximum_x, num_bins) # error axis, percentage of normalization factor y_axis = np.zeros(num_bins) interval_y = 10 interval_x = 1 plt.xlim(0, error_threshold) plt.ylim(0, 100) plt.yticks(np.arange(0, 100 + interval_y, interval_y)) plt.xticks(np.arange(0, error_threshold + interval_x, interval_x)) plt.grid() plt.title(title, fontsize=20) if normalized: plt.xlabel('Normalized error euclidean distance (%)', fontsize=16) else: plt.xlabel('Absolute error euclidean distance', fontsize=16) # calculate metrics for each method num_methods = len(normed_mean_error_dict) num_images = len(normed_mean_error_dict.values()[0]) metrics_dict = dict() metrics_table = list() table_title = ['Method Name / Metrics', 'AUC', 'MSE'] append2title = False assert num_images > 0, 'number of error array should be larger than 0' for ordered_index in range(num_methods): method_name = display_list[ordered_index] normed_mean_error = normed_mean_error_dict[method_name] if debug: assert isnparray( normed_mean_error ) and normed_mean_error.ndim == 1, 'shape of error distance is not good' assert len( normed_mean_error ) == num_images, 'number of testing images should be equal for all methods' assert len(linestyle_set) * len(color_set) >= len( normed_mean_error_dict) color_tmp = color_set[color_index] line_tmp = linestyle_set[line_index] for i in range(num_bins): y_axis[i] = float( (normed_mean_error < x_axis[i]).sum()) / num_images # percentage of error # calculate area under the curve and mean square error entry = dict() entry['AUC'] = np.sum(y_axis[:error_threshold * scale]) / ( error_threshold * scale) # bigger, better entry['MSE'] = np.mean(normed_mean_error) # smaller, better metrics_table_tmp = [ str(method_name), '%.2f' % (entry['AUC']), '%.1f' % (entry['MSE']) ] if truncated_list is not None: tmse_dict = calculate_truncated_mse(normed_mean_error.tolist(), truncated_list, debug=debug) for threshold in truncated_list: entry['AUC/%s' % threshold] = np.sum(y_axis[:error_threshold * scale]) / ( error_threshold * scale) # bigger, better entry['MSE/%s' % threshold] = tmse_dict[threshold]['T-MSE'] entry['percentage/%s' % threshold] = tmse_dict[threshold]['percentage'] if not append2title: table_title.append('AUC/%s' % threshold) table_title.append('MSE/%s' % threshold) table_title.append('pct/%s' % threshold) metrics_table_tmp.append('%.2f' % (entry['AUC/%s' % threshold])) metrics_table_tmp.append('%.1f' % (entry['MSE/%s' % threshold])) metrics_table_tmp.append( '%.1f' % (100 * entry['percentage/%s' % threshold]) + '%') # print metrics_table_tmp metrics_table.append(metrics_table_tmp) append2title = True metrics_dict[method_name] = entry # draw label = '%s, AUC: %.2f, MSE: %.1f (%.0f um)' % ( method_name, entry['AUC'], entry['MSE'], entry['MSE'] * scale_distance) if normalized: plt.plot(x_axis * 100, y_axis * 100, color=color_tmp, linestyle=line_tmp, label=label, lw=3) else: plt.plot(x_axis, y_axis * 100, color=color_tmp, linestyle=line_tmp, label=label, lw=3) plt.legend(loc=4, fontsize=legend_fontsize) color_index += 1 if color_index / len(color_set) == 1: line_index += 1 color_index = color_index % len(color_set) # plt.grid() plt.ylabel('{} Test Images (%)'.format(num_images), fontsize=16) save_vis_close_helper(fig=fig, ax=None, vis=vis, transparent=False, save_path=pck_savepath, debug=debug, closefig=closefig) # reorder the table order_index_list = [ display_list.index(method_name_tmp) for method_name_tmp in normed_mean_error_dict.keys() ] order_index_list = [0] + [ order_index_tmp + 1 for order_index_tmp in order_index_list ] # print table to terminal metrics_table = [table_title] + metrics_table # metrics_table = list_reorder([table_title] + metrics_table, order_index_list, debug=debug) table = AsciiTable(metrics_table) if display2terminal: print('\nprint detailed metrics') print(table.table) # save table to file if table_savepath is not None: table_file = open(table_savepath, 'w') table_file.write(table.table) table_file.close() if display2terminal: print('\nsave detailed metrics to %s' % table_savepath) return metrics_dict, metrics_table
def generate_hdf5(data_src, save_dir, data_name='data', batch_size=1, ext_filter='png', label_src1=None, label_name1='label', label_preprocess_function1=identity, label_range1=None, label_src2=None, label_name2='label2', label_preprocess_function2=identity, label_range2=None, debug=True, vis=False): ''' # this function creates data in hdf5 format from a image path # input parameter # data_src: source of image data, which can be a list of image path, a txt file contains a list of image path, a folder contains a set of images, a list of numpy array image data # label_src: source of label data, which can be none, a file contains a set of labels, a dictionary of labels, a 1-d numpy array data, a list of label data # save_dir: where to store the hdf5 data # batch_size: how many image to store in a single hdf file # ext_filder: what format of data to use for generating hdf5 data ''' # parse input assert is_path_exists_or_creatable( save_dir), 'save path should be a folder to save all hdf5 files' mkdir_if_missing(save_dir) assert isstring( data_name), 'dataset name is not correct' # name for hdf5 data # convert data source to a list of numpy array image data if isfolder(data_src): print 'data is loading from %s with extension .%s' % (data_src, ext_filter) filelist, num_data = load_list_from_folder(data_src, ext_filter=ext_filter) datalist = None elif isfile(data_src): print 'data is loading from %s with extension .%s' % (data_src, ext_filter) filelist, num_data = load_list_from_file(data_src) datalist = None elif islist(data_src): if debug: assert all( isimage(data_tmp) for data_tmp in data_src ), 'input data source is not a list of numpy array image data' datalist = data_src num_data = len(datalist) filelist = None else: assert False, 'data source format is not correct.' if debug: assert (datalist is None and filelist is not None) or ( filelist is None and datalist is not None), 'data is not correct' if datalist is not None: assert len(datalist) == num_data, 'number of data is not equal' if filelist is not None: assert len(filelist) == num_data, 'number of data is not equal' # convert label source to a list of numpy array label if label_src1 is None: labeldict1 = None labellist1 = None elif isfile(label_src1): assert is_path_exists(label_src1), 'file not found' _, _, ext = fileparts(label_src1) assert ext == '.json', 'only json extension is supported' labeldict1 = json.load(label_src1) num_label1 = len(labeldict1) assert num_data == num_label1, 'number of data and label is not equal.' labellist1 = None elif isdict(label_src1): labeldict1 = label_src1 labellist1 = None elif isnparray(label_src1): if debug: assert label_src1.ndim == 1, 'only 1-d label is supported' labeldict1 = None labellist1 = label_src1 elif islist(label_src1): if debug: assert all( np.array(label_tmp).size == 1 for label_tmp in label_src1), 'only 1-d label is supported' labellist1 = label_src1 labeldict1 = None else: assert False, 'label source format is not correct.' assert isfunction(label_preprocess_function1 ), 'label preprocess function is not correct.' # convert label source to a list of numpy array label if label_src2 is None: labeldict2 = None labellist2 = None elif isfile(label_src2): assert is_path_exists(label_src2), 'file not found' _, _, ext = fileparts(label_src2) assert ext == '.json', 'only json extension is supported' labeldict2 = json.load(label_src2) num_label2 = len(labeldict2) assert num_data == num_label2, 'number of data and label is not equal.' labellist2 = None elif isdict(label_src2): labeldict2 = label_src2 labellist2 = None elif isnparray(label_src2): if debug: assert label_src2.ndim == 1, 'only 1-d label is supported' labeldict2 = None labellist2 = label_src2 elif islist(label_src2): if debug: assert all( np.array(label_tmp).size == 1 for label_tmp in label_src2), 'only 1-d label is supported' labellist2 = label_src2 labeldict2 = None else: assert False, 'label source format is not correct.' assert isfunction(label_preprocess_function2 ), 'label preprocess function is not correct.' # warm up if datalist is not None: size_data = datalist[0].shape else: size_data = imread(filelist[0]).shape if labeldict1 is not None: if debug: assert isstring(label_name1), 'label name is not correct' labels1 = np.zeros((batch_size, 1), dtype='float32') # label_value1 = [float(label_tmp_char) for label_tmp_char in labeldict1.values()] # label_range1 = np.array([min(label_value1), max(label_value1)]) if labellist1 is not None: labels1 = np.zeros((batch_size, 1), dtype='float32') # label_range1 = [np.min(labellist1), np.max(labellist1)] if label_src1 is not None and debug: assert label_range1 is not None, 'label range is not correct' assert (labeldict1 is not None and labellist1 is None) or ( labellist1 is not None and labeldict1 is None), 'label is not correct' if labeldict2 is not None: if debug: assert isstring(label_name2), 'label name is not correct' labels2 = np.zeros((batch_size, 1), dtype='float32') # label_value2 = [float(label_tmp_char) for label_tmp_char in labeldict2.values()] # label_range2 = np.array([min(label_value2), max(label_value2)]) if labellist2 is not None: labels2 = np.zeros((batch_size, 1), dtype='float32') # label_range2 = [np.min(labellist2), np.max(labellist2)] if label_src2 is not None and debug: assert label_range2 is not None, 'label range is not correct' assert (labeldict2 is not None and labellist2 is None) or ( labellist2 is not None and labeldict2 is None), 'label is not correct' # start generating count_hdf = 1 # count number of hdf5 file clock = Timer() datalist_batch = list() for i in xrange(num_data): clock.tic() if filelist is not None: imagefile = filelist[i] _, name, _ = fileparts(imagefile) img = imread(imagefile).astype('float32') max_value = np.max(img) if max_value > 1 and max_value <= 255: img = img / 255.0 # [rows,col,channel,numbers], scale the image data to (0, 1) if debug: min_value = np.min(img) assert min_value >= 0 and min_value <= 1, 'data is not in [0, 1]' if datalist is not None: img = datalist[i] if debug: assert size_data == img.shape datalist_batch.append(img) # process label if labeldict1 is not None: if debug: assert len(filelist) == len( labeldict1), 'file list is not equal to label dictionary' labels1[i % batch_size, 0] = float(labeldict1[name]) if labellist1 is not None: labels1[i % batch_size, 0] = float(labellist1[i]) if labeldict2 is not None: if debug: assert len(filelist) == len( labeldict2), 'file list is not equal to label dictionary' labels2[i % batch_size, 0] = float(labeldict2[name]) if labellist2 is not None: labels2[i % batch_size, 0] = float(labellist2[i]) # save to hdf5 if i % batch_size == 0: data = preprocess_image_caffe( datalist_batch, debug=debug, vis=vis ) # swap channel, transfer from list of HxWxC to NxCxHxW # write to hdf5 format if filelist is not None: save_path = os.path.join(save_dir, '%s.hdf5' % name) else: save_path = os.path.join(save_dir, 'image_%010d.hdf5' % count_hdf) h5f = h5py.File(save_path, 'w') h5f.create_dataset(data_name, data=data, dtype='float32') if (labeldict1 is not None) or (labellist1 is not None): # print(labels1) labels1 = label_preprocess_function1(data=labels1, data_range=label_range1, debug=debug) # print(labels1) h5f.create_dataset(label_name1, data=labels1, dtype='float32') labels1 = np.zeros((batch_size, 1), dtype='float32') if (labeldict2 is not None) or (labellist2 is not None): labels2 = label_preprocess_function2(data=labels2, data_range=label_range2, debug=debug) h5f.create_dataset(label_name2, data=labels2, dtype='float32') labels2 = np.zeros((batch_size, 1), dtype='float32') h5f.close() count_hdf = count_hdf + 1 del datalist_batch[:] if debug: assert len(datalist_batch) == 0, 'list has not been cleared' average_time = clock.toc() print( 'saving to %s: %d/%d, average time:%.3f, elapsed time:%s, estimated time remaining:%s' % (save_path, i + 1, num_data, average_time, convert_secs2time(average_time * i), convert_secs2time(average_time * (num_data - i)))) return count_hdf - 1, num_data
def visualize_pts(pts, title=None, fig=None, ax=None, display_range=False, xlim=[-100, 100], ylim=[-100, 100], display_list=None, covariance=False, mse=False, mse_value=None, vis=True, save_path=None, debug=True, closefig=True): ''' visualize point scatter plot parameter: pts: 2 x num_pts numpy array or a dictionary containing 2 x num_pts numpy array ''' if debug: if isdict(pts): for pts_tmp in pts.values(): assert is2dptsarray( pts_tmp ), 'input points within dictionary are not correct: (2, num_pts) vs %s' % print_np_shape( pts_tmp) if display_list is not None: assert islist(display_list) and len(display_list) == len( pts), 'the input display list is not correct' assert CHECK_EQ_LIST_UNORDERED( display_list, pts.keys(), debug=debug ), 'the input display list does not match the points key list' else: display_list = pts.keys() else: assert is2dptsarray( pts ), 'input points are not correct: (2, num_pts) vs %s' % print_np_shape( pts) if title is not None: assert isstring(title), 'title is not correct' else: title = 'Point Error Vector Distribution Map' assert islogical( display_range ), 'the flag determine if to display in a specific range should be logical value' if display_range: assert islist(xlim) and islist(ylim) and len(xlim) == 2 and len( ylim) == 2, 'the input range for x and y is not correct' assert xlim[1] > xlim[0] and ylim[1] > ylim[ 0], 'the input range for x and y is not correct' # figure setting width, height = 1024, 1024 fig, _ = get_fig_ax_helper(fig=fig, ax=ax, width=width, height=height) if ax is None: plt.title(title, fontsize=20) if isdict(pts): num_pts_all = pts.values()[0].shape[1] if all(pts_tmp.shape[1] == num_pts_all for pts_tmp in pts.values()): plt.xlabel('x coordinate (%d points)' % pts.values()[0].shape[1], fontsize=16) plt.ylabel('y coordinate (%d points)' % pts.values()[0].shape[1], fontsize=16) else: print('number of points is different across different methods') plt.xlabel('x coordinate', fontsize=16) plt.ylabel('y coordinate', fontsize=16) else: plt.xlabel('x coordinate (%d points)' % pts.shape[1], fontsize=16) plt.ylabel('y coordinate (%d points)' % pts.shape[1], fontsize=16) plt.axis('equal') ax = plt.gca() ax.grid() # internal parameters pts_size = 5 std = None conf = 0.98 color_index = 0 marker_index = 0 hatch_index = 0 alpha = 0.2 legend_fontsize = 10 scale_distance = 48.8 linewidth = 2 # plot points handle_dict = dict() # for legend if isdict(pts): num_methods = len(pts) assert len(color_set) * len(marker_set) >= num_methods and len( color_set ) * len( hatch_set ) >= num_methods, 'color in color set is not enough to use, please use different markers' mse_return = dict() for method_name, pts_tmp in pts.items(): color_tmp = color_set[color_index] marker_tmp = marker_set[marker_index] hatch_tmp = hatch_set[hatch_index] # plot covariance ellipse if covariance: _, covariance_number = visualize_pts_covariance( pts_tmp[0:2, :], std=std, conf=conf, ax=ax, debug=debug, color=color_tmp, hatch=hatch_tmp, linewidth=linewidth) handle_tmp = ax.scatter(pts_tmp[0, :], pts_tmp[1, :], color=color_tmp, marker=marker_tmp, s=pts_size, alpha=alpha) if mse: if mse_value is None: num_pts = pts_tmp.shape[1] mse_tmp, _ = pts_euclidean(pts_tmp[0:2, :], np.zeros((2, num_pts), dtype='float32'), debug=debug) else: mse_tmp = mse_value[method_name] display_string = '%s, MSE: %.1f (%.1f um), Covariance: %.1f' % ( method_name, mse_tmp, mse_tmp * scale_distance, covariance_number) mse_return[method_name] = mse_tmp else: display_string = method_name handle_dict[display_string] = handle_tmp color_index += 1 if color_index / len(color_set) == 1: marker_index += 1 hatch_index += 1 color_index = color_index % len(color_set) # reorder the handle before plot handle_key_list = handle_dict.keys() handle_value_list = handle_dict.values() order_index_list = [ display_list.index(method_name_tmp.split(', ')[0]) for method_name_tmp in handle_dict.keys() ] ordered_handle_key_list = list_reorder(handle_key_list, order_index_list, debug=debug) ordered_handle_value_list = list_reorder(handle_value_list, order_index_list, debug=debug) plt.legend(list2tuple(ordered_handle_value_list), list2tuple(ordered_handle_key_list), scatterpoints=1, markerscale=4, loc='lower left', fontsize=legend_fontsize) else: color_tmp = color_set[color_index] marker_tmp = marker_set[marker_index] hatch_tmp = hatch_set[hatch_index] handle_tmp = ax.scatter(pts[0, :], pts[1, :], color=color_tmp, marker=marker_tmp, s=pts_size, alpha=alpha) # plot covariance ellipse if covariance: _, covariance_number = visualize_pts_covariance( pts[0:2, :], std=std, conf=conf, ax=ax, debug=debug, color=color_tmp, hatch=hatch_tmp, linewidth=linewidth) if mse: if mse_value is None: num_pts = pts.shape[1] mse_tmp, _ = pts_euclidean(pts[0:2, :], np.zeros((2, num_pts), dtype='float32'), debug=debug) display_string = 'MSE: %.1f (%.1f um), Covariance: %.1f' % ( mse_tmp, mse_tmp * scale_distance, covariance_number) mse_return = mse_tmp else: display_string = 'MSE: %.1f (%.1f um), Covariance: %.1f' % ( mse_value, mse_value * scale_distance, covariance_number) mse_return = mse_value handle_dict[display_string] = handle_tmp plt.legend(list2tuple(handle_dict.values()), list2tuple(handle_dict.keys()), scatterpoints=1, markerscale=4, loc='lower left', fontsize=legend_fontsize) # display only specific range if display_range: axis_bin = 10 * 2 interval_x = (xlim[1] - xlim[0]) / axis_bin interval_y = (ylim[1] - ylim[0]) / axis_bin plt.xlim(xlim[0], xlim[1]) plt.ylim(ylim[0], ylim[1]) plt.xticks(np.arange(xlim[0], xlim[1] + interval_x, interval_x)) plt.yticks(np.arange(ylim[0], ylim[1] + interval_y, interval_y)) plt.grid() save_vis_close_helper(fig=fig, ax=ax, vis=vis, save_path=save_path, warning=warning, debug=debug, closefig=closefig, transparent=False) return mse_return