def mkdir_if_missing(pathname, debug=True):
    pathname = safepath(pathname)
    if debug:
        assert is_path_exists_or_creatable(
            pathname), 'input path is not valid or creatable: %s' % pathname
    dirname, _, _ = fileparts(pathname)

    if not is_path_exists(dirname):
        mkdir_if_missing(dirname)

    if isfolder(pathname) and not is_path_exists(pathname):
        os.mkdir(pathname)
Exemple #2
0
def load_image(src_path, resize_factor=None, target_size=None, input_angle=0, gray=False, warning=True, debug=True):
    '''
    load an image from given path, with preprocessing of resizing and rotating, output a rgb image

    parameters:
        resize_factor:      a scalar
        target_size:        a list or tuple or numpy array with 2 elements, representing height and width
        input_angle:        a scalar, counterclockwise rotation in degree

    output:
        np_image:           an uint8 rgb numpy image
    '''
    src_path = safe_path(src_path, warning=warning, debug=debug)
    if debug: assert is_path_exists(src_path), 'image path is not correct at %s' % src_path
    if resize_factor is None and target_size is None: resize_factor = 1.0           # default not to have resizing

    with open(src_path, 'rb') as f:
        with Image.open(f) as img:
            if gray: img = img.convert('L')
            else: 
                try:
                    img = img.convert('RGB')
                except IOError:
                    print(src_path)
                    zxc
            np_image = image_rotate(img, input_angle=input_angle, warning=warning, debug=debug)
            np_image = image_resize(np_image, resize_factor=resize_factor, target_size=target_size, warning=warning, debug=debug)
    return np_image
Exemple #3
0
def extract_images_from_video_opencv(video_file, save_dir, debug=True):
	'''
	extract a list of images from a video file using opencv package
	Note that if the VideoCapture does not work, uninstall python-opencv and reinstall the newest version

	Parameters:
		video_file:		a file path to a video file
		save_dir:		a folder to save the images extracted from the video
		debug:			boolean, debug mode to check format

	Returns:
	'''
	if debug: assert is_path_exists(video_file), 'the input video file does not exist'
	mkdir_if_missing(save_dir)
	cap = cv2.VideoCapture(video_file)
	frame_id = 0

	while(True):
		ret, frame = cap.read()
		if not ret: break
		save_path = os.path.join(save_dir, 'image%05d.png' % frame_id)
		visualize_image(frame, bgr2rgb=True, save_path=save_path)
		frame_id += 1
		print('processing frame %d' % frame_id)

	cap.release()

	return
Exemple #4
0
def extract_images_from_video_ffmpeg(video_file,
                                     save_dir,
                                     format='frame%06d.png',
                                     startnum=0,
                                     verbose=True,
                                     debug=True):
    '''
	extract a list of images from a video file using system ffmpeg library

	Parameters:
		video_file:		a file path to a video file
		save_dir:		a folder to save the images extracted from the video
		format:			string format for the extracted output images
		verbose:		boolean, display logging information
		debug:			boolean, debug mode to check format

	Returns:
	'''
    if debug:
        assert is_path_exists(
            video_file), 'the input video file does not exist'
    mkdir_if_missing(save_dir)
    if verbose:
        command = 'ffmpeg -i %s -start_number %d %s/%s' % (
            video_file, startnum, save_dir, format)
    else:
        command = 'ffmpeg -loglevel panic -i %s -start_number %d %s/%s' % (
            video_file, startnum, save_dir, format)
    os.system(command)

    return
Exemple #5
0
def extract_images_from_video_ffmpeg(video_file, save_dir, format='frame%06d.png', debug=True):
	'''
	loading the video using the ffmpeg
	'''
	if debug: assert is_path_exists(video_file), 'the input video file does not exist'
	mkdir_if_missing(save_dir)
	command = 'ffmpeg -i %s %s/%s' % (video_file, save_dir, format)
	os.system(command)
Exemple #6
0
def mkdir_if_missing(input_path, warning=True, debug=True):
    '''
	create a directory if not existing:
		1. if the input is a path of file, then create the parent directory of this file
		2. if the root directory does not exists for the input, then create all the root directories recursively until the parent directory of input exists

	parameters:
		input_path:     a string path
	'''
    good_path = safe_path(input_path, warning=warning, debug=debug)
    if debug:
        assert is_path_exists_or_creatable(
            good_path), 'input path is not valid or creatable: %s' % good_path
    dirname, _, _ = fileparts(good_path)
    if not is_path_exists(dirname): mkdir_if_missing(dirname)
    if isfolder(good_path) and not is_path_exists(good_path):
        os.mkdir(good_path)
def load_image(src_path,
               resize_factor=1.0,
               rotate=0,
               mode='numpy',
               debug=True):
    '''
    load an image from given path

    parameters:
        resize_factor:      resize the image (>1 enlarge)
        mode:               numpy or pil, specify the format of returned image
        rotate:             counterclockwise rotation in degree

    output:
        img:                an uint8 rgb image (numpy or pil)
    '''

    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    src_path = safepath(src_path)

    if debug:
        assert is_path_exists(
            src_path), 'txt path is not correct at %s' % src_path
        assert mode == 'numpy' or mode == 'pil', 'the input mode for returned image is not correct'
        assert (isscalar(resize_factor) and resize_factor > 0) or len(
            resize_factor) == 2, 'the resize factor is not correct: {}'.format(
                resize_factor)

    with open(src_path, 'rb') as f:
        with Image.open(f) as img:
            img = img.convert('RGB')

            # rotation
            if rotate != 0:
                img = img.rotate(rotate, expand=True)

            # scaling
            if isscalar(resize_factor):
                width, height = img.size
                img = img.resize(size=(int(width * resize_factor),
                                       int(height * resize_factor)),
                                 resample=Image.BILINEAR)
            elif len(resize_factor) == 2:
                resize_width, resize_height = int(resize_factor[0]), int(
                    resize_factor[1])
                img = img.resize(size=(resize_width, resize_height),
                                 resample=Image.BILINEAR)
            else:
                assert False, 'the resize factor is neither a scalar nor a (width, height)'

            # formating
            if mode == 'numpy':
                img = np.array(img)

    return img
def load_2dmatrix_from_file(src_path,
                            delimiter=' ',
                            dtype='float32',
                            debug=True):
    src_path = safepath(src_path)
    if debug:
        assert is_path_exists(
            src_path), 'txt path is not correct at %s' % src_path

    data = np.loadtxt(src_path, delimiter=delimiter, dtype=dtype)
    nrows = data.shape[0]
    return data, nrows
Exemple #9
0
def load_txt_file(file_path, debug=True):
    '''
    load data or string from text file
    '''
    file_path = safe_path(file_path)
    if debug:
        assert is_path_exists(
            file_path), 'text file is not existing at path: %s!' % file_path
    with open(file_path, 'r') as file:
        data = file.read().splitlines()
    num_lines = len(data)
    file.close()
    return data, num_lines
Exemple #10
0
def load_hdf5_file(hdf5_file, dataname, debug=True):
    '''
    load a single hdf5 file
    '''
    if debug:
        assert is_path_exists(hdf5_file) and isfile(
            hdf5_file), 'input hdf5 path does not exist: %s' % hdf5_file
        assert islist(dataname), 'dataset queried is not correct'
        assert all(
            isstring(dataset_tmp)
            for dataset_tmp in dataname), 'dataset queried is not correct'

    hdf5 = h5py.File(hdf5_file, 'r')
    datadict = dict()
    for dataset in dataname:
        datadict[dataset] = np.array(hdf5[dataset])
    return datadict
Exemple #11
0
	def load(self, save_path):
		if self.debug:
			assert is_path_exists(save_path), '{} is not a file'.format(save_path)

		print('load point meta data from {}'.format(save_path))			
		checkpoint       		= torch.load(save_path)
		self.pts_root 			= checkpoint['pts_root']
		self.pts_forward 		= checkpoint['pts_forward']
		self.pts_backward 		= checkpoint['pts_backward']
		self.pts_anno    		= checkpoint['pts_anno']
		self.image_prev_path 	= checkpoint['image_prev_path']
		self.image_next_path	= checkpoint['image_next_path']

		self.length 			= checkpoint['length']
		self.pts_valid_index 	= checkpoint['pts_valid_index']
		self.index_dict 		= checkpoint['index_dict']
		self.key_list 			= checkpoint['key_list']
Exemple #12
0
def extract_images_from_video_opencv(video_file, save_dir, debug=True):
	'''
	if the VideoCapture does not work, uninstall python-opencv and reinstall the newest version
	'''
	if debug: assert is_path_exists(video_file), 'the input video file does not exist'
	mkdir_if_missing(save_dir)
	cap = cv2.VideoCapture(video_file)
	frame_id = 0

	while(True):
		ret, frame = cap.read()
		if not ret: break
		save_path = os.path.join(save_dir, 'image%05d.png' % frame_id)
		visualize_image(frame, bgr2rgb=True, save_path=save_path)
		frame_id += 1
		print('processing frame %d' % frame_id)

	cap.release()
def load_video_opencv(filename, debug=True):
    '''
    if the VideoCapture does not work, uninstall python-opencv and reinstall the newest version
    '''
    if debug: assert is_path_exists(filename), 'the input video file does not exist'
    cap = cv2.VideoCapture(filename)
    frame_id = 0
    frames = []

    while True:
        ret, image = cap.read()
        if not ret: break

        image = image_bgr2rgb(image)
        image = functional.to_tensor(image)
        frames.append(image)
        frame_id += 1
        if frame_id == num_frames: break

    # frames = np.array(frames)
    return frames
Exemple #14
0
def rand_load_hdf5_from_folder(hdf5_src, dataname, debug=True):
    '''
    randomly load a single hdf5 file from a hdf5 folder
    '''
    if debug:
        assert is_path_exists(hdf5_src) and isfolder(
            hdf5_src), 'input hdf5 path does not exist: %s' % hdf5_src
        assert islist(dataname), 'dataset queried is not correct'
        assert all(
            isstring(dataset_tmp)
            for dataset_tmp in dataname), 'dataset queried is not correct'

    hdf5list, num_hdf5_files = load_list_from_folder(folder_path=hdf5_src,
                                                     ext_filter='.hdf5')
    check_index = random.randrange(0, num_hdf5_files)
    hdf5_path_sample = hdf5list[check_index]
    hdf5_file = h5py.File(hdf5_path_sample, 'r')
    datadict = dict()
    for dataset in dataname:
        datadict[dataset] = np.array(hdf5_file[dataset])
    return datadict
Exemple #15
0
def facial_landmark_evaluation(pred_dict_all, anno_dict, num_pts, error_threshold, normalization_ced=True, normalization_vec=False, covariance=True, display_list=None, debug=True, vis=False, save=True, save_path=None):
	'''
	evaluate the performance of facial landmark detection

	parameter:
		pred_dict_all:	a dictionary for all basline methods. Each key is the method name and the value is corresponding prediction dictionary, 
						which keys are the image path and values are 2 x N prediction results
		anno_dict: 		a dictionary which keys are the image path and values are 2 x N annotation results
		num_pts:		number of points
		vis:			determine if visualizing the pck curve
		save:			determine if saving the visualization results
		save_path:		a directory to save all the results

	visualization:
		1. 2d pck curve (total and point specific) for all points for all methods
		2. point error vector (total and point specific) for all points and for all methods
		3. mean square error

	return:
		metrics_all:	a list of list to have detailed metrics over all methods
		ptswise_mse:	a list of list to have average MSE over all key-points for all methods
	'''
	num_methods = len(pred_dict_all)
	if debug:
		assert isdict(pred_dict_all) and num_methods > 0 and all(isdict(pred_dict) for pred_dict in pred_dict_all.values()), 'predictions result format is not correct'
		assert isdict(anno_dict), 'annotation result format is not correct'
		assert ispositiveinteger(num_pts), 'number of points is not correct'
		assert isscalar(error_threshold), 'error threshold is not correct'
		assert islogical(normalization_ced) and islogical(normalization_vec), 'normalization flag is not correct'
		if display_list is not None:
			assert len(display_list) == num_methods, 'display list is not correct %d vs %d' % (len(display_list), num_methods)

	num_images = len(pred_dict_all.values()[0])
	if debug:
		assert num_images > 0, 'the predictions are empty'
		assert num_images == len(anno_dict), 'number of images is not equal to number of annotations: %d vs %d' % (num_images, len(anno_dict))
		assert all(num_images == len(pred_dict) for pred_dict in pred_dict_all.values()), 'number of images in results from different methods are not equal'

	# calculate normalized mean error for each single image based on point-to-point Euclidean distance normalized by the bounding box size
	# calculate point error vector for each single image based on error vector normalized by the bounding box size
	normed_mean_error_dict = dict()
	normed_mean_error_pts_specific_dict = dict()
	normed_mean_error_pts_specific_valid_dict = dict()
	pts_error_vec_dict = dict()
	pts_error_vec_pts_specific_dict = dict()
	mse_error_dict_dict = dict()
	for method_name, pred_dict in pred_dict_all.items():
		normed_mean_error_total = np.zeros((num_images, ), dtype='float32')
		normed_mean_error_pts_specific = np.zeros((num_images, num_pts), dtype='float32')
		normed_mean_error_pts_specific_valid = np.zeros((num_images, num_pts), dtype='bool')
		pts_error_vec = np.zeros((num_images, 2), dtype='float32')					
		pts_error_vec_pts_specific = np.zeros((num_images, 2, num_pts), dtype='float32')
		mse_error_dict = dict()
		count = 0
		count_skip_num_images = 0						# it's possible that no annotation exists on some images, than no error should be counted for those images, we count the number of those images
		for image_path, pts_prediction in pred_dict.items():
			_, filename, _ = fileparts(image_path)
			pts_anno = anno_dict[filename]				# 2 x N annotation
			pts_keep_index = range(num_pts)

			# to avoid list object type, do conversion here
			if islist(pts_anno):
				pts_anno = np.asarray(pts_anno)
			if islist(pts_prediction):
				pts_prediction = np.asarray(pts_prediction)
			if debug:
				assert (is2dptsarray(pts_anno) or is2dptsarray_occlusion(pts_anno)) and pts_anno.shape[1] == num_pts, 'shape of annotations is not correct (%d x %d) vs (%d x %d)' % (2, num_pts, pts_anno.shape[0], pts_anno.shape[1])

			# if the annotation has 3 channels (include extra occlusion channel, we keep only the points with annotations)
			# occlusion: -1 -> visible but not annotated, 0 -> invisible and not annotated, 1 -> visible, we keep only visible and annotated points
			if pts_anno.shape[0] == 3:	
				pts_keep_index = np.where(pts_anno[2, :] == 1)[0].tolist()
				if len(pts_keep_index) <= 0:		# if no point is annotated in current image
					count_skip_num_images += 1
					continue
				pts_anno = pts_anno[0:2, pts_keep_index]		
				pts_prediction = pts_prediction[:, pts_keep_index]
			
			# to avoid the point location includes the score or occlusion channel, only take the first two channels here	
			if pts_prediction.shape[0] == 3 or pts_prediction.shape[0] == 4:
				pts_prediction = pts_prediction[0:2, :]

			num_pts_tmp = len(pts_keep_index)
			if debug:
				assert pts_anno.shape[1] <= num_pts, 'number of points is not correct: %d vs %d' % (pts_anno.shape[1], num_pts)
				assert pts_anno.shape == pts_prediction.shape, 'shape of annotations and predictions are not the same {} vs {}'.format(print_np_shape(pts_anno, debug=debug), print_np_shape(pts_prediction, debug=debug))
				# print 'number of points to keep is %d' % num_pts_tmp

			# calculate bbox for normalization
			if normalization_ced or normalization_vec:
				assert len(pts_keep_index) == num_pts, 'some points are not annotated. Normalization on PCK curve is not allowed.'
				bbox_anno = pts2bbox(pts_anno, debug=debug)							# 1 x 4
				bbox_TLWH = bbox_TLBR2TLWH(bbox_anno, debug=debug)					# 1 x 4
				bbox_size = math.sqrt(bbox_TLWH[0, 2] * bbox_TLWH[0, 3])			# scalar
			
			# calculate normalized error for all points
			normed_mean_error, _ = pts_euclidean(pts_prediction, pts_anno, debug=debug)	# scalar
			if normalization_ced:
				normed_mean_error /= bbox_size
			normed_mean_error_total[count] = normed_mean_error
			mse_error_dict[image_path] = normed_mean_error

			if normed_mean_error == 0:
				print pts_prediction
				print pts_anno

			# calculate normalized error point specifically
			for pts_index in xrange(num_pts):
				if pts_index in pts_keep_index:			# if current point not annotated in current image, just keep 0
					normed_mean_error_pts_specific_valid[count, pts_index] = True
				else:
					continue

				pts_index_from_keep_list = pts_keep_index.index(pts_index)
				pts_prediction_tmp = np.reshape(pts_prediction[:, pts_index_from_keep_list], (2, 1))
				pts_anno_tmp = np.reshape(pts_anno[:, pts_index_from_keep_list], (2, 1))
				normed_mean_error_pts_specifc_tmp, _ = pts_euclidean(pts_prediction_tmp, pts_anno_tmp, debug=debug)

				if normalization_ced:
					normed_mean_error_pts_specifc_tmp /= bbox_size
				normed_mean_error_pts_specific[count, pts_index] = normed_mean_error_pts_specifc_tmp

			# calculate the point error vector
			error_vector = pts_prediction - pts_anno 			# 2 x num_pts_tmp
			if normalization_vec:
				error_vector /= bbox_size
			pts_error_vec_pts_specific[count, :, pts_keep_index] = np.transpose(error_vector)
			pts_error_vec[count, :] = np.sum(error_vector, axis=1) / num_pts_tmp

			count += 1

		print 'number of skipped images is %d' % count_skip_num_images
		assert count + count_skip_num_images == num_images, 'all cells in the array must be filled %d vs %d' % (count + count_skip_num_images, num_images)
		# print normed_mean_error_total
		# time.sleep(1000)

		# save results to dictionary
		normed_mean_error_dict[method_name] = normed_mean_error_total[:count]
		normed_mean_error_pts_specific_dict[method_name] = normed_mean_error_pts_specific[:count, :]
		normed_mean_error_pts_specific_valid_dict[method_name] = normed_mean_error_pts_specific_valid[:count, :]
		pts_error_vec_dict[method_name] = np.transpose(pts_error_vec[:count, :])												# 2 x num_images
		pts_error_vec_pts_specific_dict[method_name] = pts_error_vec_pts_specific[:count, :, :]
		mse_error_dict_dict[method_name] = mse_error_dict

	# calculate mean value
	if mse:
		mse_value = dict()		# dictionary to record all average MSE for different methods
		mse_dict = dict()		# dictionary to record all point-wise MSE for different keypoints
		for method_name, error_array in normed_mean_error_dict.items():
			mse_value[method_name] = np.mean(error_array)
	else:
		mse_value = None

	# save mse error list to file for each method
	error_list_savedir = os.path.join(save_path, 'error_list')
	mkdir_if_missing(error_list_savedir)
	for method_name, mse_error_dict in mse_error_dict_dict.items():
		mse_error_list_path = os.path.join(error_list_savedir, 'error_%s.txt' % method_name)
		mse_error_list = open(mse_error_list_path, 'w')
		
		sorted_tuple_list = sorted(mse_error_dict.items(), key=operator.itemgetter(1), reverse=True)
		for tuple_index in range(len(sorted_tuple_list)):
			image_path_tmp = sorted_tuple_list[tuple_index][0]
			mse_error_tmp = sorted_tuple_list[tuple_index][1]
			mse_error_list.write('{:<200} {}\n'.format(image_path_tmp, '%.2f' % mse_error_tmp))
		mse_error_list.close()
		print '\nsave mse error list for %s to %s' % (method_name, mse_error_list_path)

	# visualize the ced (cumulative error distribution curve)
	print('visualizing pck curve....\n')
	pck_savedir = os.path.join(save_path, 'pck')
	mkdir_if_missing(pck_savedir)
	pck_savepath = os.path.join(pck_savedir, 'pck_curve_overall.png')
	table_savedir = os.path.join(save_path, 'metrics')
	mkdir_if_missing(table_savedir)
	table_savepath = os.path.join(table_savedir, 'detailed_metrics_overall.txt')
	_, metrics_all = visualize_ced(normed_mean_error_dict, error_threshold=error_threshold, normalized=normalization_ced, truncated_list=truncated_list, title='2D PCK curve (all %d points)' % num_pts, display_list=display_list, debug=debug, vis=vis, pck_savepath=pck_savepath, table_savepath=table_savepath)
	metrics_title = ['Method Name / Point Index']
	ptswise_mse_table = [[normed_mean_error_pts_specific_dict.keys()[index_tmp]] for index_tmp in xrange(num_methods)]
	for pts_index in xrange(num_pts):
		metrics_title.append(str(pts_index + 1))
		normed_mean_error_dict_tmp = dict()

		for method_name, error_array in normed_mean_error_pts_specific_dict.items():
			normed_mean_error_pts_specific_valid_temp = normed_mean_error_pts_specific_valid_dict[method_name]
			
			# Some points at certain images might not be annotated. When calculating MSE for these specific point, we remove those images to avoid "false" mean average error
			valid_array_per_pts_per_method = np.where(normed_mean_error_pts_specific_valid_temp[:, pts_index] == True)[0].tolist()
			error_array_per_pts = error_array[:, pts_index]
			error_array_per_pts = error_array_per_pts[valid_array_per_pts_per_method]
			num_image_tmp = len(valid_array_per_pts_per_method)
			normed_mean_error_dict_tmp[method_name] = np.reshape(error_array_per_pts, (num_image_tmp, ))
		pck_savepath = os.path.join(pck_savedir, 'pck_curve_pts_%d.png' % (pts_index+1))
		table_savepath = os.path.join(table_savedir, 'detailed_metrics_pts_%d.txt' % (pts_index+1))
		metrics_dict, _ = visualize_ced(normed_mean_error_dict_tmp, error_threshold=error_threshold, normalized=normalization_ced, truncated_list=truncated_list, display2terminal=False, title='2D PCK curve for point %d' % (pts_index+1), display_list=display_list, debug=debug, vis=vis, pck_savepath=pck_savepath, table_savepath=table_savepath)
		for method_index in range(num_methods):
			method_name = normed_mean_error_pts_specific_dict.keys()[method_index]
			ptswise_mse_table[method_index].append('%.1f' % metrics_dict[method_name]['MSE'])
	
	# reorder the table
	order_index_list = [display_list.index(method_name_tmp) for method_name_tmp in normed_mean_error_pts_specific_dict.keys()]
	order_index_list = [0] + [order_index_tmp + 1 for order_index_tmp in order_index_list]

	# print table to terminal
	ptswise_mse_table = list_reorder([metrics_title] + ptswise_mse_table, order_index_list, debug=debug)
	table = AsciiTable(ptswise_mse_table)
	print '\nprint point-wise average MSE'
	print table.table
	# save table to file
	ptswise_savepath = os.path.join(table_savedir, 'pointwise_average_MSE.txt')
	table_file = open(ptswise_savepath, 'w')
	table_file.write(table.table)
	table_file.close()
	print '\nsave point-wise average MSE to %s' % ptswise_savepath

	# visualize the error vector map
	print('visualizing error vector distribution map....\n')
	error_vec_save_dir = os.path.join(save_path, 'error_vec')
	mkdir_if_missing(error_vec_save_dir)
	savepath_tmp = os.path.join(error_vec_save_dir, 'error_vector_distribution_all.png')
	visualize_pts(pts_error_vec_dict, title='Point Error Vector Distribution (all %d points)' % num_pts, mse=mse, mse_value=mse_value, display_range=display_range, display_list=display_list, xlim=xlim, ylim=ylim, covariance=covariance, debug=debug, vis=vis, save_path=savepath_tmp)
	for pts_index in xrange(num_pts):
		pts_error_vec_pts_specific_dict_tmp = dict()
		for method_name, error_vec_dict in pts_error_vec_pts_specific_dict.items():
			pts_error_vec_pts_specific_valid = normed_mean_error_pts_specific_valid_dict[method_name]		# get valid flag
			valid_image_index_per_pts = np.where(pts_error_vec_pts_specific_valid[:, pts_index] == True)[0].tolist()		# get images where the points with current index are annotated

			pts_error_vec_pts_specific_dict_tmp[method_name] = np.transpose(error_vec_dict[valid_image_index_per_pts, :, pts_index])		# 2 x num_images
		savepath_tmp = os.path.join(error_vec_save_dir, 'error_vector_distribution_pts_%d.png' % (pts_index+1))
		if mse:
			mse_dict_tmp = visualize_pts(pts_error_vec_pts_specific_dict_tmp, title='Point Error Vector Distribution for Point %d' % (pts_index+1), mse=mse, display_range=display_range, display_list=display_list, xlim=xlim, ylim=ylim, covariance=covariance, debug=debug, vis=vis, save_path=savepath_tmp)
			mse_best = min(mse_dict_tmp.values())
			mse_single = dict()
			mse_single['mse'] = mse_best
			mse_single['num_images'] = len(valid_image_index_per_pts)			# assume number of valid images is equal for all methods
			mse_dict[pts_index] = mse_single
		else:
			visualize_pts(pts_error_vec_pts_specific_dict_tmp, title='Point Error Vector Distribution for Point %d' % (pts_index+1), mse=mse, display_range=display_range, display_list=display_list, xlim=xlim, ylim=ylim, covariance=covariance, debug=debug, vis=vis, save_path=savepath_tmp)

	# save mse to json file for further use
	if mse:
		json_path = os.path.join(save_path, 'mse_pts.json')

		# if existing, compare and select the best
		if is_path_exists(json_path):
			with open(json_path, 'r') as file:
				mse_dict_old = json.load(file)
				file.close()

			for pts_index, mse_single in mse_dict_old.items():
				mse_dict_new = mse_dict[int(pts_index)]
				mse_new = mse_dict_new['mse']
				if mse_new < mse_single['mse']:
					mse_single['mse'] = mse_new
				mse_dict_old[pts_index] = mse_single

			with open(json_path, 'w') as file:
				print('overwrite old mse to {}'.format(json_path))
				json.dump(mse_dict_old, file)
				file.close()

		else:
			with open(json_path, 'w') as file:
				print('save mse for all keypoings to {}'.format(json_path))
				json.dump(mse_dict, file)
				file.close()

	print('\ndone!!!!!\n')

	return metrics_all, ptswise_mse_table
Exemple #16
0
def visualize_nearest_neighbor(featuremap_dict,
                               num_neighbor=5,
                               top_number=5,
                               vis=True,
                               save_csv=False,
                               csv_save_path=None,
                               save_vis=False,
                               save_img=False,
                               save_thumb_name='nearest_neighbor.png',
                               img_src_folder=None,
                               ext_filter='.jpg',
                               nn_save_folder=None,
                               debug=True):
    '''
    visualize nearest neighbor for featuremap from images

    parameter:
        featuremap_dict: a dictionary contains image path as key, and featuremap as value, the featuremap needs to be numpy array with any shape. No flatten needed
        num_neighbor: number of neighbor to visualize, the first nearest is itself
        top_number: number of top to visualize, since there might be tons of featuremap (length of dictionary), we choose the top ten with lowest distance with their nearest neighbor
        csv_save_path: path to save .csv file which contains indices and distance array for all elements
        nn_save_folder: save the nearest neighbor images for top featuremap

    return:
        all_sorted_nearest_id: a 2d matrix, each row is a feature followed by its nearest neighbor in whole feature dataset, the column is sorted by the distance of all nearest neighbor each row
        selected_nearest_id: only top number of sorted nearest id 
    '''
    print('processing feature map to nearest neightbor.......')
    if debug:
        assert isdict(featuremap_dict), 'featuremap should be dictionary'
        assert all(
            isnparray(featuremap_tmp) for featuremap_tmp in featuremap_dict.
            values()), 'value of dictionary should be numpy array'
        assert isinteger(
            num_neighbor
        ) and num_neighbor > 1, 'number of neighborhodd is an integer larger than 1'
        if save_csv and csv_save_path is not None:
            assert is_path_exists_or_creatable(
                csv_save_path), 'path to save .csv file is not correct'

        if save_vis or save_img:
            if nn_save_folder is not None:  # save image directly
                assert isstring(ext_filter), 'extension filter is not correct'
                assert is_path_exists(
                    img_src_folder), 'source folder for image is not correct'
                assert all(
                    isstring(path_tmp) for path_tmp in featuremap_dict.keys()
                )  # key should be the path for the image
                assert is_path_exists_or_creatable(
                    nn_save_folder
                ), 'folder to save top visualized images is not correct'
                assert isstring(
                    save_thumb_name), 'name of thumbnail is not correct'

    if ext_filter.find('.') == -1:
        ext_filter = '.%s' % ext_filter

    # flatten the feature map
    nn_feature_dict = dict()
    for key, featuremap_tmp in featuremap_dict.items():
        nn_feature_dict[key] = featuremap_tmp.flatten()
    num_features = len(nn_feature_dict)

    # nearest neighbor
    featuremap = np.array(nn_feature_dict.values())
    nearbrs = NearestNeighbors(n_neighbors=num_neighbor,
                               algorithm='ball_tree').fit(featuremap)
    distances, indices = nearbrs.kneighbors(featuremap)

    if debug:
        assert featuremap.shape[
            0] == num_features, 'shape of feature map is not correct'
        assert indices.shape == (
            num_features, num_neighbor), 'shape of indices is not correct'
        assert distances.shape == (
            num_features, num_neighbor), 'shape of indices is not correct'

    # convert the nearest indices for all featuremap to the key accordingly
    id_list = nn_feature_dict.keys()
    max_length = len(max(
        id_list, key=len))  # find the maximum length of string in the key
    nearest_id = np.chararray(indices.shape, itemsize=max_length + 1)
    for x in range(nearest_id.shape[0]):
        for y in range(nearest_id.shape[1]):
            nearest_id[x, y] = id_list[indices[x, y]]

    if debug:
        assert list(nearest_id[:,
                               0]) == id_list, 'nearest neighbor has problem'

    # sort the feature based on distance
    print('sorting the feature based on distance')
    featuremap_distance = np.sum(distances, axis=1)
    if debug:
        assert featuremap_distance.shape == (
            num_features, ), 'distance is not correct'
    sorted_indices = np.argsort(featuremap_distance)
    all_sorted_nearest_id = nearest_id[sorted_indices, :]

    # save to the csv file
    if save_csv and csv_save_path is not None:
        print('Saving nearest neighbor result as .csv to path: %s' %
              csv_save_path)
        with open(csv_save_path, 'w+') as file:
            np.savetxt(file, distances, delimiter=',', fmt='%f')
            np.savetxt(file, all_sorted_nearest_id, delimiter=',', fmt='%s')
            file.close()

    # choose the best to visualize
    selected_sorted_indices = sorted_indices[0:top_number]
    if debug:
        for i in range(num_features - 1):
            assert featuremap_distance[
                sorted_indices[i]] < featuremap_distance[sorted_indices[
                    i + 1]], 'feature map is not well sorted based on distance'
    selected_nearest_id = nearest_id[selected_sorted_indices, :]

    if save_vis:
        fig, axarray = plt.subplots(top_number, num_neighbor)
        for index in range(top_number):
            for nearest_index in range(num_neighbor):
                img_path = os.path.join(
                    img_src_folder, '%s%s' %
                    (selected_nearest_id[index, nearest_index], ext_filter))
                if debug:
                    print('loading image from %s' % img_path)
                img = imread(img_path)
                if isgrayimage_dimension(img):
                    axarray[index, nearest_index].imshow(img, cmap='gray')
                elif iscolorimage_dimension(img):
                    axarray[index, nearest_index].imshow(img)
                else:
                    assert False, 'unknown error'
                axarray[index, nearest_index].axis('off')
        save_thumb = os.path.join(nn_save_folder, save_thumb_name)
        fig.savefig(save_thumb)
        if vis:
            plt.show()
        plt.close(fig)

    # save top visualization to the folder
    if save_img and nn_save_folder is not None:
        for top_index in range(top_number):
            file_list = selected_nearest_id[top_index]
            save_subfolder = os.path.join(nn_save_folder, file_list[0])
            mkdir_if_missing(save_subfolder)
            for file_tmp in file_list:
                file_src = os.path.join(img_src_folder,
                                        '%s%s' % (file_tmp, ext_filter))
                save_path = os.path.join(save_subfolder,
                                         '%s%s' % (file_tmp, ext_filter))
                if debug:
                    print('saving %s to %s' % (file_src, save_path))
                shutil.copyfile(file_src, save_path)

    return all_sorted_nearest_id, selected_nearest_id
Exemple #17
0
def generate_hdf5(data_src,
                  save_dir,
                  data_name='data',
                  batch_size=1,
                  ext_filter='png',
                  label_src1=None,
                  label_name1='label',
                  label_preprocess_function1=identity,
                  label_range1=None,
                  label_src2=None,
                  label_name2='label2',
                  label_preprocess_function2=identity,
                  label_range2=None,
                  debug=True,
                  vis=False):
    '''
    # this function creates data in hdf5 format from a image path 

    # input parameter
    #   data_src:       source of image data, which can be a list of image path, a txt file contains a list of image path, a folder contains a set of images, a list of numpy array image data
    #   label_src:      source of label data, which can be none, a file contains a set of labels, a dictionary of labels, a 1-d numpy array data, a list of label data
    #   save_dir:       where to store the hdf5 data
    #   batch_size:     how many image to store in a single hdf file
    #   ext_filder:     what format of data to use for generating hdf5 data 
    '''

    # parse input
    assert is_path_exists_or_creatable(
        save_dir), 'save path should be a folder to save all hdf5 files'
    mkdir_if_missing(save_dir)
    assert isstring(
        data_name), 'dataset name is not correct'  # name for hdf5 data

    # convert data source to a list of numpy array image data
    if isfolder(data_src):
        print 'data is loading from %s with extension .%s' % (data_src,
                                                              ext_filter)
        filelist, num_data = load_list_from_folder(data_src,
                                                   ext_filter=ext_filter)
        datalist = None
    elif isfile(data_src):
        print 'data is loading from %s with extension .%s' % (data_src,
                                                              ext_filter)
        filelist, num_data = load_list_from_file(data_src)
        datalist = None
    elif islist(data_src):
        if debug:
            assert all(
                isimage(data_tmp) for data_tmp in data_src
            ), 'input data source is not a list of numpy array image data'
        datalist = data_src
        num_data = len(datalist)
        filelist = None
    else:
        assert False, 'data source format is not correct.'
    if debug:
        assert (datalist is None and filelist is not None) or (
            filelist is None and datalist is not None), 'data is not correct'
        if datalist is not None:
            assert len(datalist) == num_data, 'number of data is not equal'
        if filelist is not None:
            assert len(filelist) == num_data, 'number of data is not equal'

    # convert label source to a list of numpy array label
    if label_src1 is None:
        labeldict1 = None
        labellist1 = None
    elif isfile(label_src1):
        assert is_path_exists(label_src1), 'file not found'
        _, _, ext = fileparts(label_src1)
        assert ext == '.json', 'only json extension is supported'
        labeldict1 = json.load(label_src1)
        num_label1 = len(labeldict1)
        assert num_data == num_label1, 'number of data and label is not equal.'
        labellist1 = None
    elif isdict(label_src1):
        labeldict1 = label_src1
        labellist1 = None
    elif isnparray(label_src1):
        if debug:
            assert label_src1.ndim == 1, 'only 1-d label is supported'
        labeldict1 = None
        labellist1 = label_src1
    elif islist(label_src1):
        if debug:
            assert all(
                np.array(label_tmp).size == 1
                for label_tmp in label_src1), 'only 1-d label is supported'
        labellist1 = label_src1
        labeldict1 = None
    else:
        assert False, 'label source format is not correct.'
    assert isfunction(label_preprocess_function1
                      ), 'label preprocess function is not correct.'

    # convert label source to a list of numpy array label
    if label_src2 is None:
        labeldict2 = None
        labellist2 = None
    elif isfile(label_src2):
        assert is_path_exists(label_src2), 'file not found'
        _, _, ext = fileparts(label_src2)
        assert ext == '.json', 'only json extension is supported'
        labeldict2 = json.load(label_src2)
        num_label2 = len(labeldict2)
        assert num_data == num_label2, 'number of data and label is not equal.'
        labellist2 = None
    elif isdict(label_src2):
        labeldict2 = label_src2
        labellist2 = None
    elif isnparray(label_src2):
        if debug:
            assert label_src2.ndim == 1, 'only 1-d label is supported'
        labeldict2 = None
        labellist2 = label_src2
    elif islist(label_src2):
        if debug:
            assert all(
                np.array(label_tmp).size == 1
                for label_tmp in label_src2), 'only 1-d label is supported'
        labellist2 = label_src2
        labeldict2 = None
    else:
        assert False, 'label source format is not correct.'
    assert isfunction(label_preprocess_function2
                      ), 'label preprocess function is not correct.'

    # warm up
    if datalist is not None:
        size_data = datalist[0].shape
    else:
        size_data = imread(filelist[0]).shape

    if labeldict1 is not None:
        if debug:
            assert isstring(label_name1), 'label name is not correct'
        labels1 = np.zeros((batch_size, 1), dtype='float32')
        # label_value1 = [float(label_tmp_char) for label_tmp_char in labeldict1.values()]
        # label_range1 = np.array([min(label_value1), max(label_value1)])
    if labellist1 is not None:
        labels1 = np.zeros((batch_size, 1), dtype='float32')
        # label_range1 = [np.min(labellist1), np.max(labellist1)]
    if label_src1 is not None and debug:
        assert label_range1 is not None, 'label range is not correct'
        assert (labeldict1 is not None and labellist1 is None) or (
            labellist1 is not None
            and labeldict1 is None), 'label is not correct'

    if labeldict2 is not None:
        if debug:
            assert isstring(label_name2), 'label name is not correct'
        labels2 = np.zeros((batch_size, 1), dtype='float32')
        # label_value2 = [float(label_tmp_char) for label_tmp_char in labeldict2.values()]
        # label_range2 = np.array([min(label_value2), max(label_value2)])
    if labellist2 is not None:
        labels2 = np.zeros((batch_size, 1), dtype='float32')
        # label_range2 = [np.min(labellist2), np.max(labellist2)]
    if label_src2 is not None and debug:
        assert label_range2 is not None, 'label range is not correct'
        assert (labeldict2 is not None and labellist2 is None) or (
            labellist2 is not None
            and labeldict2 is None), 'label is not correct'

    # start generating
    count_hdf = 1  # count number of hdf5 file
    clock = Timer()
    datalist_batch = list()
    for i in xrange(num_data):
        clock.tic()
        if filelist is not None:
            imagefile = filelist[i]
            _, name, _ = fileparts(imagefile)
            img = imread(imagefile).astype('float32')
            max_value = np.max(img)
            if max_value > 1 and max_value <= 255:
                img = img / 255.0  # [rows,col,channel,numbers], scale the image data to (0, 1)
            if debug:
                min_value = np.min(img)
                assert min_value >= 0 and min_value <= 1, 'data is not in [0, 1]'
        if datalist is not None:
            img = datalist[i]
        if debug:
            assert size_data == img.shape
        datalist_batch.append(img)

        # process label
        if labeldict1 is not None:
            if debug:
                assert len(filelist) == len(
                    labeldict1), 'file list is not equal to label dictionary'

            labels1[i % batch_size, 0] = float(labeldict1[name])
        if labellist1 is not None:
            labels1[i % batch_size, 0] = float(labellist1[i])
        if labeldict2 is not None:
            if debug:
                assert len(filelist) == len(
                    labeldict2), 'file list is not equal to label dictionary'
            labels2[i % batch_size, 0] = float(labeldict2[name])
        if labellist2 is not None:
            labels2[i % batch_size, 0] = float(labellist2[i])

        # save to hdf5
        if i % batch_size == 0:
            data = preprocess_image_caffe(
                datalist_batch, debug=debug, vis=vis
            )  # swap channel, transfer from list of HxWxC to NxCxHxW

            # write to hdf5 format
            if filelist is not None:
                save_path = os.path.join(save_dir, '%s.hdf5' % name)
            else:
                save_path = os.path.join(save_dir,
                                         'image_%010d.hdf5' % count_hdf)
            h5f = h5py.File(save_path, 'w')
            h5f.create_dataset(data_name, data=data, dtype='float32')
            if (labeldict1 is not None) or (labellist1 is not None):
                # print(labels1)
                labels1 = label_preprocess_function1(data=labels1,
                                                     data_range=label_range1,
                                                     debug=debug)
                # print(labels1)
                h5f.create_dataset(label_name1, data=labels1, dtype='float32')
                labels1 = np.zeros((batch_size, 1), dtype='float32')

            if (labeldict2 is not None) or (labellist2 is not None):
                labels2 = label_preprocess_function2(data=labels2,
                                                     data_range=label_range2,
                                                     debug=debug)
                h5f.create_dataset(label_name2, data=labels2, dtype='float32')
                labels2 = np.zeros((batch_size, 1), dtype='float32')

            h5f.close()
            count_hdf = count_hdf + 1
            del datalist_batch[:]
            if debug:
                assert len(datalist_batch) == 0, 'list has not been cleared'
        average_time = clock.toc()
        print(
            'saving to %s: %d/%d, average time:%.3f, elapsed time:%s, estimated time remaining:%s'
            % (save_path, i + 1, num_data, average_time,
               convert_secs2time(average_time * i),
               convert_secs2time(average_time * (num_data - i))))

    return count_hdf - 1, num_data
def load_list_from_folder(folder_path,
                          ext_filter=None,
                          depth=1,
                          recursive=False,
                          sort=True,
                          save_path=None,
                          debug=True):
    '''
    load a list of files or folders from a system path

    parameters:
        folder_path:    root to search 
        ext_filter:     a string to represent the extension of files interested
        depth:          maximum depth of folder to search, when it's None, all levels of folders will be searched
        recursive:      False: only return current level
                        True: return all levels till to the input depth

    outputs:
        fulllist:       a list of elements
        num_elem:       number of the elements
    '''
    folder_path = safepath(folder_path)
    if debug:
        assert isfolder(
            folder_path), 'input folder path is not correct: %s' % folder_path
    if not is_path_exists(folder_path): return [], 0

    if debug:
        assert islogical(
            recursive), 'recursive should be a logical variable: {}'.format(
                recursive)
        assert depth is None or (
            isinteger(depth)
            and depth >= 1), 'input depth is not correct {}'.format(depth)
        assert ext_filter is None or (islist(ext_filter) and all(
            isstring(ext_tmp) for ext_tmp in ext_filter)) or isstring(
                ext_filter), 'extension filter is not correct'
    if isstring(ext_filter):  # convert to a list
        ext_filter = [ext_filter]

    fulllist = list()
    if depth is None:  # find all files recursively
        recursive = True
        wildcard_prefix = '**'
        if ext_filter is not None:
            for ext_tmp in ext_filter:
                wildcard = os.path.join(wildcard_prefix,
                                        '*' + string2ext_filter(ext_tmp))
                curlist = glob2.glob(os.path.join(folder_path, wildcard))
                if sort:
                    curlist = sorted(curlist)
                fulllist += curlist

        else:
            wildcard = wildcard_prefix
            curlist = glob2.glob(os.path.join(folder_path, wildcard))
            if sort:
                curlist = sorted(curlist)
            fulllist += curlist
    else:  # find files based on depth and recursive flag
        wildcard_prefix = '*'
        for index in range(depth - 1):
            wildcard_prefix = os.path.join(wildcard_prefix, '*')
        if ext_filter is not None:
            for ext_tmp in ext_filter:
                wildcard = wildcard_prefix + string2ext_filter(ext_tmp)
                curlist = glob.glob(os.path.join(folder_path, wildcard))
                if sort:
                    curlist = sorted(curlist)
                fulllist += curlist
        else:
            wildcard = wildcard_prefix
            curlist = glob.glob(os.path.join(folder_path, wildcard))
            if sort:
                curlist = sorted(curlist)
            fulllist += curlist
        if recursive and depth > 1:
            newlist, _ = load_list_from_folder(folder_path=folder_path,
                                               ext_filter=ext_filter,
                                               depth=depth - 1,
                                               recursive=True)
            fulllist += newlist

    fulllist = [os.path.normpath(path_tmp) for path_tmp in fulllist]
    num_elem = len(fulllist)

    # save list to a path
    if save_path is not None:
        save_path = safepath(save_path)
        if debug:
            assert is_path_exists_or_creatable(
                save_path), 'the file cannot be created'
        with open(save_path, 'w') as file:
            for item in fulllist:
                file.write('%s\n' % item)
        file.close()

    return fulllist, num_elem