def calculate_truncated_mse(error_list, truncated_list, debug=True):
    '''
    calculate the mse truncated by a set of thresholds, and return the truncated MSE and the percentage of how many points' error is lower than the threshold

    parameters:
        error_list:         a list of error
        truncated_list:     a list of threshold

    return
        tmse_dict:          a dictionary where each entry is a dict and has key 'T-MSE' & 'percentage'
    '''
    if debug:
        assert islist(error_list) and all(
            isscalar(error_tmp)
            for error_tmp in error_list), 'the input error list is not correct'
        assert islist(truncated_list) and all(
            isscalar(thres_tmp) for thres_tmp in
            truncated_list), 'the input truncated list is not correct'
        assert len(truncated_list) > 0, 'there is not entry in truncated list'

    tmse_dict = dict()
    num_entry = len(error_list)
    error_array = np.asarray(error_list)

    for threshold in truncated_list:
        error_index = np.where(error_array[:] < threshold)[0].tolist(
        )  # plot visible points in red color
        error_interested = error_array[error_index]

        entry = dict()
        entry['T-MSE'] = np.mean(error_interested)
        entry['percentage'] = len(error_index) / float(num_entry)
        tmse_dict[threshold] = entry

    return tmse_dict
Example #2
0
def load_image(src_path,
               resize_factor=1.0,
               rotate=0,
               mode='numpy',
               debug=True):
    '''
    load an image from given path

    parameters:
        resize_factor:      resize the image (>1 enlarge)
        mode:               numpy or pil, specify the format of returned image
        rotate:             counterclockwise rotation in degree

    output:
        img:                an uint8 rgb image (numpy or pil)
    '''

    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    src_path = safepath(src_path)

    if debug:
        assert is_path_exists(
            src_path), 'txt path is not correct at %s' % src_path
        assert mode == 'numpy' or mode == 'pil', 'the input mode for returned image is not correct'
        assert (isscalar(resize_factor) and resize_factor > 0) or len(
            resize_factor) == 2, 'the resize factor is not correct: {}'.format(
                resize_factor)

    with open(src_path, 'rb') as f:
        with Image.open(f) as img:
            img = img.convert('RGB')

            # rotation
            if rotate != 0:
                img = img.rotate(rotate, expand=True)

            # scaling
            if isscalar(resize_factor):
                width, height = img.size
                img = img.resize(size=(int(width * resize_factor),
                                       int(height * resize_factor)),
                                 resample=Image.BILINEAR)
            elif len(resize_factor) == 2:
                resize_width, resize_height = int(resize_factor[0]), int(
                    resize_factor[1])
                img = img.resize(size=(resize_width, resize_height),
                                 resample=Image.BILINEAR)
            else:
                assert False, 'the resize factor is neither a scalar nor a (width, height)'

            # formating
            if mode == 'numpy':
                img = np.array(img)

    return img
Example #3
0
def get_fig_ax_helper(fig=None, ax=None, width=None, height=None, debug=True):
    if fig is None:
        if width is not None and height is not None:
            if debug:
                assert isscalar(width) and isscalar(
                    height), 'the height and width are not correct'
            figsize = width / float(dpi), height / float(dpi)
            fig = plt.figure(figsize=figsize)
        else:
            fig = plt.gcf()
    if ax is None: ax = plt.gca()
    return fig, ax
Example #4
0
def data_normalize(input_data, method='max', data_range=None, sum=1, warning=True, debug=True):
	'''
	this function normalizes N-d data in different ways: 1) normalize the data from a range to [0, 1]; 2) normalize the data which sums to a value

	parameters:
		input_data:			a list or a numpy N-d data to normalize
		method:				max:	normalize the data from a range to [0, 1], when the range is not given, the max and min are obtained from the data
							sum:	normalize the data such that all elements are summed to a value, the default value is 1
		data_range:			None or 2-element tuple, list or array
		sum:				a scalar

	outputs:
		normalized_data:	a float32 numpy array with same shape as the input data
	'''
	np_data = safe_npdata(input_data, warning=warning, debug=debug).astype('float32')
	if debug: 
		assert isnparray(np_data), 'the input data is not a numpy data'
		assert method in ['max', 'sum'], 'the method for normalization is not correct'

	if method == 'max':
		if data_range is None: max_value, min_value = np.max(np_data), np.min(np_data)
		else:	
			if debug: assert isrange(data_range), 'data range is not correct'
			max_value, min_value = data_range[1], data_range[0]
	elif method == 'sum':
		if debug: assert isscalar(sum), 'the sum is not correct'
		max_value, min_value = np.sum(np_data) / sum, 0

	normalized_data = (np_data - min_value) / (max_value - min_value)	# normalization

	return normalized_data
Example #5
0
def image_draw_mask(input_image,
                    input_image_mask,
                    transparency=0.3,
                    warning=True,
                    debug=True):
    '''
	draw a mask on top of an image with certain transparency

	parameters: 
		input_image:			a pil or numpy image
		input_image_mask:		a pil or numpy image
		transparency:			transparency factor

	outputs:
		masked_image:			uint8 numpy image
	'''
    np_image, _ = safe_image(input_image, warning=warning, debug=debug)
    np_image_mask, _ = safe_image(input_image_mask,
                                  warning=warning,
                                  debug=debug)
    if debug:
        assert isscalar(transparency), 'the transparency should be a scalar'
        assert np_image.shape == np_image_mask.shape, 'the shape of mask should be equal to the shape of input image'
    if isfloatimage(np_image): np_image = (np_image * 255.).astype('uint8')
    if isfloatimage(np_image_mask):
        np_image_mask = (np_image_mask * 255.).astype('uint8')

    pil_image, pil_image_mask = Image.fromarray(np_image), Image.fromarray(
        np_image_mask)
    masked_image = np.array(
        Image.blend(pil_image, pil_image_mask, alpha=transparency))
    return masked_image
Example #6
0
def safe_angle(input_angle, radian=False, warning=True, debug=True):
    '''
	make ensure the rotation is in [-180, 180] in degree

	parameters:
		input_angle:	an angle which is supposed to be in degree
		radian:			if True, the unit is replaced to radian instead of degree

	outputs:
		angle:			an angle in degree within (-180, 180]
	'''
    angle = copy.copy(input_angle)
    if debug:
        assert isscalar(angle), 'the input angle should be a scalar'

    if isnparray(angle): angle = angle[0]  # single numpy scalar value
    if radian:
        while angle > np.pi:
            angle -= np.pi
        while angle <= -np.pi:
            angle += np.pi
    else:
        while angle > 180:
            angle -= 360
        while angle <= -180:
            angle += 360

    return angle
Example #7
0
def get_2dline_from_pts_slope(input_pts, slope, warning=True, debug=True):
    '''
	get the homogeneous line representation from two a homogeneous point and the slope in degree

	parameters:
		input_pts1:         a homogeneous 2D point, can be a list or tuple or numpy array: (x, y, z)
		slope:              a scalar in degree

	outputs:
		np_line:            a homogeneous 2D line,  can be a list or tuple or numpy array: 3 x 1, (a, b, c)
	'''
    np_pts1 = safe_2dptsarray(input_pts,
                              homogeneous=True,
                              warning=warning,
                              debug=debug)
    if debug:
        assert is2dhomopts(np_pts1), 'point is not correct'
        assert isscalar(slope), 'the slope is not correct'

    y = math.sin(math.radians(slope))  # math.tan can handle 90 or -90
    x = math.cos(math.radians(slope))  # math.tan can handle 90 or -90
    np_pts2 = np.array([x, y, 0]).reshape(
        (3, 1))  # this equation is obtained from slope
    np_line = get_2dline_from_pts(np_pts1,
                                  np_pts2,
                                  warning=warning,
                                  debug=debug)

    return np_line
Example #8
0
def tracking_lk_opencv(input_image1, input_image2, input_pts, backward=False, win_size=15, pyramid=5, warning=True, debug=True):
	'''
	tracking a set of points in two images using Lucas-Kanade tracking implemented in opencv

	parameters:
		input_image1, input_image2:				a pil or numpy image, color or gray
		input_pts: 			a list of 2 elements, a listoflist of 2 elements: 
							e.g., [[1,2], [5,6]], a numpy array with shape or (2, N) or (2, )
		backward:			run backward tracking if true
		win_sie:			window sized used for lucas kanade tracking
		pyramid:			number of levels of pyramid used for lucas kanade tracking

	outputs:
		pts_forward:		tracked points in forward pass, 2 x N float32 numpy
		pts_bacward:		tracked points in backward pass, 2 x N float32 numpy, None is not runnign the backward pass
		backward_err_list:	a list of error in forward-backward pass check, None if not running the backward pass
		found_index_list:	a list of 0 or 1, 1 if the tracking converges, 0 if not converging
	'''
	np_image1, _ = safe_image(input_image1, warning=warning, debug=debug)
	np_image2, _ = safe_image(input_image2, warning=warning, debug=debug)
	np_pts = safe_2dptsarray(input_pts, homogeneous=False, warning=warning, debug=debug).astype('float32')		# 2 x N
	if debug: assert isscalar(win_size) and isscalar(pyramid), 'the hyperparameters of lucas-kanade tracking is not correct'
	num_pts = np_pts.shape[1]

	# formatting the input
	if iscolorimage_dimension(np_image1): np_image1 = rgb2gray(np_image1)
	if iscolorimage_dimension(np_image2): np_image2 = rgb2gray(np_image2)

	lk_params = dict(winSize=(win_size, win_size), maxLevel=pyramid, criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10000, 0.03))
	pts_root = np.expand_dims(np_pts.transpose(), axis=1) 			# N x 1 x 2
	
	pts_forward, status_for, err_for = cv2.calcOpticalFlowPyrLK(np_image1, np_image2, pts_root, None, **lk_params) 
	found_index_for = np.where(status_for[:, 0] == 1)[0].tolist()
	if backward: 
		pts_bacward, status_bac, err_bac = cv2.calcOpticalFlowPyrLK(np_image2, np_image1, pts_forward, None, **lk_params)
		# print(status_bac)
		found_index_bac = np.where(status_bac[:, 0] == 1)[0].tolist()
		# aa
		pts_bacward = pts_bacward.reshape((-1, 2)).transpose()
		_, backward_err_list = pts_euclidean(np_pts, pts_bacward, warning=warning, debug=debug)
		found_index_list = find_unique_common_from_lists(found_index_for, found_index_bac, warning=warning, debug=debug)
	else: 
		pts_bacward, backward_err_list = None, None
		found_index_list = found_index_for

	pts_forward = pts_forward.reshape((-1, 2)).transpose()			#  2 x N
	return pts_forward, pts_bacward, backward_err_list, found_index_list
def image_find_peaks(input_image, percent_threshold=0.5, warning=True, debug=True):
	'''
	this function find all strict local peaks and a strict global peak from a grayscale image
	the strict local maximum means that the pixel value must be larger than all nearby pixel values

	parameters:
		input_image:			a pil or numpy grayscale image
		percent_threshold:		determine to what pixel value to be smoothed out. 
								e.g., when 0.4, all pixel values less than 0.4 * np.max(input_image) are smoothed out to be 0

	outputs:
		peak_array:				a numpy float32 array, 3 x num_peaks, (x, y, score)
		peak_global:			a numpy float32 array, 3 x 1: (x, y, score)
	'''
	np_image, _ = safe_image_like(input_image, warning=warning, debug=debug)
	if isuintimage(np_image): np_image = np_image.astype('float32') / 255.
	if debug: 
		assert isgrayimage(np_image) and isfloatimage(np_image), 'the input image is not a grayscale and float image'
		assert isscalar(percent_threshold) and percent_threshold >= 0 and percent_threshold <= 1, 'the percent_threshold is not correct'

	max_value = np.max(np_image)
	np_image[np_image < percent_threshold * max_value] = 0.0
	height, width = np_image.shape[0], np_image.shape[1]
	npimage_center, npimage_top, npimage_bottom, npimage_left, npimage_right = np.zeros([height + 2, width + 2]), np.zeros([height + 2, width + 2]), np.zeros([height + 2, width + 2]), np.zeros([height + 2, width + 2]), np.zeros([height + 2, width + 2])

	# shift in different directions to find local peak, only works for convex blob
	npimage_center[1:-1, 1:-1] = np_image
	npimage_left[1:-1, 0:-2] = np_image
	npimage_right[1:-1, 2:] = np_image
	npimage_top[0:-2, 1:-1] = np_image
	npimage_bottom[2:, 1:-1] = np_image

	# compute pixels larger than its shifted version of heatmap
	right_bool = npimage_center > npimage_right
	left_bool = npimage_center > npimage_left
	bottom_bool = npimage_center > npimage_bottom
	top_bool = npimage_center > npimage_top

	# the strict local maximum must be bigger than all nearby pixel values
	peakMap = np.logical_and(np.logical_and(np.logical_and(right_bool, left_bool), top_bool), bottom_bool)		
	peakMap = peakMap[1:-1, 1:-1]
	peak_location_tuple = np.nonzero(peakMap)     # find true
	num_peaks = len(peak_location_tuple[0])
	if num_peaks == 0:
		if warning: print('No single local peak found')
		return np.zeros((3, 0), dtype='float32'), np.zeros((3, 0), dtype='float32')

	# convert to a numpy array format
	peak_array = np.zeros((3, num_peaks), dtype='float32')
	peak_array[0, :], peak_array[1, :] = peak_location_tuple[1], peak_location_tuple[0]
	for peak_index in range(num_peaks):
		peak_array[2, peak_index] = np_image[int(peak_array[1, peak_index]), int(peak_array[0, peak_index])]

	# find the global peak from all local peaks
	global_peak_index = np.argmax(peak_array[2, :])
	peak_global = peak_array[:, global_peak_index].reshape((3, 1))

	return peak_array, peak_global
Example #10
0
def nparray_resize(input_nparray,
                   resize_factor=None,
                   target_size=None,
                   interp='bicubic',
                   warning=True,
                   debug=True):
    '''
    resize the numpy array given a resize factor (e.g., 0.25), or given a target size (height, width)
    e.g., the numpy array has 600 x 800:
        1. given a resize factor of 0.25 -> results in an image with 150 x 200
        2. given a target size of (300, 400) -> results in an image with 300 x 400
    note that:
        resize_factor and target_size cannot exist at the same time

    parameters:
        input_nparray:      a numpy array
        resize_factor:      a scalar
        target_size:        a list of tuple or numpy array with 2 elements, representing height and width
        interp:             interpolation methods: bicubic or bilinear

    outputs:
        resized_nparray:    a numpy array
    '''
    np_array = safe_npdata(input_nparray, warning=warning, debug=debug)
    if debug:
        assert interp in ['bicubic', 'bilinear'
                          ], 'the interpolation method is not correct'
        assert (resize_factor is not None and target_size is None) or (
            resize_factor is None and target_size
            is not None), 'resize_factor and target_size cannot co-exist'

    if target_size is not None:
        if debug:
            assert isimsize(
                target_size), 'the input target size is not correct'
        target_width, target_height = int(round(target_size[1])), int(
            round(target_size[0]))
    elif resize_factor is not None:
        if debug:
            assert isscalar(resize_factor), 'the resize factor is not a scalar'
        height, width = np_array.shape[:2]
        target_width, target_height = int(round(resize_factor * width)), int(
            round(resize_factor * height))
    else:
        assert False, 'the target_size and resize_factor do not exist'

    if interp == 'bicubic':
        resized_nparray = cv2.resize(np_array, (target_width, target_height),
                                     interpolation=cv2.INTER_CUBIC)
    elif interp == 'bilinear':
        resized_nparray = cv2.resize(np_array, (target_width, target_height),
                                     interpolation=cv2.INTER_LINEAR)
    else:
        assert False, 'interpolation is wrong'

    return resized_nparray
Example #11
0
def degree2radian(degree, debug=True):
    '''
    this function return degree given radians, difference from default math.degrees is that this function normalize the output in range [0, 2*pi)
    '''
    if debug: assert isscalar(degree), 'input degree number is not correct'
    radian = math.radians(degree)
    while radian < 0:
        radian += 2 * math.pi
    while radian >= 2 * math.pi:
        radian -= 2 * math.pi

    return radian
Example #12
0
def radian2degree(radian, debug=True):
    '''
    this function return radian given degree, difference from default math.degrees is that this function normalize the output in range [0, 360)
    '''
    if debug: assert isscalar(degree), 'input radian number is not correct'
    degree = math.degrees(radian)
    while degree < 0:
        degree += 360.0
    while degree >= 360.0:
        degree -= 360.0

    return degree
Example #13
0
def visualize_bar(data,
                  bin_size=2.0,
                  title='Bar Graph of Key-Value Pair',
                  xlabel='index',
                  ylabel='count',
                  vis=True,
                  save_path=None,
                  debug=True,
                  closefig=True):
    '''
    visualize the bar graph of a data, which can be a dictionary or list of dictionary

    different from function of visualize_bar_graph, this function does not depend on panda and dataframe, it's simpler but with less functionality
    also the key of this function takes continuous scalar variable
    '''
    if debug:
        assert isstring(title) and isstring(xlabel) and isstring(
            ylabel), 'title/xlabel/ylabel is not correct'
        assert isdict(data) or islist(data), 'input data is not correct'
        assert isscalar(bin_size), 'the bin size is not a floating number'

    if isdict(data):
        index_list = data.keys()
        if debug:
            assert islistofscalar(
                index_list
            ), 'the input dictionary does not contain a scalar key'
        frequencies = data.values()
    else:
        index_list = range(len(data))
        frequencies = data

    index_str_list = scalarlist2strlist(index_list, debug=debug)
    index_list = np.array(index_list)
    fig, ax = get_fig_ax_helper(fig=None, ax=None)
    # ax.set_xticks(index_list)
    # ax.set_xticklabels(index_str_list)
    plt.bar(index_list, frequencies, bin_size, color='r', alpha=0.5)
    plt.title(title, fontsize=20)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    return save_vis_close_helper(fig=fig,
                                 ax=ax,
                                 vis=vis,
                                 save_path=save_path,
                                 debug=debug,
                                 transparent=False,
                                 closefig=closefig)
def image_concatenate(input_image, target_size=[1600, 2560], grid_size=None, edge_factor=0.99, warning=True, debug=True):
	'''
	concatenate a list of images automatically

	parameters:	
		input_image:			NHWC numpy image, uint8 or float32
		target_size:			a tuple or list or numpy array with 2 elements, for [H, W]
		grid_size:				a tuple or list or numpy array with 2 elements, for [num_rows, num_cols] 
		edge_factor:			the margin between images after concatenation, bigger, the edge is smaller, [0, 1]

	outputs:
		image_merged: 			CHW uint8 numpy image with size of target_size
	'''
	np_image, _ = safe_batch_image(input_image, warning=warning, debug=debug)
	if debug:
		assert isimsize(target_size), 'the input image size is not correct'
		if grid_size is not None: assert isimsize(grid_size), 'the input grid size is not correct'
		assert isscalar(edge_factor) and edge_factor <= 1 and edge_factor >= 0, 'the edge factor is not correct'

	num_images = np_image.shape[0]
	if grid_size is None:
		num_rows = int(np.sqrt(num_images))
		num_cols = int(np.ceil(num_images * 1.0 / num_rows))
	else:
		num_rows, num_cols = np.ceil(grid_size[0]), np.ceil(grid_size[1])

	window_height, window_width = target_size[0], target_size[1]
	grid_height = int(window_height / num_rows)
	grid_width  = int(window_width  / num_cols)
	im_height   = int(grid_height   * edge_factor)
	im_width 	= int(grid_width 	 * edge_factor)
	im_channel 	= np_image.shape[-1]

	# concatenate
	image_merged = np.zeros((window_height, window_width, im_channel), dtype='uint8')
	for image_index in range(num_images):
		image_tmp = np_image[image_index, :, :, :]
		image_tmp = image_resize(image_tmp, target_size=(im_height, im_width), warning=warning, debug=debug)

		rows_index = int(np.ceil((image_index + 1.0) / num_cols))				# 1-indexed
		cols_index = int(image_index + 1 - (rows_index - 1) * num_cols)			# 1-indexed
		rows_start = int((rows_index - 1) * grid_height)						# 0-indexed
		rows_end   = int(rows_start + im_height)								# 0-indexed
		cols_start = int((cols_index - 1) * grid_width)							# 0-indexed
		cols_end   = int(cols_start + im_width)									# 0-indexed
		image_merged[rows_start:rows_end, cols_start:cols_end, :] = image_tmp

	return image_merged
Example #15
0
def safe_npdata(input_data, warning=True, debug=True):
    '''
	copy a list of data or a numpy data to the buffer for use

	parameters:
		input_data:		a list, a scalar or numpy data

	outputs:
		np_data:		a copy of numpy data
	'''
    if islist(input_data): np_data = np.array(input_data)
    elif isscalar(input_data): np_data = np.array(input_data).reshape((1, ))
    elif isnparray(input_data): np_data = input_data.copy()
    else: assert False, 'only list of data, scalar or numpy data are supported'

    return np_data
def image_rotate(input_image, input_angle, warning=True, debug=True):
	'''
	rotate the image given an angle in degree (e.g., 90). The rotation is counter-clockwise
	
	parameters:
		input_image:		an pil or numpy image
		input_angle:		a scalar

	outputs:
		rotated_image:		a numpy uint8 image
	'''	
	if debug: assert isscalar(input_angle), 'the input angle is not a scalar'
	rotation_angle = safe_angle(input_angle, warning=warning, debug=True)             # ensure to be in [-180, 180]

	np_image, _ = safe_image(input_image, warning=warning, debug=debug)
	if isfloatimage(np_image): np_image = (np_image * 255.).astype('uint8')
	pil_image = Image.fromarray(np_image)
	pil_image = pil_image.rotate(rotation_angle, expand=True)
	rotated_image = np.array(pil_image).astype('uint8')

	return rotated_image
Example #17
0
def safe_angle(input_angle, warning=True, debug=True):
    '''
	make ensure the rotation is in [-180, 180] in degree

	parameters:
		input_angle:	an angle which is supposed to be in degree

	outputs:
		angle:			an angle in degree within (-180, 180]
	'''
    angle = copy.copy(input_angle)
    if debug:
        assert isscalar(angle), 'the input angle should be a scalar'

    if isnparray(angle): angle = angle[0]  # single numpy scalar value
    while angle > 180:
        angle -= 360
    while angle <= -180:
        angle += 360

    return angle
Example #18
0
def facial_landmark_evaluation(pred_dict_all, anno_dict, num_pts, error_threshold, normalization_ced=True, normalization_vec=False, covariance=True, display_list=None, debug=True, vis=False, save=True, save_path=None):
	'''
	evaluate the performance of facial landmark detection

	parameter:
		pred_dict_all:	a dictionary for all basline methods. Each key is the method name and the value is corresponding prediction dictionary, 
						which keys are the image path and values are 2 x N prediction results
		anno_dict: 		a dictionary which keys are the image path and values are 2 x N annotation results
		num_pts:		number of points
		vis:			determine if visualizing the pck curve
		save:			determine if saving the visualization results
		save_path:		a directory to save all the results

	visualization:
		1. 2d pck curve (total and point specific) for all points for all methods
		2. point error vector (total and point specific) for all points and for all methods
		3. mean square error

	return:
		metrics_all:	a list of list to have detailed metrics over all methods
		ptswise_mse:	a list of list to have average MSE over all key-points for all methods
	'''
	num_methods = len(pred_dict_all)
	if debug:
		assert isdict(pred_dict_all) and num_methods > 0 and all(isdict(pred_dict) for pred_dict in pred_dict_all.values()), 'predictions result format is not correct'
		assert isdict(anno_dict), 'annotation result format is not correct'
		assert ispositiveinteger(num_pts), 'number of points is not correct'
		assert isscalar(error_threshold), 'error threshold is not correct'
		assert islogical(normalization_ced) and islogical(normalization_vec), 'normalization flag is not correct'
		if display_list is not None: assert len(display_list) == num_methods, 'display list is not correct %d vs %d' % (len(display_list), num_methods)

	num_images = len(pred_dict_all.values()[0])
	if debug:
		assert num_images > 0, 'the predictions are empty'
		assert num_images == len(anno_dict), 'number of images is not equal to number of annotations: %d vs %d' % (num_images, len(anno_dict))
		assert all(num_images == len(pred_dict) for pred_dict in pred_dict_all.values()), 'number of images in results from different methods are not equal'

	# calculate normalized mean error for each single image based on point-to-point Euclidean distance normalized by the bounding box size
	# calculate point error vector for each single image based on error vector normalized by the bounding box size
	normed_mean_error_dict = dict()
	normed_mean_error_pts_specific_dict = dict()
	normed_mean_error_pts_specific_valid_dict = dict()
	pts_error_vec_dict = dict()
	pts_error_vec_pts_specific_dict = dict()
	mse_error_dict_dict = dict()
	for method_name, pred_dict in pred_dict_all.items():
		normed_mean_error_total = np.zeros((num_images, ), dtype='float32')
		normed_mean_error_pts_specific = np.zeros((num_images, num_pts), dtype='float32')
		normed_mean_error_pts_specific_valid = np.zeros((num_images, num_pts), dtype='bool')
		pts_error_vec = np.zeros((num_images, 2), dtype='float32')					
		pts_error_vec_pts_specific = np.zeros((num_images, 2, num_pts), dtype='float32')
		mse_error_dict = dict()
		count = 0
		count_skip_num_images = 0						# it's possible that no annotation exists on some images, than no error should be counted for those images, we count the number of those images
		for image_path, pts_prediction in pred_dict.items():
			_, filename, _ = fileparts(image_path)
			pts_anno = anno_dict[filename]				# 2 x N annotation
			pts_keep_index = range(num_pts)

			# to avoid list object type, do conversion here
			if islist(pts_anno): pts_anno = np.asarray(pts_anno)
			if islist(pts_prediction): pts_prediction = np.asarray(pts_prediction)
			if debug: assert (is2dptsarray(pts_anno) or is2dptsarray_occlusion(pts_anno)) and pts_anno.shape[1] == num_pts, 'shape of annotations is not correct (%d x %d) vs (%d x %d)' % (2, num_pts, pts_anno.shape[0], pts_anno.shape[1])

			# if the annotation has 3 channels (include extra occlusion channel, we keep only the points with annotations)
			# occlusion: -1 -> visible but not annotated, 0 -> invisible and not annotated, 1 -> visible, we keep only visible and annotated points
			if pts_anno.shape[0] == 3:	
				pts_keep_index = np.where(pts_anno[2, :] == 1)[0].tolist()
				if len(pts_keep_index) <= 0:		# if no point is annotated in current image
					count_skip_num_images += 1
					continue
				pts_anno = pts_anno[0:2, pts_keep_index]		
				pts_prediction = pts_prediction[:, pts_keep_index]
			
			# to avoid the point location includes the score or occlusion channel, only take the first two channels here	
			if pts_prediction.shape[0] == 3 or pts_prediction.shape[0] == 4: 
				pts_prediction = pts_prediction[0:2, :]

			num_pts_tmp = len(pts_keep_index)
			if debug:
				assert pts_anno.shape[1] <= num_pts, 'number of points is not correct: %d vs %d' % (pts_anno.shape[1], num_pts)
				assert pts_anno.shape == pts_prediction.shape, 'shape of annotations and predictions are not the same {} vs {}'.format(print_np_shape(pts_anno, debug=debug), print_np_shape(pts_prediction, debug=debug))
				# print 'number of points to keep is %d' % num_pts_tmp

			# calculate bbox for normalization
			if normalization_ced or normalization_vec:
				assert len(pts_keep_index) == num_pts, 'some points are not annotated. Normalization on PCK curve is not allowed.'
				bbox_anno = pts2bbox(pts_anno, debug=debug)							# 1 x 4
				bbox_TLWH = bbox_TLBR2TLWH(bbox_anno, debug=debug)					# 1 x 4
				bbox_size = math.sqrt(bbox_TLWH[0, 2] * bbox_TLWH[0, 3])			# scalar
			
			# calculate normalized error for all points
			normed_mean_error, _ = pts_euclidean(pts_prediction, pts_anno, debug=debug)	# scalar
			if normalization_ced: normed_mean_error /= bbox_size
			normed_mean_error_total[count] = normed_mean_error
			mse_error_dict[image_path] = normed_mean_error

			if normed_mean_error == 0:
				print pts_prediction
				print pts_anno

			# calculate normalized error point specifically
			for pts_index in xrange(num_pts):
				if pts_index in pts_keep_index:			# if current point not annotated in current image, just keep 0
					normed_mean_error_pts_specific_valid[count, pts_index] = True
				else: continue

				pts_index_from_keep_list = pts_keep_index.index(pts_index)
				pts_prediction_tmp = np.reshape(pts_prediction[:, pts_index_from_keep_list], (2, 1))
				pts_anno_tmp = np.reshape(pts_anno[:, pts_index_from_keep_list], (2, 1))
				normed_mean_error_pts_specifc_tmp, _ = pts_euclidean(pts_prediction_tmp, pts_anno_tmp, debug=debug)

				if normalization_ced: normed_mean_error_pts_specifc_tmp /= bbox_size
				normed_mean_error_pts_specific[count, pts_index] = normed_mean_error_pts_specifc_tmp

			# calculate the point error vector
			error_vector = pts_prediction - pts_anno 			# 2 x num_pts_tmp
			if normalization_vec: error_vector /= bbox_size
			pts_error_vec_pts_specific[count, :, pts_keep_index] = np.transpose(error_vector)
			pts_error_vec[count, :] = np.sum(error_vector, axis=1) / num_pts_tmp

			count += 1

		print 'number of skipped images is %d' % count_skip_num_images
		assert count + count_skip_num_images == num_images, 'all cells in the array must be filled %d vs %d' % (count + count_skip_num_images, num_images)
		# print normed_mean_error_total
		# time.sleep(1000)

		# save results to dictionary
		normed_mean_error_dict[method_name] = normed_mean_error_total[:count]
		normed_mean_error_pts_specific_dict[method_name] = normed_mean_error_pts_specific[:count, :]
		normed_mean_error_pts_specific_valid_dict[method_name] = normed_mean_error_pts_specific_valid[:count, :]
		pts_error_vec_dict[method_name] = np.transpose(pts_error_vec[:count, :])												# 2 x num_images
		pts_error_vec_pts_specific_dict[method_name] = pts_error_vec_pts_specific[:count, :, :]
		mse_error_dict_dict[method_name] = mse_error_dict

	# calculate mean value
	if mse:
		mse_value = dict()		# dictionary to record all average MSE for different methods
		mse_dict = dict()		# dictionary to record all point-wise MSE for different keypoints
		for method_name, error_array in normed_mean_error_dict.items():
			mse_value[method_name] = np.mean(error_array)
	else: mse_value = None

	# save mse error list to file for each method
	error_list_savedir = os.path.join(save_path, 'error_list')
	mkdir_if_missing(error_list_savedir)
	for method_name, mse_error_dict in mse_error_dict_dict.items():
		mse_error_list_path = os.path.join(error_list_savedir, 'error_%s.txt' % method_name)
		mse_error_list = open(mse_error_list_path, 'w')
		
		sorted_tuple_list = sorted(mse_error_dict.items(), key=operator.itemgetter(1), reverse=True)
		for tuple_index in range(len(sorted_tuple_list)):
			image_path_tmp = sorted_tuple_list[tuple_index][0]
			mse_error_tmp = sorted_tuple_list[tuple_index][1]
			mse_error_list.write('{:<200} {}\n'.format(image_path_tmp, '%.2f' % mse_error_tmp))
		mse_error_list.close()
		print '\nsave mse error list for %s to %s' % (method_name, mse_error_list_path)

	# visualize the ced (cumulative error distribution curve)
	print('visualizing pck curve....\n')
	pck_savedir = os.path.join(save_path, 'pck')
	mkdir_if_missing(pck_savedir)
	pck_savepath = os.path.join(pck_savedir, 'pck_curve_overall.png')
	table_savedir = os.path.join(save_path, 'metrics')
	mkdir_if_missing(table_savedir)
	table_savepath = os.path.join(table_savedir, 'detailed_metrics_overall.txt')
	_, metrics_all = visualize_ced(normed_mean_error_dict, error_threshold=error_threshold, normalized=normalization_ced, truncated_list=truncated_list, title='2D PCK curve (all %d points)' % num_pts, display_list=display_list, debug=debug, vis=vis, pck_savepath=pck_savepath, table_savepath=table_savepath)
	metrics_title = ['Method Name / Point Index']
	ptswise_mse_table = [[normed_mean_error_pts_specific_dict.keys()[index_tmp]] for index_tmp in xrange(num_methods)]
	for pts_index in xrange(num_pts):
		metrics_title.append(str(pts_index + 1))
		normed_mean_error_dict_tmp = dict()

		for method_name, error_array in normed_mean_error_pts_specific_dict.items():
			normed_mean_error_pts_specific_valid_temp = normed_mean_error_pts_specific_valid_dict[method_name]
			
			# Some points at certain images might not be annotated. When calculating MSE for these specific point, we remove those images to avoid "false" mean average error
			valid_array_per_pts_per_method = np.where(normed_mean_error_pts_specific_valid_temp[:, pts_index] == True)[0].tolist()
			error_array_per_pts = error_array[:, pts_index]
			error_array_per_pts = error_array_per_pts[valid_array_per_pts_per_method]
			num_image_tmp = len(valid_array_per_pts_per_method)
			# print(num_image_tmp)
			if num_image_tmp == 0: continue
			# aaa
			normed_mean_error_dict_tmp[method_name] = np.reshape(error_array_per_pts, (num_image_tmp, ))
		pck_savepath = os.path.join(pck_savedir, 'pck_curve_pts_%d.png' % (pts_index+1))
		table_savepath = os.path.join(table_savedir, 'detailed_metrics_pts_%d.txt' % (pts_index+1))

		if len(normed_mean_error_dict_tmp) == 0: continue
		metrics_dict, _ = visualize_ced(normed_mean_error_dict_tmp, error_threshold=error_threshold, normalized=normalization_ced, truncated_list=truncated_list, display2terminal=False, title='2D PCK curve for point %d' % (pts_index+1), display_list=display_list, debug=debug, vis=vis, pck_savepath=pck_savepath, table_savepath=table_savepath)
		for method_index in range(num_methods):
			method_name = normed_mean_error_pts_specific_dict.keys()[method_index]
			ptswise_mse_table[method_index].append('%.1f' % metrics_dict[method_name]['MSE'])
	
	# reorder the table
	order_index_list = [display_list.index(method_name_tmp) for method_name_tmp in normed_mean_error_pts_specific_dict.keys()]
	order_index_list = [0] + [order_index_tmp + 1 for order_index_tmp in order_index_list]

	# print table to terminal
	ptswise_mse_table = list_reorder([metrics_title] + ptswise_mse_table, order_index_list, debug=debug)
	table = AsciiTable(ptswise_mse_table)
	print '\nprint point-wise average MSE'
	print table.table
	# save table to file
	ptswise_savepath = os.path.join(table_savedir, 'pointwise_average_MSE.txt')
	table_file = open(ptswise_savepath, 'w')
	table_file.write(table.table)
	table_file.close()
	print '\nsave point-wise average MSE to %s' % ptswise_savepath

	# visualize the error vector map
	# print('visualizing error vector distribution map....\n')
	# error_vec_save_dir = os.path.join(save_path, 'error_vec')
	# mkdir_if_missing(error_vec_save_dir)
	# savepath_tmp = os.path.join(error_vec_save_dir, 'error_vector_distribution_all.png')
	# visualize_pts(pts_error_vec_dict, title='Point Error Vector Distribution (all %d points)' % num_pts, mse=mse, mse_value=mse_value, display_range=display_range, display_list=display_list, xlim=xlim, ylim=ylim, covariance=covariance, debug=debug, vis=vis, save_path=savepath_tmp)
	# for pts_index in xrange(num_pts):
	# 	pts_error_vec_pts_specific_dict_tmp = dict()
	# 	for method_name, error_vec_dict in pts_error_vec_pts_specific_dict.items():
	# 		pts_error_vec_pts_specific_valid = normed_mean_error_pts_specific_valid_dict[method_name]		# get valid flag
	# 		valid_image_index_per_pts = np.where(pts_error_vec_pts_specific_valid[:, pts_index] == True)[0].tolist()		# get images where the points with current index are annotated
	# 		print(len(valid_image_index_per_pts))

	# 		pts_error_vec_pts_specific_dict_tmp[method_name] = np.transpose(error_vec_dict[valid_image_index_per_pts, :, pts_index])		# 2 x num_images
	# 	savepath_tmp = os.path.join(error_vec_save_dir, 'error_vector_distribution_pts_%d.png' % (pts_index+1))
	# 	if mse:
	# 		mse_dict_tmp = visualize_pts(pts_error_vec_pts_specific_dict_tmp, title='Point Error Vector Distribution for Point %d' % (pts_index+1), mse=mse, display_range=display_range, display_list=display_list, xlim=xlim, ylim=ylim, covariance=covariance, debug=debug, vis=vis, save_path=savepath_tmp)
	# 		mse_best = min(mse_dict_tmp.values())
	# 		mse_single = dict()
	# 		mse_single['mse'] = mse_best
	# 		mse_single['num_images'] = len(valid_image_index_per_pts)			# assume number of valid images is equal for all methods
	# 		mse_dict[pts_index] = mse_single
	# 	else:
	# 		visualize_pts(pts_error_vec_pts_specific_dict_tmp, title='Point Error Vector Distribution for Point %d' % (pts_index+1), mse=mse, display_range=display_range, display_list=display_list, xlim=xlim, ylim=ylim, covariance=covariance, debug=debug, vis=vis, save_path=savepath_tmp)

	# save mse to json file for further use
	# if mse: 
	# 	json_path = os.path.join(save_path, 'mse_pts.json')

	# 	# if existing, compare and select the best
	# 	if is_path_exists(json_path):
	# 		with open(json_path, 'r') as file:
	# 			mse_dict_old = json.load(file)
	# 			file.close()

	# 		for pts_index, mse_single in mse_dict_old.items():
	# 			mse_dict_new = mse_dict[int(pts_index)]
	# 			mse_new = mse_dict_new['mse']
	# 			if mse_new < mse_single['mse']:
	# 				mse_single['mse'] = mse_new
	# 			mse_dict_old[pts_index] = mse_single

	# 		with open(json_path, 'w') as file:
	# 			print('overwrite old mse to {}'.format(json_path))
	# 			json.dump(mse_dict_old, file)
	# 			file.close()

	# 	else:
	# 		with open(json_path, 'w') as file:
	# 			print('save mse for all keypoings to {}'.format(json_path))
	# 			json.dump(mse_dict, file)
	# 			file.close()

	print('\ndone!!!!!\n')
	return metrics_all, ptswise_mse_table
Example #19
0
def visualize_bar_graph(data,
                        title='Bar Graph of Key-Value Pair',
                        xlabel='pixel error',
                        ylabel='keypoint index',
                        label=False,
                        label_list=None,
                        vis=True,
                        save_path=None,
                        debug=True,
                        closefig=True):
    '''
    visualize the bar graph of a data, which can be a dictionary or list of dictionary
    inside each dictionary, the keys (string) should be the same which is the y label, the values should be scalar
    '''
    if debug:
        assert isstring(title) and isstring(xlabel) and isstring(
            ylabel), 'title/xlabel/ylabel is not correct'
        assert isdict(data) or islistofdict(data), 'input data is not correct'
        if isdict(data):
            assert all(
                isstring(key_tmp)
                for key_tmp in data.keys()), 'the keys are not all strings'
            assert all(
                isscalar(value_tmp)
                for value_tmp in data.values()), 'the keys are not all strings'
        else:
            assert len(data) <= len(
                color_set
            ), 'number of data set is larger than number of color to use'
            keys = sorted(data[0].keys())
            for dict_tmp in data:
                if not (sorted(dict_tmp.keys()) == keys):
                    print(dict_tmp.keys())
                    print(keys)
                    assert False, 'the keys are not equal across different input set'
                assert all(isstring(key_tmp) for key_tmp in
                           dict_tmp.keys()), 'the keys are not all strings'
                assert all(
                    isscalar(value_tmp) for value_tmp in
                    dict_tmp.values()), 'the values are not all scalars'

    # convert dictionary to DataFrame
    data_new = dict()
    if isdict(data):
        key_list = data.keys()
        sorted_index = sorted(range(len(key_list)), key=lambda k: key_list[k])
        data_new['names'] = (np.asarray(key_list)[sorted_index]).tolist()
        data_new['values'] = (np.asarray(data.values())[sorted_index]).tolist()
    else:
        key_list = data[0].keys()
        sorted_index = sorted(range(len(key_list)), key=lambda k: key_list[k])
        data_new['names'] = (np.asarray(key_list)[sorted_index]).tolist()
        num_sets = len(data)
        for set_index in range(num_sets):
            data_new['value_%03d' % set_index] = (np.asarray(
                data[set_index].values())[sorted_index]).tolist()
    dataframe = DataFrame(data_new)

    # plot
    width = 2000
    height = 2000
    alpha = 0.5
    figsize = width / float(dpi), height / float(dpi)
    fig = plt.figure(figsize=figsize)
    sns.set(style='whitegrid')
    # fig, ax = get_fig_ax_helper(fig=None, ax=None)
    if isdict(data):
        g = sns.barplot(x='values',
                        y='names',
                        data=dataframe,
                        label='data',
                        color='b')
        plt.legend(ncol=1, loc='lower right', frameon=True, fontsize=5)
    else:
        num_sets = len(data)
        for set_index in range(num_sets):
            if set_index == 0:
                sns.set_color_codes('pastel')
            else:
                sns.set_color_codes('muted')

            if label:
                sns.barplot(x='value_%03d' % set_index,
                            y='names',
                            data=dataframe,
                            label=label_list[set_index],
                            color=color_set[set_index],
                            alpha=alpha)
            else:
                sns.barplot(x='value_%03d' % set_index,
                            y='names',
                            data=dataframe,
                            color=solor_set[set_index],
                            alpha=alpha)
        plt.legend(ncol=len(data), loc='lower right', frameon=True, fontsize=5)

    sns.despine(left=True, bottom=True)
    plt.title(title, fontsize=20)
    plt.xlim([0, 50])
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)

    num_yticks = len(data_new['names'])
    adaptive_fontsize = -0.0555556 * num_yticks + 15.111
    plt.yticks(fontsize=adaptive_fontsize)

    return save_vis_close_helper(fig=fig,
                                 vis=vis,
                                 save_path=save_path,
                                 debug=debug,
                                 closefig=closefig)
Example #20
0
def visualize_covariance_ellipse(covariance,
                                 center,
                                 conf=None,
                                 std=None,
                                 fig=None,
                                 ax=None,
                                 debug=True,
                                 **kwargs):
    """
    Plots an `nstd` sigma error ellipse based on the specified covariance
    matrix (`cov`). Additional keyword arguments are passed on to the 
    ellipse patch artist.

    Parameters
        covariance      : The 2x2 covariance matrix to base the ellipse on
        center          : The location of the center of the ellipse. Expects a 2-element sequence of [x0, y0].
        conf            : a floating number between [0, 1]
        std             : The radius of the ellipse in numbers of standard deviations. Defaults to 2 standard deviations.
        ax              : The axis that the ellipse will be plotted on. Defaults to the current axis.
        Additional keyword arguments are pass on to the ellipse patch.

    Returns
        A covariance ellipse
    """
    if debug:
        if conf is not None:
            assert isscalar(
                conf
            ) and conf >= 0 and conf <= 1, 'the confidence is not in a good range'
        if std is not None:
            assert ispositiveinteger(
                std
            ), 'the number of standard deviation should be a positive integer'
    fig, ax = get_fig_ax_helper(fig=fig, ax=ax)

    def eigsorted(covariance):
        vals, vecs = np.linalg.eigh(covariance)
        # order = vals.argsort()[::-1]
        # return vals[order], vecs[:,order]
        return vals, vecs

    if conf is not None: conf = np.asarray(conf)
    elif std is not None: conf = 2 * norm.cdf(std) - 1
    else: raise ValueError('One of `conf` and `std` should be specified.')
    r2 = chi2.ppf(conf, 2)
    vals, vecs = eigsorted(covariance)
    theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
    # theta = np.degrees(np.arctan2(*vecs[::-1, 0]))
    # Width and height are "full" widths, not radius
    # width, height = 2 * std * np.sqrt(vals)
    width, height = 2 * np.sqrt(np.sqrt(vals) * r2)
    # width, height = 2 * np.sqrt(vals[:, None] * r2)
    ellipse = Ellipse(xy=center,
                      width=width,
                      height=height,
                      angle=theta,
                      **kwargs)
    ellipse.set_facecolor('none')

    ax.add_artist(ellipse)
    return ellipse
Example #21
0
def image_resize(input_image,
                 resize_factor=None,
                 target_size=None,
                 interp='bicubic',
                 warning=True,
                 debug=True):
    '''
	resize the image given a resize factor (e.g., 0.25), or given a target size (height, width)
	e.g., the input image has 600 x 800:
		1. given a resize factor of 0.25 -> results in an image with 150 x 200
		2. given a target size of (300, 400) -> results in an image with 300 x 400
	note that:
		resize_factor and target_size cannot exist at the same time

	parameters:
		input_image:		an pil or numpy image
		resize_factor:		a scalar
		target_size:		a list of tuple or numpy array with 2 elements, representing height and width
		interp:				interpolation methods: bicubic or bilinear

	outputs:
		resized_image:		a numpy uint8 image
	'''
    np_image, _ = safe_image(input_image, warning=warning, debug=debug)
    if isfloatimage(np_image): np_image = (np_image * 255.).astype('uint8')

    if debug:
        assert interp in ['bicubic', 'bilinear'
                          ], 'the interpolation method is not correct'
        assert (resize_factor is not None and target_size is None) or (
            resize_factor is None and target_size
            is not None), 'resize_factor and target_size cannot co-exist'

    if target_size is not None:
        if debug:
            assert isimsize(
                target_size), 'the input target size is not correct'
        target_width, target_height = int(round(target_size[1])), int(
            round(target_size[0]))
        if target_width == np_image.shape[
                1] and target_height == np_image.shape[0]:
            return np_image
    elif resize_factor is not None:
        if debug:
            assert isscalar(
                resize_factor
            ) and resize_factor > 0, 'the resize factor is not a scalar'
        if resize_factor == 1: return np_image  # no resizing
        height, width = np_image.shape[:2]
        target_width, target_height = int(round(resize_factor * width)), int(
            round(resize_factor * height))
    else:
        assert False, 'the target_size and resize_factor do not exist'

    if interp == 'bicubic':
        resized_image = cv2.resize(np_image, (target_width, target_height),
                                   interpolation=cv2.INTER_CUBIC)
    elif interp == 'bilinear':
        resized_image = cv2.resize(np_image, (target_width, target_height),
                                   interpolation=cv2.INTER_LINEAR)
    else:
        assert False, 'interpolation is wrong'

    return resized_image
Example #22
0
def generate_gaussian_heatmap(input_pts,
                              image_size,
                              std,
                              warning=True,
                              debug=True):
    '''
	generate a heatmap based on the input points array, create a 2-D gaussian with given std around each points provided
	the mask is generated by the occlusion from the point array: only occlusion with -1 will be masked out
	    0 -> invisible points without location
	    1 -> visible points with location
	    -1 -> visible points without location, masked

	parameters:
	    input_pts:          a list of 3 elements, a listoflist of 3 elements: e.g., [[1,2], [5,6], [0, 1]],
	                        a numpy array with shape or (3, N) or (3, )
	    image_size:         a tuple or list of numpy array with 2 elements, representing (height, width)
	    std:                the standard deviation used for gaussian distribution

	outputs:
	    masked_heatmap:         numpy float32 multichannel numpy array, (height, width, num_pts + 1)
	    mask_valid:             numpy float32 multichannel numpy array, (1, 1, num_pts + 1)
	    mask_visible:           numpy float32 multichannel numpy array, (1, 1, num_pts + 1)
	'''
    pts_array = safe_2dptsarray_occlusion(input_pts,
                                          warning=warning,
                                          debug=debug)
    if debug:
        assert isscalar(std), 'the standard deviation should be a scalar'
        assert isimsize(image_size), 'the image size is not correct'
    height, width = image_size[0], image_size[1]
    num_pts, threshold = pts_array.shape[1], 0.01
    heatmap = np.fromfunction( lambda y, x, pts_id : ((x - pts_array[0, pts_id])**2 \
                                                    + (y - pts_array[1, pts_id])**2) \
                                                    / -2.0 / std / std, (height, width, num_pts), dtype=int)
    heatmap = np.exp(heatmap)

    valid = np.logical_or(
        pts_array[2, :] == 0,
        pts_array[2, :] == 1)  # mask out invalid points with -1 in the third
    visible = pts_array[2, :] == 1  # mask out invalid and occuluded points
    mask_valid = np.ones((1, 1, num_pts + 1), dtype='float32')
    mask_valid[0, 0, :num_pts] = valid  # never mask out the background channel
    mask_visible = np.ones((1, 1, num_pts + 1), dtype='float32')
    mask_visible[
        0, 0, :num_pts] = visible  # never mask out the background channel

    # mask out the invalid channel
    heatmap[heatmap < threshold] = 0  # ceiling and flooring
    heatmap[heatmap > 1] = 1
    masked_heatmap = heatmap * mask_valid[:, :, :
                                          num_pts]  # (height, width, num_pts)

    background_label = 1 - np.amax(
        masked_heatmap,
        axis=2)  # (height, width), maximize along the channel axis
    background_label[background_label < 0] = 0  # (height, width, 1)
    masked_heatmap = np.concatenate(
        (masked_heatmap, np.expand_dims(background_label, axis=2)),
        axis=2).astype('float32')

    return masked_heatmap, mask_valid, mask_visible