Example #1
0
def update_patchs2sheet(service,
                        sheet_id,
                        starting_position,
                        data,
                        debug=True):
    '''
    update a list of list data to a google sheet continuously

    parameters:
        service:    a service request to google sheet
        sheet_di:   a string to identify the sheet uniquely
        starting_position:      a string existing in the sheet to represent the let-top corner of patch to fill in
        data:                   a list of list data to fill
    '''

    if debug:
        isstring(sheet_id), 'the sheet id is not a string'
        isstring(starting_position), 'the starting position is not correct'
        islistoflist(data), 'the input data is not a list of list'

    # How the input data should be interpreted.
    value_input_option = 'RAW'  # TODO: Update placeholder value.

    value_range_body = {'values': data}
    request = service.spreadsheets().values().update(
        spreadsheetId=sheet_id,
        range=starting_position,
        valueInputOption=value_input_option,
        body=value_range_body)
    response = request.execute()
Example #2
0
	def append(self, pts_root, pts_forward, pts_backward=None, pts_anno=None, image_prev_path=None, image_next_path=None, index=None):
		if index is None:
			index = self.length
		key = self.convert_index2key(index)
		if key in self.key_list:
			assert False, 'the data with index %d already exists' % index
		else:
			self.key_list.append(key)
			self.index_dict[key] = index

		if self.debug:
			assert is2dptsarray_occlusion(pts_root) or is2dptsarray(pts_root), 'the shape of input point array %s in the root frame is not correct' % print_np_shape(pts_root)
			assert is2dptsarray_occlusion(pts_forward) or is2dptsarray(pts_forward), 'the shape of input point array %s in the forward frame is not correct' % print_np_shape(pts_forward)
			assert is2dptsarray_occlusion(pts_backward) or is2dptsarray(pts_backward) or pts_backward is None, 'the shape of input point array %s in the backward frame is not correct' % print_np_shape(pts_backward)
			assert is2dptsarray_occlusion(pts_anno) or is2dptsarray(pts_anno) or pts_anno is None, 'the shape of input point array %s from the annotations is not correct' % print_np_shape(pts_anno)
			assert (isstring(image_prev_path) or image_prev_path is None) and (isstring(image_next_path) or image_next_path is None), 'the input image path is not correct'

		self.pts_root[key] = pts_root
		self.pts_forward[key] = pts_forward
		self.pts_backward[key] = pts_backward
		self.pts_anno[key] = pts_anno
		self.image_prev_path[key] = image_prev_path
		self.image_next_path[key] = image_next_path

		self.length += 1
Example #3
0
def visualize_bar(data,
                  bin_size=2.0,
                  title='Bar Graph of Key-Value Pair',
                  xlabel='index',
                  ylabel='count',
                  vis=True,
                  save_path=None,
                  debug=True,
                  closefig=True):
    '''
    visualize the bar graph of a data, which can be a dictionary or list of dictionary

    different from function of visualize_bar_graph, this function does not depend on panda and dataframe, it's simpler but with less functionality
    also the key of this function takes continuous scalar variable
    '''
    if debug:
        assert isstring(title) and isstring(xlabel) and isstring(
            ylabel), 'title/xlabel/ylabel is not correct'
        assert isdict(data) or islist(data), 'input data is not correct'
        assert isscalar(bin_size), 'the bin size is not a floating number'

    if isdict(data):
        index_list = data.keys()
        if debug:
            assert islistofscalar(
                index_list
            ), 'the input dictionary does not contain a scalar key'
        frequencies = data.values()
    else:
        index_list = range(len(data))
        frequencies = data

    index_str_list = scalarlist2strlist(index_list, debug=debug)
    index_list = np.array(index_list)
    fig, ax = get_fig_ax_helper(fig=None, ax=None)
    # ax.set_xticks(index_list)
    # ax.set_xticklabels(index_str_list)
    plt.bar(index_list, frequencies, bin_size, color='r', alpha=0.5)
    plt.title(title, fontsize=20)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    return save_vis_close_helper(fig=fig,
                                 ax=ax,
                                 vis=vis,
                                 save_path=save_path,
                                 debug=debug,
                                 transparent=False,
                                 closefig=closefig)
Example #4
0
	def __init__(self, intrinsics, extrinsics, distortion=None, camera_id=None, warning=True, debug=True):
		self.intrinsics = intrinsics
		self.extrinsics = extrinsics
		if distortion is not None: 
			# to make it full 
			self.distortion = distortion
		if camera_id is not None: 
			assert isstring(camera_id), 'the camera id is a string'
			self.camera_id = camera_id
		
		if debug:
			assert isnparray(self.intrinsics) and self.intrinsics.shape == (3, 3), 'the intrinsics is not correct'
			assert isnparray(self.extrinsics) and self.extrinsics.shape == (3, 4), 'the extrinsics is not correct'
Example #5
0
def load_list_from_folders(folder_path_list,
                           ext_filter=None,
                           depth=1,
                           recursive=False,
                           save_path=None,
                           debug=True):
    '''
    load a list of files or folders from a list of system path
    '''
    if debug:
        assert islist(folder_path_list) or isstring(
            folder_path_list), 'input path list is not correct'
    if isstring(folder_path_list):
        folder_path_list = [folder_path_list]

    fulllist = list()
    num_elem = 0
    for folder_path_tmp in folder_path_list:
        fulllist_tmp, num_elem_tmp = load_list_from_folder(
            folder_path_tmp,
            ext_filter=ext_filter,
            depth=depth,
            recursive=recursive)
        fulllist += fulllist_tmp
        num_elem += num_elem_tmp

    # save list to a path
    if save_path is not None:
        save_path = safepath(save_path)
        if debug:
            assert is_path_exists_or_creatable(
                save_path), 'the file cannot be created'
        with open(save_path, 'w') as file:
            for item in fulllist:
                file.write('%s\n' % item)
        file.close()

    return fulllist, num_elem
Example #6
0
def get_data_from_sheet(service, sheet_id, search_range, debug=True):
    '''
    get a list of data from a google sheet continuously

    parameters:
        service:    a service request to google sheet
        sheet_di:   a string to identify the sheet uniquely
        search_range:      a list of position queried 
    '''

    if debug:
        isstring(sheet_id), 'the sheet id is not a string'
        islist(search_range), 'the search range is not a list'

    # print(search_range)
    # How the input data should be interpreted.
    # value_input_option = 'RAW'  # TODO: Update placeholder value.

    # value_range_body = {'values': [data]}
    request = service.spreadsheets().values().batchGet(spreadsheetId=sheet_id,
                                                       ranges=search_range)

    while True:
        try:
            response = request.execute()
            break
        except:
            continue

    data = list()
    # print(response['valueRanges'])
    for raw_data in response['valueRanges']:
        if 'values' in raw_data:
            data.append(raw_data['values'][0][0])
        else:
            data.append('')

    return data
Example #7
0
def load_hdf5_file(hdf5_file, dataname, debug=True):
    '''
    load a single hdf5 file
    '''
    if debug:
        assert is_path_exists(hdf5_file) and isfile(
            hdf5_file), 'input hdf5 path does not exist: %s' % hdf5_file
        assert islist(dataname), 'dataset queried is not correct'
        assert all(
            isstring(dataset_tmp)
            for dataset_tmp in dataname), 'dataset queried is not correct'

    hdf5 = h5py.File(hdf5_file, 'r')
    datadict = dict()
    for dataset in dataname:
        datadict[dataset] = np.array(hdf5[dataset])
    return datadict
Example #8
0
def rand_load_hdf5_from_folder(hdf5_src, dataname, debug=True):
    '''
    randomly load a single hdf5 file from a hdf5 folder
    '''
    if debug:
        assert is_path_exists(hdf5_src) and isfolder(
            hdf5_src), 'input hdf5 path does not exist: %s' % hdf5_src
        assert islist(dataname), 'dataset queried is not correct'
        assert all(
            isstring(dataset_tmp)
            for dataset_tmp in dataname), 'dataset queried is not correct'

    hdf5list, num_hdf5_files = load_list_from_folder(folder_path=hdf5_src,
                                                     ext_filter='.hdf5')
    check_index = random.randrange(0, num_hdf5_files)
    hdf5_path_sample = hdf5list[check_index]
    hdf5_file = h5py.File(hdf5_path_sample, 'r')
    datadict = dict()
    for dataset in dataname:
        datadict[dataset] = np.array(hdf5_file[dataset])
    return datadict
Example #9
0
def visualize_bar_graph(data,
                        title='Bar Graph of Key-Value Pair',
                        xlabel='pixel error',
                        ylabel='keypoint index',
                        label=False,
                        label_list=None,
                        vis=True,
                        save_path=None,
                        debug=True,
                        closefig=True):
    '''
    visualize the bar graph of a data, which can be a dictionary or list of dictionary
    inside each dictionary, the keys (string) should be the same which is the y label, the values should be scalar
    '''
    if debug:
        assert isstring(title) and isstring(xlabel) and isstring(
            ylabel), 'title/xlabel/ylabel is not correct'
        assert isdict(data) or islistofdict(data), 'input data is not correct'
        if isdict(data):
            assert all(
                isstring(key_tmp)
                for key_tmp in data.keys()), 'the keys are not all strings'
            assert all(
                isscalar(value_tmp)
                for value_tmp in data.values()), 'the keys are not all strings'
        else:
            assert len(data) <= len(
                color_set
            ), 'number of data set is larger than number of color to use'
            keys = sorted(data[0].keys())
            for dict_tmp in data:
                if not (sorted(dict_tmp.keys()) == keys):
                    print(dict_tmp.keys())
                    print(keys)
                    assert False, 'the keys are not equal across different input set'
                assert all(isstring(key_tmp) for key_tmp in
                           dict_tmp.keys()), 'the keys are not all strings'
                assert all(
                    isscalar(value_tmp) for value_tmp in
                    dict_tmp.values()), 'the values are not all scalars'

    # convert dictionary to DataFrame
    data_new = dict()
    if isdict(data):
        key_list = data.keys()
        sorted_index = sorted(range(len(key_list)), key=lambda k: key_list[k])
        data_new['names'] = (np.asarray(key_list)[sorted_index]).tolist()
        data_new['values'] = (np.asarray(data.values())[sorted_index]).tolist()
    else:
        key_list = data[0].keys()
        sorted_index = sorted(range(len(key_list)), key=lambda k: key_list[k])
        data_new['names'] = (np.asarray(key_list)[sorted_index]).tolist()
        num_sets = len(data)
        for set_index in range(num_sets):
            data_new['value_%03d' % set_index] = (np.asarray(
                data[set_index].values())[sorted_index]).tolist()
    dataframe = DataFrame(data_new)

    # plot
    width = 2000
    height = 2000
    alpha = 0.5
    figsize = width / float(dpi), height / float(dpi)
    fig = plt.figure(figsize=figsize)
    sns.set(style='whitegrid')
    # fig, ax = get_fig_ax_helper(fig=None, ax=None)
    if isdict(data):
        g = sns.barplot(x='values',
                        y='names',
                        data=dataframe,
                        label='data',
                        color='b')
        plt.legend(ncol=1, loc='lower right', frameon=True, fontsize=5)
    else:
        num_sets = len(data)
        for set_index in range(num_sets):
            if set_index == 0:
                sns.set_color_codes('pastel')
            else:
                sns.set_color_codes('muted')

            if label:
                sns.barplot(x='value_%03d' % set_index,
                            y='names',
                            data=dataframe,
                            label=label_list[set_index],
                            color=color_set[set_index],
                            alpha=alpha)
            else:
                sns.barplot(x='value_%03d' % set_index,
                            y='names',
                            data=dataframe,
                            color=solor_set[set_index],
                            alpha=alpha)
        plt.legend(ncol=len(data), loc='lower right', frameon=True, fontsize=5)

    sns.despine(left=True, bottom=True)
    plt.title(title, fontsize=20)
    plt.xlim([0, 50])
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)

    num_yticks = len(data_new['names'])
    adaptive_fontsize = -0.0555556 * num_yticks + 15.111
    plt.yticks(fontsize=adaptive_fontsize)

    return save_vis_close_helper(fig=fig,
                                 vis=vis,
                                 save_path=save_path,
                                 debug=debug,
                                 closefig=closefig)
Example #10
0
def visualize_nearest_neighbor(featuremap_dict,
                               num_neighbor=5,
                               top_number=5,
                               vis=True,
                               save_csv=False,
                               csv_save_path=None,
                               save_vis=False,
                               save_img=False,
                               save_thumb_name='nearest_neighbor.png',
                               img_src_folder=None,
                               ext_filter='.jpg',
                               nn_save_folder=None,
                               debug=True):
    '''
    visualize nearest neighbor for featuremap from images

    parameter:
        featuremap_dict: a dictionary contains image path as key, and featuremap as value, the featuremap needs to be numpy array with any shape. No flatten needed
        num_neighbor: number of neighbor to visualize, the first nearest is itself
        top_number: number of top to visualize, since there might be tons of featuremap (length of dictionary), we choose the top ten with lowest distance with their nearest neighbor
        csv_save_path: path to save .csv file which contains indices and distance array for all elements
        nn_save_folder: save the nearest neighbor images for top featuremap

    return:
        all_sorted_nearest_id: a 2d matrix, each row is a feature followed by its nearest neighbor in whole feature dataset, the column is sorted by the distance of all nearest neighbor each row
        selected_nearest_id: only top number of sorted nearest id 
    '''
    print('processing feature map to nearest neightbor.......')
    if debug:
        assert isdict(featuremap_dict), 'featuremap should be dictionary'
        assert all(
            isnparray(featuremap_tmp) for featuremap_tmp in featuremap_dict.
            values()), 'value of dictionary should be numpy array'
        assert isinteger(
            num_neighbor
        ) and num_neighbor > 1, 'number of neighborhodd is an integer larger than 1'
        if save_csv and csv_save_path is not None:
            assert is_path_exists_or_creatable(
                csv_save_path), 'path to save .csv file is not correct'

        if save_vis or save_img:
            if nn_save_folder is not None:  # save image directly
                assert isstring(ext_filter), 'extension filter is not correct'
                assert is_path_exists(
                    img_src_folder), 'source folder for image is not correct'
                assert all(
                    isstring(path_tmp) for path_tmp in featuremap_dict.keys()
                )  # key should be the path for the image
                assert is_path_exists_or_creatable(
                    nn_save_folder
                ), 'folder to save top visualized images is not correct'
                assert isstring(
                    save_thumb_name), 'name of thumbnail is not correct'

    if ext_filter.find('.') == -1:
        ext_filter = '.%s' % ext_filter

    # flatten the feature map
    nn_feature_dict = dict()
    for key, featuremap_tmp in featuremap_dict.items():
        nn_feature_dict[key] = featuremap_tmp.flatten()
    num_features = len(nn_feature_dict)

    # nearest neighbor
    featuremap = np.array(nn_feature_dict.values())
    nearbrs = NearestNeighbors(n_neighbors=num_neighbor,
                               algorithm='ball_tree').fit(featuremap)
    distances, indices = nearbrs.kneighbors(featuremap)

    if debug:
        assert featuremap.shape[
            0] == num_features, 'shape of feature map is not correct'
        assert indices.shape == (
            num_features, num_neighbor), 'shape of indices is not correct'
        assert distances.shape == (
            num_features, num_neighbor), 'shape of indices is not correct'

    # convert the nearest indices for all featuremap to the key accordingly
    id_list = nn_feature_dict.keys()
    max_length = len(max(
        id_list, key=len))  # find the maximum length of string in the key
    nearest_id = np.chararray(indices.shape, itemsize=max_length + 1)
    for x in range(nearest_id.shape[0]):
        for y in range(nearest_id.shape[1]):
            nearest_id[x, y] = id_list[indices[x, y]]

    if debug:
        assert list(nearest_id[:,
                               0]) == id_list, 'nearest neighbor has problem'

    # sort the feature based on distance
    print('sorting the feature based on distance')
    featuremap_distance = np.sum(distances, axis=1)
    if debug:
        assert featuremap_distance.shape == (
            num_features, ), 'distance is not correct'
    sorted_indices = np.argsort(featuremap_distance)
    all_sorted_nearest_id = nearest_id[sorted_indices, :]

    # save to the csv file
    if save_csv and csv_save_path is not None:
        print('Saving nearest neighbor result as .csv to path: %s' %
              csv_save_path)
        with open(csv_save_path, 'w+') as file:
            np.savetxt(file, distances, delimiter=',', fmt='%f')
            np.savetxt(file, all_sorted_nearest_id, delimiter=',', fmt='%s')
            file.close()

    # choose the best to visualize
    selected_sorted_indices = sorted_indices[0:top_number]
    if debug:
        for i in range(num_features - 1):
            assert featuremap_distance[
                sorted_indices[i]] < featuremap_distance[sorted_indices[
                    i + 1]], 'feature map is not well sorted based on distance'
    selected_nearest_id = nearest_id[selected_sorted_indices, :]

    if save_vis:
        fig, axarray = plt.subplots(top_number, num_neighbor)
        for index in range(top_number):
            for nearest_index in range(num_neighbor):
                img_path = os.path.join(
                    img_src_folder, '%s%s' %
                    (selected_nearest_id[index, nearest_index], ext_filter))
                if debug:
                    print('loading image from %s' % img_path)
                img = imread(img_path)
                if isgrayimage_dimension(img):
                    axarray[index, nearest_index].imshow(img, cmap='gray')
                elif iscolorimage_dimension(img):
                    axarray[index, nearest_index].imshow(img)
                else:
                    assert False, 'unknown error'
                axarray[index, nearest_index].axis('off')
        save_thumb = os.path.join(nn_save_folder, save_thumb_name)
        fig.savefig(save_thumb)
        if vis:
            plt.show()
        plt.close(fig)

    # save top visualization to the folder
    if save_img and nn_save_folder is not None:
        for top_index in range(top_number):
            file_list = selected_nearest_id[top_index]
            save_subfolder = os.path.join(nn_save_folder, file_list[0])
            mkdir_if_missing(save_subfolder)
            for file_tmp in file_list:
                file_src = os.path.join(img_src_folder,
                                        '%s%s' % (file_tmp, ext_filter))
                save_path = os.path.join(save_subfolder,
                                         '%s%s' % (file_tmp, ext_filter))
                if debug:
                    print('saving %s to %s' % (file_src, save_path))
                shutil.copyfile(file_src, save_path)

    return all_sorted_nearest_id, selected_nearest_id
Example #11
0
def visualize_ced(normed_mean_error_dict,
                  error_threshold,
                  normalized=True,
                  truncated_list=None,
                  display2terminal=True,
                  display_list=None,
                  title='2D PCK curve',
                  debug=True,
                  vis=False,
                  pck_savepath=None,
                  table_savepath=None,
                  closefig=True):
    '''
    visualize the cumulative error distribution curve (alse called NME curve or pck curve)
    all parameters are represented by percentage

    parameter:
        normed_mean_error_dict:     a dictionary whose keys are the method name and values are (N, ) numpy array to represent error in evaluation
        error_threshold:            threshold to display in x axis

    return:
        AUC:                        area under the curve
        MSE:                        mean square error
    '''
    if debug:
        assert isdict(
            normed_mean_error_dict
        ), 'the input normalized mean error dictionary is not correct'
        assert islogical(
            normalized), 'the normalization flag should be logical'
        if normalized:
            assert error_threshold > 0 and error_threshold < 100, 'threshold percentage is not well set'
        if save:
            assert is_path_exists_or_creatable(
                pck_savepath
            ), 'please provide a valid path to save the pck results'
            assert is_path_exists_or_creatable(
                table_savepath
            ), 'please provide a valid path to save the table results'
        assert isstring(title), 'title is not correct'
        if truncated_list is not None:
            assert islistofscalar(
                truncated_list), 'the input truncated list is not correct'
        if display_list is not None:
            assert islist(display_list) and len(
                display_list) == len(normed_mean_error_dict
                                     ), 'the input display list is not correct'
            assert CHECK_EQ_LIST_UNORDERED(
                display_list, normed_mean_error_dict.keys(), debug=debug
            ), 'the input display list does not match the error dictionary key list'
        else:
            display_list = normed_mean_error_dict.keys()

    # set display parameters
    width, height = 1000, 800
    legend_fontsize = 10
    scale_distance = 48.8
    line_index, color_index = 0, 0

    figsize = width / float(dpi), height / float(dpi)
    fig = plt.figure(figsize=figsize)

    # set figure handle
    num_bins = 1000
    if normalized:
        maximum_x = 1
        scale = num_bins / 100
    else:
        maximum_x = error_threshold + 1
        scale = num_bins / maximum_x
    x_axis = np.linspace(
        0, maximum_x,
        num_bins)  # error axis, percentage of normalization factor
    y_axis = np.zeros(num_bins)
    interval_y = 10
    interval_x = 1
    plt.xlim(0, error_threshold)
    plt.ylim(0, 100)
    plt.yticks(np.arange(0, 100 + interval_y, interval_y))
    plt.xticks(np.arange(0, error_threshold + interval_x, interval_x))
    plt.grid()
    plt.title(title, fontsize=20)
    if normalized:
        plt.xlabel('Normalized error euclidean distance (%)', fontsize=16)
    else:
        plt.xlabel('Absolute error euclidean distance', fontsize=16)

    # calculate metrics for each method
    num_methods = len(normed_mean_error_dict)
    num_images = len(normed_mean_error_dict.values()[0])
    metrics_dict = dict()
    metrics_table = list()
    table_title = ['Method Name / Metrics', 'AUC', 'MSE']
    append2title = False
    assert num_images > 0, 'number of error array should be larger than 0'
    for ordered_index in range(num_methods):
        method_name = display_list[ordered_index]
        normed_mean_error = normed_mean_error_dict[method_name]

        if debug:
            assert isnparray(
                normed_mean_error
            ) and normed_mean_error.ndim == 1, 'shape of error distance is not good'
            assert len(
                normed_mean_error
            ) == num_images, 'number of testing images should be equal for all methods'
            assert len(linestyle_set) * len(color_set) >= len(
                normed_mean_error_dict)

        color_tmp = color_set[color_index]
        line_tmp = linestyle_set[line_index]

        for i in range(num_bins):
            y_axis[i] = float(
                (normed_mean_error <
                 x_axis[i]).sum()) / num_images  # percentage of error

        # calculate area under the curve and mean square error
        entry = dict()
        entry['AUC'] = np.sum(y_axis[:error_threshold * scale]) / (
            error_threshold * scale)  # bigger, better
        entry['MSE'] = np.mean(normed_mean_error)  # smaller, better
        metrics_table_tmp = [
            str(method_name),
            '%.2f' % (entry['AUC']),
            '%.1f' % (entry['MSE'])
        ]
        if truncated_list is not None:
            tmse_dict = calculate_truncated_mse(normed_mean_error.tolist(),
                                                truncated_list,
                                                debug=debug)
            for threshold in truncated_list:
                entry['AUC/%s' %
                      threshold] = np.sum(y_axis[:error_threshold * scale]) / (
                          error_threshold * scale)  # bigger, better
                entry['MSE/%s' % threshold] = tmse_dict[threshold]['T-MSE']
                entry['percentage/%s' %
                      threshold] = tmse_dict[threshold]['percentage']

                if not append2title:
                    table_title.append('AUC/%s' % threshold)
                    table_title.append('MSE/%s' % threshold)
                    table_title.append('pct/%s' % threshold)
                metrics_table_tmp.append('%.2f' %
                                         (entry['AUC/%s' % threshold]))
                metrics_table_tmp.append('%.1f' %
                                         (entry['MSE/%s' % threshold]))
                metrics_table_tmp.append(
                    '%.1f' % (100 * entry['percentage/%s' % threshold]) + '%')

        # print metrics_table_tmp
        metrics_table.append(metrics_table_tmp)
        append2title = True
        metrics_dict[method_name] = entry

        # draw
        label = '%s, AUC: %.2f, MSE: %.1f (%.0f um)' % (
            method_name, entry['AUC'], entry['MSE'],
            entry['MSE'] * scale_distance)
        if normalized:
            plt.plot(x_axis * 100,
                     y_axis * 100,
                     color=color_tmp,
                     linestyle=line_tmp,
                     label=label,
                     lw=3)
        else:
            plt.plot(x_axis,
                     y_axis * 100,
                     color=color_tmp,
                     linestyle=line_tmp,
                     label=label,
                     lw=3)
        plt.legend(loc=4, fontsize=legend_fontsize)

        color_index += 1
        if color_index / len(color_set) == 1:
            line_index += 1
            color_index = color_index % len(color_set)

    # plt.grid()
    plt.ylabel('{} Test Images (%)'.format(num_images), fontsize=16)
    save_vis_close_helper(fig=fig,
                          ax=None,
                          vis=vis,
                          transparent=False,
                          save_path=pck_savepath,
                          debug=debug,
                          closefig=closefig)

    # reorder the table
    order_index_list = [
        display_list.index(method_name_tmp)
        for method_name_tmp in normed_mean_error_dict.keys()
    ]
    order_index_list = [0] + [
        order_index_tmp + 1 for order_index_tmp in order_index_list
    ]

    # print table to terminal
    metrics_table = [table_title] + metrics_table
    # metrics_table = list_reorder([table_title] + metrics_table, order_index_list, debug=debug)
    table = AsciiTable(metrics_table)
    if display2terminal:
        print('\nprint detailed metrics')
        print(table.table)

    # save table to file
    if table_savepath is not None:
        table_file = open(table_savepath, 'w')
        table_file.write(table.table)
        table_file.close()
        if display2terminal:
            print('\nsave detailed metrics to %s' % table_savepath)

    return metrics_dict, metrics_table
Example #12
0
def generate_hdf5(data_src,
                  save_dir,
                  data_name='data',
                  batch_size=1,
                  ext_filter='png',
                  label_src1=None,
                  label_name1='label',
                  label_preprocess_function1=identity,
                  label_range1=None,
                  label_src2=None,
                  label_name2='label2',
                  label_preprocess_function2=identity,
                  label_range2=None,
                  debug=True,
                  vis=False):
    '''
    # this function creates data in hdf5 format from a image path 

    # input parameter
    #   data_src:       source of image data, which can be a list of image path, a txt file contains a list of image path, a folder contains a set of images, a list of numpy array image data
    #   label_src:      source of label data, which can be none, a file contains a set of labels, a dictionary of labels, a 1-d numpy array data, a list of label data
    #   save_dir:       where to store the hdf5 data
    #   batch_size:     how many image to store in a single hdf file
    #   ext_filder:     what format of data to use for generating hdf5 data 
    '''

    # parse input
    assert is_path_exists_or_creatable(
        save_dir), 'save path should be a folder to save all hdf5 files'
    mkdir_if_missing(save_dir)
    assert isstring(
        data_name), 'dataset name is not correct'  # name for hdf5 data

    # convert data source to a list of numpy array image data
    if isfolder(data_src):
        print 'data is loading from %s with extension .%s' % (data_src,
                                                              ext_filter)
        filelist, num_data = load_list_from_folder(data_src,
                                                   ext_filter=ext_filter)
        datalist = None
    elif isfile(data_src):
        print 'data is loading from %s with extension .%s' % (data_src,
                                                              ext_filter)
        filelist, num_data = load_list_from_file(data_src)
        datalist = None
    elif islist(data_src):
        if debug:
            assert all(
                isimage(data_tmp) for data_tmp in data_src
            ), 'input data source is not a list of numpy array image data'
        datalist = data_src
        num_data = len(datalist)
        filelist = None
    else:
        assert False, 'data source format is not correct.'
    if debug:
        assert (datalist is None and filelist is not None) or (
            filelist is None and datalist is not None), 'data is not correct'
        if datalist is not None:
            assert len(datalist) == num_data, 'number of data is not equal'
        if filelist is not None:
            assert len(filelist) == num_data, 'number of data is not equal'

    # convert label source to a list of numpy array label
    if label_src1 is None:
        labeldict1 = None
        labellist1 = None
    elif isfile(label_src1):
        assert is_path_exists(label_src1), 'file not found'
        _, _, ext = fileparts(label_src1)
        assert ext == '.json', 'only json extension is supported'
        labeldict1 = json.load(label_src1)
        num_label1 = len(labeldict1)
        assert num_data == num_label1, 'number of data and label is not equal.'
        labellist1 = None
    elif isdict(label_src1):
        labeldict1 = label_src1
        labellist1 = None
    elif isnparray(label_src1):
        if debug:
            assert label_src1.ndim == 1, 'only 1-d label is supported'
        labeldict1 = None
        labellist1 = label_src1
    elif islist(label_src1):
        if debug:
            assert all(
                np.array(label_tmp).size == 1
                for label_tmp in label_src1), 'only 1-d label is supported'
        labellist1 = label_src1
        labeldict1 = None
    else:
        assert False, 'label source format is not correct.'
    assert isfunction(label_preprocess_function1
                      ), 'label preprocess function is not correct.'

    # convert label source to a list of numpy array label
    if label_src2 is None:
        labeldict2 = None
        labellist2 = None
    elif isfile(label_src2):
        assert is_path_exists(label_src2), 'file not found'
        _, _, ext = fileparts(label_src2)
        assert ext == '.json', 'only json extension is supported'
        labeldict2 = json.load(label_src2)
        num_label2 = len(labeldict2)
        assert num_data == num_label2, 'number of data and label is not equal.'
        labellist2 = None
    elif isdict(label_src2):
        labeldict2 = label_src2
        labellist2 = None
    elif isnparray(label_src2):
        if debug:
            assert label_src2.ndim == 1, 'only 1-d label is supported'
        labeldict2 = None
        labellist2 = label_src2
    elif islist(label_src2):
        if debug:
            assert all(
                np.array(label_tmp).size == 1
                for label_tmp in label_src2), 'only 1-d label is supported'
        labellist2 = label_src2
        labeldict2 = None
    else:
        assert False, 'label source format is not correct.'
    assert isfunction(label_preprocess_function2
                      ), 'label preprocess function is not correct.'

    # warm up
    if datalist is not None:
        size_data = datalist[0].shape
    else:
        size_data = imread(filelist[0]).shape

    if labeldict1 is not None:
        if debug:
            assert isstring(label_name1), 'label name is not correct'
        labels1 = np.zeros((batch_size, 1), dtype='float32')
        # label_value1 = [float(label_tmp_char) for label_tmp_char in labeldict1.values()]
        # label_range1 = np.array([min(label_value1), max(label_value1)])
    if labellist1 is not None:
        labels1 = np.zeros((batch_size, 1), dtype='float32')
        # label_range1 = [np.min(labellist1), np.max(labellist1)]
    if label_src1 is not None and debug:
        assert label_range1 is not None, 'label range is not correct'
        assert (labeldict1 is not None and labellist1 is None) or (
            labellist1 is not None
            and labeldict1 is None), 'label is not correct'

    if labeldict2 is not None:
        if debug:
            assert isstring(label_name2), 'label name is not correct'
        labels2 = np.zeros((batch_size, 1), dtype='float32')
        # label_value2 = [float(label_tmp_char) for label_tmp_char in labeldict2.values()]
        # label_range2 = np.array([min(label_value2), max(label_value2)])
    if labellist2 is not None:
        labels2 = np.zeros((batch_size, 1), dtype='float32')
        # label_range2 = [np.min(labellist2), np.max(labellist2)]
    if label_src2 is not None and debug:
        assert label_range2 is not None, 'label range is not correct'
        assert (labeldict2 is not None and labellist2 is None) or (
            labellist2 is not None
            and labeldict2 is None), 'label is not correct'

    # start generating
    count_hdf = 1  # count number of hdf5 file
    clock = Timer()
    datalist_batch = list()
    for i in xrange(num_data):
        clock.tic()
        if filelist is not None:
            imagefile = filelist[i]
            _, name, _ = fileparts(imagefile)
            img = imread(imagefile).astype('float32')
            max_value = np.max(img)
            if max_value > 1 and max_value <= 255:
                img = img / 255.0  # [rows,col,channel,numbers], scale the image data to (0, 1)
            if debug:
                min_value = np.min(img)
                assert min_value >= 0 and min_value <= 1, 'data is not in [0, 1]'
        if datalist is not None:
            img = datalist[i]
        if debug:
            assert size_data == img.shape
        datalist_batch.append(img)

        # process label
        if labeldict1 is not None:
            if debug:
                assert len(filelist) == len(
                    labeldict1), 'file list is not equal to label dictionary'

            labels1[i % batch_size, 0] = float(labeldict1[name])
        if labellist1 is not None:
            labels1[i % batch_size, 0] = float(labellist1[i])
        if labeldict2 is not None:
            if debug:
                assert len(filelist) == len(
                    labeldict2), 'file list is not equal to label dictionary'
            labels2[i % batch_size, 0] = float(labeldict2[name])
        if labellist2 is not None:
            labels2[i % batch_size, 0] = float(labellist2[i])

        # save to hdf5
        if i % batch_size == 0:
            data = preprocess_image_caffe(
                datalist_batch, debug=debug, vis=vis
            )  # swap channel, transfer from list of HxWxC to NxCxHxW

            # write to hdf5 format
            if filelist is not None:
                save_path = os.path.join(save_dir, '%s.hdf5' % name)
            else:
                save_path = os.path.join(save_dir,
                                         'image_%010d.hdf5' % count_hdf)
            h5f = h5py.File(save_path, 'w')
            h5f.create_dataset(data_name, data=data, dtype='float32')
            if (labeldict1 is not None) or (labellist1 is not None):
                # print(labels1)
                labels1 = label_preprocess_function1(data=labels1,
                                                     data_range=label_range1,
                                                     debug=debug)
                # print(labels1)
                h5f.create_dataset(label_name1, data=labels1, dtype='float32')
                labels1 = np.zeros((batch_size, 1), dtype='float32')

            if (labeldict2 is not None) or (labellist2 is not None):
                labels2 = label_preprocess_function2(data=labels2,
                                                     data_range=label_range2,
                                                     debug=debug)
                h5f.create_dataset(label_name2, data=labels2, dtype='float32')
                labels2 = np.zeros((batch_size, 1), dtype='float32')

            h5f.close()
            count_hdf = count_hdf + 1
            del datalist_batch[:]
            if debug:
                assert len(datalist_batch) == 0, 'list has not been cleared'
        average_time = clock.toc()
        print(
            'saving to %s: %d/%d, average time:%.3f, elapsed time:%s, estimated time remaining:%s'
            % (save_path, i + 1, num_data, average_time,
               convert_secs2time(average_time * i),
               convert_secs2time(average_time * (num_data - i))))

    return count_hdf - 1, num_data
Example #13
0
def visualize_pts(pts,
                  title=None,
                  fig=None,
                  ax=None,
                  display_range=False,
                  xlim=[-100, 100],
                  ylim=[-100, 100],
                  display_list=None,
                  covariance=False,
                  mse=False,
                  mse_value=None,
                  vis=True,
                  save_path=None,
                  debug=True,
                  closefig=True):
    '''
    visualize point scatter plot

    parameter:
        pts:            2 x num_pts numpy array or a dictionary containing 2 x num_pts numpy array
    '''

    if debug:
        if isdict(pts):
            for pts_tmp in pts.values():
                assert is2dptsarray(
                    pts_tmp
                ), 'input points within dictionary are not correct: (2, num_pts) vs %s' % print_np_shape(
                    pts_tmp)
            if display_list is not None:
                assert islist(display_list) and len(display_list) == len(
                    pts), 'the input display list is not correct'
                assert CHECK_EQ_LIST_UNORDERED(
                    display_list, pts.keys(), debug=debug
                ), 'the input display list does not match the points key list'
            else:
                display_list = pts.keys()
        else:
            assert is2dptsarray(
                pts
            ), 'input points are not correct: (2, num_pts) vs %s' % print_np_shape(
                pts)
        if title is not None: assert isstring(title), 'title is not correct'
        else: title = 'Point Error Vector Distribution Map'
        assert islogical(
            display_range
        ), 'the flag determine if to display in a specific range should be logical value'
        if display_range:
            assert islist(xlim) and islist(ylim) and len(xlim) == 2 and len(
                ylim) == 2, 'the input range for x and y is not correct'
            assert xlim[1] > xlim[0] and ylim[1] > ylim[
                0], 'the input range for x and y is not correct'

    # figure setting
    width, height = 1024, 1024
    fig, _ = get_fig_ax_helper(fig=fig, ax=ax, width=width, height=height)
    if ax is None:
        plt.title(title, fontsize=20)
        if isdict(pts):
            num_pts_all = pts.values()[0].shape[1]
            if all(pts_tmp.shape[1] == num_pts_all
                   for pts_tmp in pts.values()):
                plt.xlabel('x coordinate (%d points)' %
                           pts.values()[0].shape[1],
                           fontsize=16)
                plt.ylabel('y coordinate (%d points)' %
                           pts.values()[0].shape[1],
                           fontsize=16)
            else:
                print('number of points is different across different methods')
                plt.xlabel('x coordinate', fontsize=16)
                plt.ylabel('y coordinate', fontsize=16)
        else:
            plt.xlabel('x coordinate (%d points)' % pts.shape[1], fontsize=16)
            plt.ylabel('y coordinate (%d points)' % pts.shape[1], fontsize=16)
        plt.axis('equal')
        ax = plt.gca()
        ax.grid()

    # internal parameters
    pts_size = 5
    std = None
    conf = 0.98
    color_index = 0
    marker_index = 0
    hatch_index = 0
    alpha = 0.2
    legend_fontsize = 10
    scale_distance = 48.8
    linewidth = 2

    # plot points
    handle_dict = dict()  # for legend
    if isdict(pts):
        num_methods = len(pts)
        assert len(color_set) * len(marker_set) >= num_methods and len(
            color_set
        ) * len(
            hatch_set
        ) >= num_methods, 'color in color set is not enough to use, please use different markers'
        mse_return = dict()
        for method_name, pts_tmp in pts.items():
            color_tmp = color_set[color_index]
            marker_tmp = marker_set[marker_index]
            hatch_tmp = hatch_set[hatch_index]

            # plot covariance ellipse
            if covariance:
                _, covariance_number = visualize_pts_covariance(
                    pts_tmp[0:2, :],
                    std=std,
                    conf=conf,
                    ax=ax,
                    debug=debug,
                    color=color_tmp,
                    hatch=hatch_tmp,
                    linewidth=linewidth)
            handle_tmp = ax.scatter(pts_tmp[0, :],
                                    pts_tmp[1, :],
                                    color=color_tmp,
                                    marker=marker_tmp,
                                    s=pts_size,
                                    alpha=alpha)
            if mse:
                if mse_value is None:
                    num_pts = pts_tmp.shape[1]
                    mse_tmp, _ = pts_euclidean(pts_tmp[0:2, :],
                                               np.zeros((2, num_pts),
                                                        dtype='float32'),
                                               debug=debug)
                else:
                    mse_tmp = mse_value[method_name]
                display_string = '%s, MSE: %.1f (%.1f um), Covariance: %.1f' % (
                    method_name, mse_tmp, mse_tmp * scale_distance,
                    covariance_number)
                mse_return[method_name] = mse_tmp
            else:
                display_string = method_name
            handle_dict[display_string] = handle_tmp
            color_index += 1
            if color_index / len(color_set) == 1:
                marker_index += 1
                hatch_index += 1
                color_index = color_index % len(color_set)

        # reorder the handle before plot
        handle_key_list = handle_dict.keys()
        handle_value_list = handle_dict.values()
        order_index_list = [
            display_list.index(method_name_tmp.split(', ')[0])
            for method_name_tmp in handle_dict.keys()
        ]
        ordered_handle_key_list = list_reorder(handle_key_list,
                                               order_index_list,
                                               debug=debug)
        ordered_handle_value_list = list_reorder(handle_value_list,
                                                 order_index_list,
                                                 debug=debug)
        plt.legend(list2tuple(ordered_handle_value_list),
                   list2tuple(ordered_handle_key_list),
                   scatterpoints=1,
                   markerscale=4,
                   loc='lower left',
                   fontsize=legend_fontsize)

    else:
        color_tmp = color_set[color_index]
        marker_tmp = marker_set[marker_index]
        hatch_tmp = hatch_set[hatch_index]
        handle_tmp = ax.scatter(pts[0, :],
                                pts[1, :],
                                color=color_tmp,
                                marker=marker_tmp,
                                s=pts_size,
                                alpha=alpha)

        # plot covariance ellipse
        if covariance:
            _, covariance_number = visualize_pts_covariance(
                pts[0:2, :],
                std=std,
                conf=conf,
                ax=ax,
                debug=debug,
                color=color_tmp,
                hatch=hatch_tmp,
                linewidth=linewidth)

        if mse:
            if mse_value is None:
                num_pts = pts.shape[1]
                mse_tmp, _ = pts_euclidean(pts[0:2, :],
                                           np.zeros((2, num_pts),
                                                    dtype='float32'),
                                           debug=debug)
                display_string = 'MSE: %.1f (%.1f um), Covariance: %.1f' % (
                    mse_tmp, mse_tmp * scale_distance, covariance_number)
                mse_return = mse_tmp
            else:
                display_string = 'MSE: %.1f (%.1f um), Covariance: %.1f' % (
                    mse_value, mse_value * scale_distance, covariance_number)
                mse_return = mse_value
            handle_dict[display_string] = handle_tmp
            plt.legend(list2tuple(handle_dict.values()),
                       list2tuple(handle_dict.keys()),
                       scatterpoints=1,
                       markerscale=4,
                       loc='lower left',
                       fontsize=legend_fontsize)

    # display only specific range
    if display_range:
        axis_bin = 10 * 2
        interval_x = (xlim[1] - xlim[0]) / axis_bin
        interval_y = (ylim[1] - ylim[0]) / axis_bin
        plt.xlim(xlim[0], xlim[1])
        plt.ylim(ylim[0], ylim[1])
        plt.xticks(np.arange(xlim[0], xlim[1] + interval_x, interval_x))
        plt.yticks(np.arange(ylim[0], ylim[1] + interval_y, interval_y))
    plt.grid()

    save_vis_close_helper(fig=fig,
                          ax=ax,
                          vis=vis,
                          save_path=save_path,
                          warning=warning,
                          debug=debug,
                          closefig=closefig,
                          transparent=False)
    return mse_return
Example #14
0
def load_list_from_folder(folder_path,
                          ext_filter=None,
                          depth=1,
                          recursive=False,
                          sort=True,
                          save_path=None,
                          debug=True):
    '''
    load a list of files or folders from a system path

    parameters:
        folder_path:    root to search 
        ext_filter:     a string to represent the extension of files interested
        depth:          maximum depth of folder to search, when it's None, all levels of folders will be searched
        recursive:      False: only return current level
                        True: return all levels till to the input depth

    outputs:
        fulllist:       a list of elements
        num_elem:       number of the elements
    '''
    folder_path = safepath(folder_path)
    if debug:
        assert isfolder(
            folder_path), 'input folder path is not correct: %s' % folder_path
    if not is_path_exists(folder_path): return [], 0

    if debug:
        assert islogical(
            recursive), 'recursive should be a logical variable: {}'.format(
                recursive)
        assert depth is None or (
            isinteger(depth)
            and depth >= 1), 'input depth is not correct {}'.format(depth)
        assert ext_filter is None or (islist(ext_filter) and all(
            isstring(ext_tmp) for ext_tmp in ext_filter)) or isstring(
                ext_filter), 'extension filter is not correct'
    if isstring(ext_filter):  # convert to a list
        ext_filter = [ext_filter]

    fulllist = list()
    if depth is None:  # find all files recursively
        recursive = True
        wildcard_prefix = '**'
        if ext_filter is not None:
            for ext_tmp in ext_filter:
                wildcard = os.path.join(wildcard_prefix,
                                        '*' + string2ext_filter(ext_tmp))
                curlist = glob2.glob(os.path.join(folder_path, wildcard))
                if sort:
                    curlist = sorted(curlist)
                fulllist += curlist

        else:
            wildcard = wildcard_prefix
            curlist = glob2.glob(os.path.join(folder_path, wildcard))
            if sort:
                curlist = sorted(curlist)
            fulllist += curlist
    else:  # find files based on depth and recursive flag
        wildcard_prefix = '*'
        for index in range(depth - 1):
            wildcard_prefix = os.path.join(wildcard_prefix, '*')
        if ext_filter is not None:
            for ext_tmp in ext_filter:
                wildcard = wildcard_prefix + string2ext_filter(ext_tmp)
                curlist = glob.glob(os.path.join(folder_path, wildcard))
                if sort:
                    curlist = sorted(curlist)
                fulllist += curlist
        else:
            wildcard = wildcard_prefix
            curlist = glob.glob(os.path.join(folder_path, wildcard))
            if sort:
                curlist = sorted(curlist)
            fulllist += curlist
        if recursive and depth > 1:
            newlist, _ = load_list_from_folder(folder_path=folder_path,
                                               ext_filter=ext_filter,
                                               depth=depth - 1,
                                               recursive=True)
            fulllist += newlist

    fulllist = [os.path.normpath(path_tmp) for path_tmp in fulllist]
    num_elem = len(fulllist)

    # save list to a path
    if save_path is not None:
        save_path = safepath(save_path)
        if debug:
            assert is_path_exists_or_creatable(
                save_path), 'the file cannot be created'
        with open(save_path, 'w') as file:
            for item in fulllist:
                file.write('%s\n' % item)
        file.close()

    return fulllist, num_elem