dst_dir = params['dst_dir'] images_ext = params['images_ext'] labels_ext = params['labels_ext'] n_classes = params['n_classes'] n_indices = params['n_indices'] start_id = params['start_id'] end_id = params['end_id'] copy_images = params['copy_images'] images_path = os.path.join(db_root_dir, src_dir, 'images') labels_path = os.path.join(db_root_dir, src_dir, 'labels') images_path = os.path.abspath(images_path) labels_path = os.path.abspath(labels_path) src_files, src_labels_list, total_frames = read_data( images_path, images_ext, labels_path, labels_ext) if start_id < 0: if end_id < 0: raise AssertionError( 'end_id must be non negative for random selection') elif end_id >= total_frames: raise AssertionError( 'end_id must be less than total_frames for random selection') print('Using {} random images for selection'.format(end_id + 1)) img_ids = np.random.choice(total_frames, end_id + 1, replace=False) else: if end_id < start_id: end_id = total_frames - 1 print('Using all {} images for selection'.format(end_id - start_id + 1)) img_ids = range(start_id, end_id + 1)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--log_dir", type=str) parser.add_argument("--images_path", type=str) parser.add_argument("--images_ext", type=str, default='png') parser.add_argument("--labels_path", type=str, default='') parser.add_argument("--labels_ext", type=str, default='png') parser.add_argument("--labels_col", type=str, default='green') parser.add_argument("--seg_paths", type=str_to_list, default=[]) parser.add_argument("--seg_ext", type=str, default='png') parser.add_argument("--seg_root_dir", type=str, default='') parser.add_argument("--seg_labels", type=str_to_list, default=[]) parser.add_argument( "--seg_cols", type=str_to_list, default=['blue', 'forest_green', 'magenta', 'cyan', 'red']) parser.add_argument("--out_path", type=str, default='') parser.add_argument("--out_ext", type=str, default='jpg') parser.add_argument("--out_size", type=str, default='1920x1080') parser.add_argument("--fps", type=float, default=30) parser.add_argument("--codec", type=str, default='H264') parser.add_argument("--save_path", type=str, default='') parser.add_argument("--n_classes", type=int) parser.add_argument("--save_stitched", type=int, default=0) parser.add_argument("--load_ice_conc_diff", type=int, default=0) parser.add_argument("--start_id", type=int, default=0) parser.add_argument("--end_id", type=int, default=-1) parser.add_argument("--show_img", type=int, default=0) parser.add_argument("--stitch", type=int, default=0) parser.add_argument("--stitch_seg", type=int, default=1) parser.add_argument("--plot_changed_seg_count", type=int, default=0) parser.add_argument("--normalize_labels", type=int, default=0) parser.add_argument("--selective_mode", type=int, default=0) parser.add_argument("--ice_type", type=int, default=0, help='0: combined, 1: anchor, 2: frazil') parser.add_argument("--enable_plotting", type=int, default=1, help='enable_plotting') args = parser.parse_args() images_path = args.images_path images_ext = args.images_ext labels_path = args.labels_path labels_ext = args.labels_ext labels_col = args.labels_col seg_paths = args.seg_paths seg_root_dir = args.seg_root_dir seg_ext = args.seg_ext out_path = args.out_path out_ext = args.out_ext out_size = args.out_size fps = args.fps codec = args.codec # save_path = args.save_path n_classes = args.n_classes end_id = args.end_id start_id = args.start_id show_img = args.show_img stitch = args.stitch stitch_seg = args.stitch_seg save_stitched = args.save_stitched normalize_labels = args.normalize_labels selective_mode = args.selective_mode seg_labels = args.seg_labels seg_cols = args.seg_cols ice_type = args.ice_type plot_changed_seg_count = args.plot_changed_seg_count load_ice_conc_diff = args.load_ice_conc_diff enable_plotting = args.enable_plotting ice_types = { 0: 'Ice', 1: 'Anchor Ice', 2: 'Frazil Ice', } loc = (5, 120) size = 8 thickness = 6 fgr_col = (255, 255, 255) bgr_col = (0, 0, 0) font_id = 0 video_exts = ['mp4', 'mkv', 'avi', 'mpg', 'mpeg', 'mjpg'] labels_col_rgb = col_bgr[labels_col] seg_cols_rgb = [col_bgr[seg_col] for seg_col in seg_cols] ice_type_str = ice_types[ice_type] print('ice_type_str: {}'.format(ice_type_str)) src_files, src_labels_list, total_frames = read_data( images_path, images_ext, labels_path, labels_ext) if end_id < start_id: end_id = total_frames - 1 if seg_paths: n_seg_paths = len(seg_paths) n_seg_labels = len(seg_labels) if n_seg_paths != n_seg_labels: raise IOError( 'Mismatch between n_seg_labels: {} and n_seg_paths: {}'.format( n_seg_labels, n_seg_paths)) if seg_root_dir: seg_paths = [ os.path.join(seg_root_dir, name) for name in seg_paths ] if not out_path: if labels_path: out_path = labels_path + '_conc' elif seg_paths: out_path = seg_paths[0] + '_conc' if not os.path.isdir(out_path): os.makedirs(out_path) # print('Saving results data to {}'.format(out_path)) # if not save_path: # save_path = os.path.join(os.path.dirname(images_path), 'ice_concentration') # if not os.path.isdir(save_path): # os.makedirs(save_path) # if stitch and save_stitched: # print('Saving ice_concentration plots to: {}'.format(save_path)) # log_fname = os.path.join(out_path, 'vis_log_{:s}.txt'.format(getDateTime())) # print('Saving log to: {}'.format(log_fname)) if selective_mode: label_diff = int(255.0 / n_classes) else: label_diff = int(255.0 / (n_classes - 1)) print('label_diff: {}'.format(label_diff)) n_frames = end_id - start_id + 1 print_diff = int(n_frames * 0.01) labels_img = None n_cols = len(seg_cols_rgb) plot_y_label = '{} concentration (%)'.format(ice_type_str) plot_x_label = 'distance in pixels from left edge' dists = {} for _label in seg_labels: dists[_label] = { # 'bhattacharyya': [], 'euclidean': [], 'mae': [], 'mse': [], # 'frobenius': [], } plot_title = '{} concentration'.format(ice_type_str) out_size = tuple([int(x) for x in out_size.split('x')]) write_to_video = out_ext in video_exts out_width, out_height = out_size out_seq_name = os.path.basename(out_path) if enable_plotting: if write_to_video: stitched_seq_path = os.path.join( out_path, '{}.{}'.format(out_seq_name, out_ext)) print('Writing {}x{} output video to: {}'.format( out_width, out_height, stitched_seq_path)) save_dir = os.path.dirname(stitched_seq_path) fourcc = cv2.VideoWriter_fourcc(*codec) video_out = cv2.VideoWriter(stitched_seq_path, fourcc, fps, out_size) else: stitched_seq_path = os.path.join(out_path, out_seq_name) print('Writing {}x{} output images of type {} to: {}'.format( out_width, out_height, out_ext, stitched_seq_path)) save_dir = stitched_seq_path if save_dir and not os.path.isdir(save_dir): os.makedirs(save_dir) prev_seg_img = {} prev_conc_data_y = {} changed_seg_count = {} ice_concentration_diff = {} if load_ice_conc_diff: for seg_id in seg_labels: ice_concentration_diff[seg_id] = np.loadtxt(os.path.join( out_path, '{}_ice_concentration_diff.txt'.format(seg_id)), dtype=np.float64) _pause = 0 mae_data_y = [] for seg_id, _ in enumerate(seg_paths): mae_data_y.append([]) for img_id in range(start_id, end_id + 1): start_t = time.time() # img_fname = '{:s}_{:d}.{:s}'.format(fname_templ, img_id + 1, img_ext) img_fname = src_files[img_id] img_fname_no_ext = os.path.splitext(img_fname)[0] src_img_fname = os.path.join(images_path, img_fname) src_img = imread(src_img_fname) if src_img is None: raise SystemError('Source image could not be read from: {}'.format( src_img_fname)) try: src_height, src_width = src_img.shape[:2] except ValueError as e: print('src_img_fname: {}'.format(src_img_fname)) print('src_img: {}'.format(src_img)) print('src_img.shape: {}'.format(src_img.shape)) print('error: {}'.format(e)) sys.exit(1) conc_data_x = np.asarray(range(src_width), dtype=np.float64) plot_data_x = conc_data_x plot_data_y = [] plot_cols = [] plot_labels = [] stitched_img = src_img if labels_path: labels_img_fname = os.path.join( labels_path, img_fname_no_ext + '.{}'.format(labels_ext)) labels_img_orig = imread(labels_img_fname) if labels_img_orig is None: raise SystemError( 'Labels image could not be read from: {}'.format( labels_img_fname)) labels_height, labels_width = labels_img_orig.shape[:2] if labels_height != src_height or labels_width != src_width: raise AssertionError( 'Mismatch between dimensions of source: {} and label: {}'. format((src_height, src_width), (seg_height, seg_width))) if len(labels_img_orig.shape) == 3: labels_img_orig = np.squeeze(labels_img_orig[:, :, 0]) if show_img: cv2.imshow('labels_img_orig', labels_img_orig) if normalize_labels: labels_img = (labels_img_orig.astype(np.float64) / label_diff).astype(np.uint8) else: labels_img = np.copy(labels_img_orig) if len(labels_img.shape) == 3: labels_img = labels_img[:, :, 0].squeeze() conc_data_y = np.zeros((labels_width, ), dtype=np.float64) for i in range(labels_width): curr_pix = np.squeeze(labels_img[:, i]) if ice_type == 0: ice_pix = curr_pix[curr_pix != 0] else: ice_pix = curr_pix[curr_pix == ice_type] conc_data_y[i] = (len(ice_pix) / float(src_height)) * 100.0 conc_data = np.zeros((labels_width, 2), dtype=np.float64) conc_data[:, 0] = conc_data_x conc_data[:, 1] = conc_data_y plot_data_y.append(conc_data_y) plot_cols.append(labels_col_rgb) gt_dict = { conc_data_x[i]: conc_data_y[i] for i in range(labels_width) } if not normalize_labels: labels_img_orig = (labels_img_orig.astype(np.float64) * label_diff).astype(np.uint8) if len(labels_img_orig.shape) == 2: labels_img_orig = np.stack( (labels_img_orig, labels_img_orig, labels_img_orig), axis=2) stitched_img = np.concatenate((stitched_img, labels_img_orig), axis=1) plot_labels.append('GT') # gt_cl, _ = eval.extract_classes(labels_img_orig) # print('gt_cl: {}'.format(gt_cl)) mean_seg_counts = {} seg_count_data_y = [] curr_mae_data_y = [] mean_conc_diff = {} conc_diff_data_y = [] seg_img_disp_list = [] for seg_id, seg_path in enumerate(seg_paths): seg_img_fname = os.path.join( seg_path, img_fname_no_ext + '.{}'.format(seg_ext)) seg_img_orig = imread(seg_img_fname) seg_col = seg_cols_rgb[seg_id % n_cols] _label = seg_labels[seg_id] if seg_img_orig is None: raise SystemError( 'Seg image could not be read from: {}'.format( seg_img_fname)) seg_height, seg_width = seg_img_orig.shape[:2] if seg_height != src_height or seg_width != src_width: raise AssertionError( 'Mismatch between dimensions of source: {} and seg: {}'. format((src_height, src_width), (seg_height, seg_width))) if len(seg_img_orig.shape) == 3: seg_img_orig = np.squeeze(seg_img_orig[:, :, 0]) if seg_img_orig.max() > n_classes - 1: seg_img = (seg_img_orig.astype(np.float64) / label_diff).astype(np.uint8) seg_img_disp = seg_img_orig else: seg_img = seg_img_orig seg_img_disp = (seg_img_orig.astype(np.float64) * label_diff).astype(np.uint8) if len(seg_img_disp.shape) == 2: seg_img_disp = np.stack( (seg_img_disp, seg_img_disp, seg_img_disp), axis=2) ann_fmt = (font_id, loc[0], loc[1], size, thickness) + fgr_col + bgr_col put_text_with_background(seg_img_disp, seg_labels[seg_id], fmt=ann_fmt) seg_img_disp_list.append(seg_img_disp) # eval_cl, _ = eval.extract_classes(seg_img) # print('eval_cl: {}'.format(eval_cl)) if show_img: cv2.imshow('seg_img_orig', seg_img_orig) if len(seg_img.shape) == 3: seg_img = seg_img[:, :, 0].squeeze() conc_data_y = np.zeros((seg_width, ), dtype=np.float64) for i in range(seg_width): curr_pix = np.squeeze(seg_img[:, i]) if ice_type == 0: ice_pix = curr_pix[curr_pix != 0] else: ice_pix = curr_pix[curr_pix == ice_type] conc_data_y[i] = (len(ice_pix) / float(src_height)) * 100.0 plot_cols.append(seg_col) plot_data_y.append(conc_data_y) if labels_path: seg_dict = { conc_data_x[i]: conc_data_y[i] for i in range(seg_width) } # dists['bhattacharyya'].append(bhattacharyya(gt_dict, seg_dict)) dists[_label]['euclidean'].append(euclidean(gt_dict, seg_dict)) dists[_label]['mse'].append(mse(gt_dict, seg_dict)) dists[_label]['mae'].append(mae(gt_dict, seg_dict)) # dists['frobenius'].append(np.linalg.norm(conc_data_y - plot_data_y[0])) curr_mae_data_y.append(dists[_label]['mae'][-1]) else: if img_id > 0: if plot_changed_seg_count: flow = cv2.calcOpticalFlowFarneback( prev_seg_img[_label], seg_img, None, 0.5, 3, 15, 3, 5, 1.2, 0) print('flow: {}'.format(flow.shape)) # # Obtain the flow magnitude and direction angle # mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) # hsvImg = np.zeros((2160, 3840, 3), dtype=np.uint8) # hsvImg[..., 1] = 255 # # Update the color image # hsvImg[..., 0] = 0.5 * ang * 180 / np.pi # hsvImg[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) # rgbImg = cv2.cvtColor(hsvImg, cv2.COLOR_HSV2BGR) # rgbImg = resizeAR(rgbImg, width=out_width, height=out_height) # # Display the resulting frame # cv2.imshow('dense optical flow', rgbImg) # k = cv2.waitKey(0) curr_x, curr_y = (prev_x + flow[..., 0]).astype( np.int32), (prev_y + flow[..., 1]).astype(np.int32) seg_img_flow = seg_img[curr_y, curr_x] changed_seg_count[_label].append( np.count_nonzero( np.not_equal(seg_img, prev_seg_img[_label]))) seg_count_data_y.append(changed_seg_count[_label]) mean_seg_counts[_label] = np.mean( changed_seg_count[_label]) else: ice_concentration_diff[_label].append( np.mean( np.abs(conc_data_y - prev_conc_data_y[_label]))) conc_diff_data_y.append(ice_concentration_diff[_label]) mean_conc_diff[_label] = np.mean( ice_concentration_diff[_label]) else: if plot_changed_seg_count: prev_x, prev_y = np.meshgrid(range(seg_width), range(seg_height), sparse=False, indexing='xy') changed_seg_count[_label] = [] else: ice_concentration_diff[_label] = [] prev_seg_img[_label] = seg_img prev_conc_data_y[_label] = conc_data_y # conc_data = np.concatenate([conc_data_x, conc_data_y], axis=1) if labels_path: for i, k in enumerate(curr_mae_data_y): mae_data_y[i].append(k) n_test_images = img_id + 1 mae_data_X = np.asarray(range(1, n_test_images + 1), dtype=np.float64) print('') # print('mae_data_X:\n {}'.format(pformat(mae_data_X))) # print('mae_data_y:\n {}'.format(pformat(np.array(mae_data_y).transpose()))) if img_id == end_id: mae_data_y_arr = np.array(mae_data_y).transpose() print('mae_data_y:\n {}'.format( tabulate(mae_data_y_arr, headers=seg_labels, tablefmt='plain'))) pd.DataFrame(data=mae_data_y_arr, columns=seg_labels).to_clipboard(excel=True) mae_img = getPlotImage(mae_data_X, mae_data_y, plot_cols, 'MAE', seg_labels, 'frame', 'MAE') cv2.imshow('mae_img', mae_img) conc_diff_img = resize_ar(mae_img, seg_width, src_height, bkg_col=255) else: if img_id > 0: n_test_images = img_id seg_count_data_X = np.asarray(range(1, n_test_images + 1), dtype=np.float64) if plot_changed_seg_count: seg_count_img = getPlotImage(seg_count_data_X, seg_count_data_y, plot_cols, 'Count', seg_labels, 'frame', 'Changed Label Count') cv2.imshow('seg_count_img', seg_count_img) else: # print('seg_count_data_X:\n {}'.format(pformat(seg_count_data_X))) # print('conc_diff_data_y:\n {}'.format(pformat(conc_diff_data_y))) conc_diff_img = getPlotImage( seg_count_data_X, conc_diff_data_y, plot_cols, 'Mean concentration difference between consecutive frames' .format(ice_type_str), seg_labels, 'frame', 'Concentration Difference (%)') # cv2.imshow('conc_diff_img', conc_diff_img) conc_diff_img = resize_ar(conc_diff_img, seg_width, src_height, bkg_col=255) else: conc_diff_img = np.zeros((src_height, seg_width, 3), dtype=np.uint8) plot_labels += seg_labels if enable_plotting: plot_img = getPlotImage(plot_data_x, plot_data_y, plot_cols, plot_title, plot_labels, plot_x_label, plot_y_label, legend=0 # ylim=(0, 100) ) plot_img = resize_ar(plot_img, seg_width, src_height, bkg_col=255) # plt.plot(conc_data_x, conc_data_y) # plt.show() # conc_data_fname = os.path.join(out_path, img_fname_no_ext + '.txt') # np.savetxt(conc_data_fname, conc_data, fmt='%.6f') ann_fmt = (font_id, loc[0], loc[1], size, thickness) + labels_col_rgb + bgr_col put_text_with_background(src_img, 'frame {}'.format(img_id + 1), fmt=ann_fmt) if n_seg_paths == 1: print('seg_img_disp: {}'.format(seg_img_disp.shape)) print('plot_img: {}'.format(plot_img.shape)) stitched_seg_img = np.concatenate((seg_img_disp, plot_img), axis=1) print('stitched_seg_img: {}'.format(stitched_seg_img.shape)) print('stitched_img: {}'.format(stitched_img.shape)) stitched_img = np.concatenate((stitched_img, stitched_seg_img), axis=0 if labels_path else 1) elif n_seg_paths == 2: stitched_img = np.concatenate(( np.concatenate((src_img, conc_diff_img), axis=1), np.concatenate(seg_img_disp_list, axis=1), ), axis=0) elif n_seg_paths == 3: stitched_img = np.concatenate(( np.concatenate((src_img, plot_img, conc_diff_img), axis=1), np.concatenate(seg_img_disp_list, axis=1), ), axis=0) stitched_img = resize_ar(stitched_img, width=out_width, height=out_height) # print('dists: {}'.format(dists)) if write_to_video: video_out.write(stitched_img) else: stacked_img_path = os.path.join( stitched_seq_path, '{}.{}'.format(img_fname_no_ext, out_ext)) cv2.imwrite(stacked_img_path, stitched_img) cv2.imshow('stitched_img', stitched_img) k = cv2.waitKey(1 - _pause) if k == 27: break elif k == 32: _pause = 1 - _pause end_t = time.time() sys.stdout.write('\rDone {:d}/{:d} frames. fps: {}'.format( img_id + 1 - start_id, n_frames, 1.0 / (end_t - start_t))) sys.stdout.flush() print() if enable_plotting and write_to_video: video_out.release() if labels_path: median_dists = {} mean_dists = {} mae_data_y = [] for _label in seg_labels: _dists = dists[_label] mae_data_y.append(_dists['mae']) mean_dists[_label] = {k: np.mean(_dists[k]) for k in _dists} median_dists[_label] = {k: np.median(_dists[k]) for k in _dists} print('mean_dists:\n{}'.format(pformat(mean_dists))) print('median_dists:\n{}'.format(pformat(median_dists))) n_test_images = len(mae_data_y[0]) mae_data_x = np.asarray(range(1, n_test_images + 1), dtype=np.float64) mae_img = getPlotImage(mae_data_x, mae_data_y, plot_cols, 'MAE', seg_labels, 'test image', 'Mean Absolute Error') # plt.show() cv2.imshow('MAE', mae_img) k = cv2.waitKey(0) else: mean_seg_counts = {} median_seg_counts = {} seg_count_data_y = [] mean_conc_diff = {} median_conc_diff = {} conc_diff_data_y = [] for seg_id in ice_concentration_diff: if plot_changed_seg_count: seg_count_data_y.append(changed_seg_count[seg_id]) mean_seg_counts[seg_id] = np.mean(changed_seg_count[seg_id]) median_seg_counts[seg_id] = np.median( changed_seg_count[seg_id]) else: _ice_concentration_diff = ice_concentration_diff[seg_id] n_test_images = len(_ice_concentration_diff) conc_diff_data_y.append(_ice_concentration_diff) mean_conc_diff[seg_id] = np.mean(_ice_concentration_diff) median_conc_diff[seg_id] = np.median(_ice_concentration_diff) np.savetxt(os.path.join( out_path, '{}_ice_concentration_diff.txt'.format(seg_id)), _ice_concentration_diff, fmt='%8.4f', delimiter='\t') if plot_changed_seg_count: print('mean_seg_counts:\n{}'.format(pformat(mean_seg_counts))) print('median_seg_counts:\n{}'.format(pformat(median_seg_counts))) else: print('mean_conc_diff:') for seg_id in mean_conc_diff: print('{}\t{}'.format(seg_id, mean_conc_diff[seg_id])) print('median_conc_diff:') for seg_id in mean_conc_diff: print('{}\t{}'.format(seg_id, median_conc_diff[seg_id]))
def run(params): """ :param StitchParams params: :return: """ video_exts = ['mp4', 'mkv', 'avi', 'mpg', 'mpeg', 'mjpg'] # if not params.src_path: # assert params.db_root_dir, "either params.src_path or params.db_root_dir must be provided" # # params.src_path = os.path.join(params.db_root_dir, params.seq_name, 'images') print('Reading source images from: {}'.format(params.src_path)) src_files = [ k for k in os.listdir(params.src_path) if k.endswith('.{:s}'.format(params.images_ext)) ] total_frames = len(src_files) # print('file_list: {}'.format(file_list)) if total_frames <= 0: print('params: {}'.format(params)) raise SystemError('No input frames of type {} found'.format( params.images_ext)) print('total_frames: {}'.format(total_frames)) classes, composite_classes = read_class_info(params.class_info_path) n_classes = len(classes) class_ids = list(range(n_classes)) src_files.sort(key=sort_key) if params.n_frames <= 0: params.n_frames = total_frames if params.end_id < params.start_id: params.end_id = params.n_frames - 1 if params.patch_width <= 0: params.patch_width = params.patch_height if not params.patch_seq_path: if not params.patch_seq_name: params.patch_seq_name = '{:s}_{:d}_{:d}_{:d}_{:d}_{:d}_{:d}'.format( params.seq_name, params.start_id, params.end_id, params.patch_height, params.patch_width, params.patch_height, params.patch_width) params.patch_seq_path = os.path.join(params.db_root_dir, params.patch_seq_name, params.patch_seq_type) if not os.path.isdir(params.patch_seq_path): raise SystemError( 'params.patch_seq_path does not exist: {}'.format( params.patch_seq_path)) else: params.patch_seq_name = os.path.basename(params.patch_seq_path) if not params.stitched_seq_path: stitched_seq_name = '{}_stitched_{}'.format(params.patch_seq_name, params.method) if params.stacked: stitched_seq_name = '{}_stacked'.format(stitched_seq_name) params.method = 1 stitched_seq_name = '{}_{}_{}'.format(stitched_seq_name, params.start_id, params.end_id) params.stitched_seq_path = os.path.join(params.db_root_dir, stitched_seq_name, params.patch_seq_type) gt_labels_orig = gt_labels = video_out = None write_to_video = params.out_ext in video_exts if write_to_video: if not params.stitched_seq_path.endswith('.{}'.format(params.out_ext)): params.stitched_seq_path = '{}.{}'.format(params.stitched_seq_path, params.out_ext) print('Writing {}x{} output video to: {}'.format( params.width, params.height, params.stitched_seq_path)) save_dir = os.path.dirname(params.stitched_seq_path) fourcc = cv2.VideoWriter_fourcc(*params.codec) video_out = cv2.VideoWriter(params.stitched_seq_path, fourcc, params.fps, (params.width, params.height)) else: print('Writing output images to: {}'.format(params.stitched_seq_path)) save_dir = params.stitched_seq_path if save_dir and not os.path.isdir(save_dir): os.makedirs(save_dir) log_fname = os.path.join(save_dir, 'log_{:s}.txt'.format(getDateTime())) print_and_write('Saving log to: {}'.format(log_fname), log_fname) print_and_write( 'Reading patch images from: {}'.format(params.patch_seq_path), log_fname) n_patches = 0 pause_after_frame = 1 label_diff = int(255.0 / (params.n_classes - 1)) eval_mode = False if params.labels_path and params.labels_ext: _, labels_list, labels_total_frames = read_data( labels_path=params.labels_path, labels_ext=params.labels_ext) if labels_total_frames != total_frames: raise SystemError( 'Mismatch between no. of frames in GT and seg labels') eval_mode = True import densenet.evaluation.eval_segm as eval_segm avg_pix_acc = avg_mean_acc = avg_mean_IU = avg_fw_IU = 0 avg_mean_acc_ice = avg_mean_acc_ice_1 = avg_mean_acc_ice_2 = 0 avg_mean_IU_ice = avg_mean_IU_ice_1 = avg_mean_IU_ice_2 = 0 skip_mean_acc_ice_1 = skip_mean_acc_ice_2 = 0 skip_mean_IU_ice_1 = skip_mean_IU_ice_2 = 0 _n_frames = params.end_id - params.start_id + 1 for img_id in range(params.start_id, params.end_id + 1): # img_fname = '{:s}_{:d}.{:s}'.format(params.fname_templ, img_id + 1, params.img_ext) img_fname = src_files[img_id] img_fname_no_ext = os.path.splitext(img_fname)[0] src_img_fname = os.path.join(params.src_path, img_fname) src_img = cv2.imread(src_img_fname) if src_img is None: raise SystemError('Source image could not be read from: {}'.format( src_img_fname)) n_rows, n_cols, n_channels = src_img.shape # np.savetxt(os.path.join(params.db_root_dir, params.seq_name, 'labels_img_{}.txt'.format(img_id + 1)), # labels_img[:, :, 2], fmt='%d') out_id = 0 # skip_id = 0 min_row = 0 enable_stitching = 1 if params.method == -1: patch_src_img_fname = os.path.join( params.patch_seq_path, '{:s}.{:s}'.format(img_fname_no_ext, params.patch_ext)) if not os.path.exists(patch_src_img_fname): raise SystemError('Patch image does not exist: {}'.format( patch_src_img_fname)) stitched_img = cv2.imread(patch_src_img_fname) enable_stitching = 0 elif params.method == 0: stitched_img = None else: stitched_img = np.zeros((n_rows, n_cols, n_channels), dtype=np.uint8) while enable_stitching: max_row = min_row + params.patch_height if max_row > n_rows: diff = max_row - n_rows min_row -= diff max_row -= diff curr_row = None min_col = 0 while True: max_col = min_col + params.patch_width if max_col > n_cols: diff = max_col - n_cols min_col -= diff max_col -= diff patch_img_fname = '{:s}_{:d}'.format(img_fname_no_ext, out_id + 1) patch_src_img_fname = os.path.join( params.patch_seq_path, '{:s}.{:s}'.format(patch_img_fname, params.patch_ext)) if not os.path.exists(patch_src_img_fname): raise SystemError('Patch image does not exist: {}'.format( patch_src_img_fname)) src_patch = cv2.imread(patch_src_img_fname) seg_height, seg_width, _ = src_patch.shape if seg_width == 2 * params.patch_width or seg_width == 3 * params.patch_width: _start_id = seg_width - params.patch_width src_patch = src_patch[:, _start_id:] if params.normalize_patches: src_patch = (src_patch * label_diff).astype(np.uint8) # print('max(src_patch): {}'.format(np.max(src_patch))) # print('min(src_patch): {}'.format(np.min(src_patch))) out_id += 1 if params.method == 0: if curr_row is None: curr_row = src_patch else: curr_row = np.concatenate((curr_row, src_patch), axis=1) else: stitched_img[min_row:max_row, min_col:max_col, :] = src_patch if params.show_img: disp_img = src_img.copy() cv2.rectangle(disp_img, (min_col, min_row), (max_col, max_row), (255, 0, 0), 2) stitched_img_disp = stitched_img if params.disp_resize_factor != 1: disp_img = cv2.resize(disp_img, (0, 0), fx=params.disp_resize_factor, fy=params.disp_resize_factor) if stitched_img_disp is not None: stitched_img_disp = cv2.resize( stitched_img_disp, (0, 0), fx=params.disp_resize_factor, fy=params.disp_resize_factor) cv2.imshow('disp_img', disp_img) cv2.imshow('src_patch', src_patch) if stitched_img_disp is not None: cv2.imshow('stacked_img', stitched_img_disp) if curr_row is not None: cv2.imshow('curr_row', curr_row) k = cv2.waitKey(1 - pause_after_frame) if k == 27: sys.exit(0) elif k == 32: pause_after_frame = 1 - pause_after_frame # sys.stdout.write('\rDone {:d} patches in frame {:d}'.format(out_id, img_id + 1)) # sys.stdout.flush() if max_col >= n_cols: break min_col = max_col if params.method == 0: if stitched_img is None: stitched_img = curr_row else: stitched_img = np.concatenate((stitched_img, curr_row), axis=0) if max_row >= n_rows: break min_row = max_row if eval_mode: labels_img_fname = os.path.join( params.labels_path, img_fname_no_ext + '.{}'.format(params.labels_ext)) gt_labels_orig = imageio.imread(labels_img_fname) if gt_labels_orig is None: raise SystemError( 'Labels image could not be read from: {}'.format( labels_img_fname)) if len(gt_labels_orig.shape) == 3: gt_labels = np.squeeze(gt_labels_orig[:, :, 0]) else: gt_labels = gt_labels_orig seg_labels = np.squeeze(stitched_img[:, :, 0]) pix_acc = eval_segm.pixel_accuracy(seg_labels, gt_labels, class_ids) _acc, mean_acc = eval_segm.mean_accuracy(seg_labels, gt_labels, class_ids, return_acc=1) _IU, mean_IU = eval_segm.mean_IU(seg_labels, gt_labels, class_ids, return_iu=1) fw_IU = eval_segm.frequency_weighted_IU(seg_labels, gt_labels, class_ids) avg_pix_acc += (pix_acc - avg_pix_acc) / (img_id + 1) avg_mean_acc += (mean_acc - avg_mean_acc) / (img_id + 1) avg_mean_IU += (mean_IU - avg_mean_IU) / (img_id + 1) avg_fw_IU += (fw_IU - avg_fw_IU) / (img_id + 1) # print('_acc: {}'.format(_acc)) # print('_IU: {}'.format(_IU)) mean_acc_ice = np.mean(list(_acc.values())[1:]) avg_mean_acc_ice += (mean_acc_ice - avg_mean_acc_ice) / (img_id + 1) try: mean_acc_ice_1 = _acc[1] avg_mean_acc_ice_1 += (mean_acc_ice_1 - avg_mean_acc_ice_1) / ( img_id - skip_mean_acc_ice_1 + 1) except KeyError: print('\nskip_mean_acc_ice_1: {}'.format(img_id)) skip_mean_acc_ice_1 += 1 try: mean_acc_ice_2 = _acc[2] avg_mean_acc_ice_2 += (mean_acc_ice_2 - avg_mean_acc_ice_2) / ( img_id - skip_mean_acc_ice_2 + 1) except KeyError: print('\nskip_mean_acc_ice_2: {}'.format(img_id)) skip_mean_acc_ice_2 += 1 mean_IU_ice = np.mean(list(_IU.values())[1:]) avg_mean_IU_ice += (mean_IU_ice - avg_mean_IU_ice) / (img_id + 1) try: mean_IU_ice_1 = _IU[1] avg_mean_IU_ice_1 += (mean_IU_ice_1 - avg_mean_IU_ice_1) / ( img_id - skip_mean_IU_ice_1 + 1) except KeyError: print('\nskip_mean_IU_ice_1: {}'.format(img_id)) skip_mean_IU_ice_1 += 1 try: mean_IU_ice_2 = _IU[2] avg_mean_IU_ice_2 += (mean_IU_ice_2 - avg_mean_IU_ice_2) / ( img_id - skip_mean_IU_ice_2 + 1) except KeyError: print('\nskip_mean_IU_ice_2: {}'.format(img_id)) skip_mean_IU_ice_2 += 1 log_txt = '\nDone {:d}/{:d} frames '.format( img_id - params.start_id + 1, _n_frames) log_txt += "pix_acc: {:.5f} mean_acc: {:.5f} mean_IU: {:.5f} fw_IU: {:.5f}" \ " avg_acc_ice: {:.5f} avg_acc_ice_1: {:.5f} avg_acc_ice_2: {:.5f}" \ " avg_IU_ice: {:.5f} avg_IU_ice_1: {:.5f} avg_IU_ice_2: {:.5f}".format( avg_pix_acc, avg_mean_acc, avg_mean_IU, avg_fw_IU, avg_mean_acc_ice, avg_mean_acc_ice_1, avg_mean_acc_ice_2, avg_mean_IU_ice, avg_mean_IU_ice_1, avg_mean_IU_ice_2, ) print_and_write(log_txt, log_fname) else: sys.stdout.write('\rDone {:d}/{:d} frames'.format( img_id + 1 - params.start_id, _n_frames)) sys.stdout.flush() if not params.normalize_patches: # print('max(stitched_img): {:d}'.format(int(np.max(stitched_img)))) # print('max(stitched_img): {:d}'.format(int(np.max(stitched_img)))) # print('min(stitched_img): {:d}'.format(int(np.min(stitched_img)))) # print('label_diff: {}'.format(label_diff)) seg_img = (stitched_img * label_diff).astype(np.uint8) # print('max(seg_img): {:d}'.format(int(np.max(seg_img)))) # print('min(seg_img): {:d}'.format(int(np.min(seg_img)))) # np.savetxt('/home/abhineet/seg_img_{}.data'.format(img_id), np.squeeze(seg_img[:, :, 0]).astype( # np.float64), fmt='%f') # np.savetxt('/home/abhineet/stitched_img_{}.data'.format(img_id), np.squeeze(stitched_img[:, :, # 0]).astype(np.float64), fmt='%f') # seg_img = seg_img.astype(np.uint8) else: seg_img = stitched_img if params.stacked: # print('stitched_img.shape', stitched_img.shape) # print('src_img.shape', src_img.shape) if eval_mode: labels_img = (gt_labels_orig * label_diff).astype(np.uint8) stitched = np.concatenate((src_img, labels_img), axis=1) else: stitched = src_img out_img = np.concatenate((stitched, seg_img), axis=1) else: out_img = seg_img if write_to_video: out_img = resize_ar(out_img, params.width, params.height) video_out.write(out_img) # statinfo = os.stat(params.stitched_seq_path) # print('\nvideo_size: {}'.format(statinfo.st_size)) else: if params.resize_factor != 1: out_img = cv2.resize(out_img, (0, 0), fx=params.resize_factor, fy=params.resize_factor) stacked_img_path = os.path.join( params.stitched_seq_path, '{}.{}'.format(img_fname_no_ext, params.out_ext)) cv2.imwrite(stacked_img_path, out_img) n_patches += out_id sys.stdout.write('\n') sys.stdout.flush() sys.stdout.write('Total patches processed: {}\n'.format(n_patches)) if params.show_img: cv2.destroyAllWindows() if write_to_video: video_out.release() if params.del_patch_seq: print('Removing patch folder {}'.format(params.patch_seq_path)) shutil.rmtree(params.patch_seq_path) if eval_mode: log_txt = "pix_acc\t mean_acc\t mean_IU\t fw_IU\n{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\n".format( avg_pix_acc, avg_mean_acc, avg_mean_IU, avg_fw_IU) log_txt += "mean_acc_ice\t mean_acc_ice_1\t mean_acc_ice_2\n{:.5f}\t{:.5f}\t{:.5f}\n".format( avg_mean_acc_ice, avg_mean_acc_ice_1, avg_mean_acc_ice_2) log_txt += "mean_IU_ice\t mean_IU_ice_1\t mean_IU_ice_2\n{:.5f}\t{:.5f}\t{:.5f}\n".format( avg_mean_IU_ice, avg_mean_IU_ice_1, avg_mean_IU_ice_2) print_and_write(log_txt, log_fname) print_and_write('Saved log to: {}'.format(log_fname), log_fname) print_and_write( 'Read patch images from: {}'.format(params.patch_seq_path), log_fname)
def run(params): """ :param VisParams params: :return: """ eval_mode = False if params.multi_sequence_db: assert params.vis_split, "vis_split must be provided for CTC" """some repeated code here to allow better IntelliSense""" if params.dataset.lower() == 'ctc': from new_deeplab.datasets.build_ctc_data import CTCInfo db_splits = CTCInfo.DBSplits().__dict__ sequences = CTCInfo.sequences elif params.dataset.lower() in ('ipsc', 'ipsc_2_class', 'ipsc_5_class'): from new_deeplab.datasets.ipsc_info import IPSCInfo db_splits = IPSCInfo.DBSplits().__dict__ sequences = IPSCInfo.sequences elif params.dataset.lower() == 'ipsc_patches': from new_deeplab.datasets.ipsc_info import IPSCPatchesInfo db_splits = IPSCPatchesInfo.DBSplits().__dict__ sequences = IPSCPatchesInfo.sequences else: raise AssertionError( 'multi_sequence_db {} is not supported yet'.format( params.dataset)) seq_ids = db_splits[params.vis_split] src_files = [] seg_labels_list = [] if params.no_labels: src_labels_list = None else: src_labels_list = [] total_frames = 0 seg_total_frames = 0 for seq_id in seq_ids: seq_name, n_frames = sequences[seq_id] images_path = os.path.join(params.images_path, seq_name) if params.no_labels: labels_path = '' else: labels_path = os.path.join(params.labels_path, seq_name) _src_files, _src_labels_list, _total_frames = read_data( images_path, params.images_ext, labels_path, params.labels_ext) _src_filenames = [ os.path.splitext(os.path.basename(k))[0] for k in _src_files ] if not params.no_labels: _src_labels_filenames = [ os.path.splitext(os.path.basename(k))[0] for k in _src_labels_list ] assert _src_labels_filenames == _src_filenames, "mismatch between image and label filenames" eval_mode = False if params.seg_path and params.seg_ext: seg_path = os.path.join(params.seg_path, seq_name) _, _seg_labels_list, _seg_total_frames = read_data( labels_path=seg_path, labels_ext=params.seg_ext, labels_type='seg') _seg_labels__filenames = [ os.path.splitext(os.path.basename(k))[0] for k in _seg_labels_list ] if _seg_total_frames != _total_frames: if params.seg_on_subset and _seg_total_frames < _total_frames: matching_ids = [ _src_filenames.index(k) for k in _seg_labels__filenames ] _src_files = [_src_files[i] for i in matching_ids] if not params.no_labels: _src_labels_list = [ _src_labels_list[i] for i in matching_ids ] _total_frames = _seg_total_frames else: raise AssertionError( 'Mismatch between no. of frames in GT and seg labels: {} and {}' .format(_total_frames, _seg_total_frames)) seg_labels_list += _seg_labels_list seg_total_frames += _seg_total_frames eval_mode = True src_files += _src_files if not params.no_labels: src_labels_list += _src_labels_list # else: # params.stitch = params.save_stitched = 1 total_frames += _total_frames else: src_files, src_labels_list, total_frames = read_data( params.images_path, params.images_ext, params.labels_path, params.labels_ext) eval_mode = False if params.labels_path and params.seg_path and params.seg_ext: _, seg_labels_list, seg_total_frames = read_data( labels_path=params.seg_path, labels_ext=params.seg_ext, labels_type='seg') if seg_total_frames != total_frames: raise SystemError( 'Mismatch between no. of frames in GT and seg labels: {} and {}' .format(total_frames, seg_total_frames)) eval_mode = True # else: # params.stitch = params.save_stitched = 1 if params.end_id < params.start_id: params.end_id = total_frames - 1 classes, composite_classes = read_class_info(params.class_info_path) n_classes = len(classes) class_ids = list(range(n_classes)) class_id_to_color = {i: k[1] for i, k in enumerate(classes)} all_classes = [k[0] for k in classes + composite_classes] if not params.save_path: if eval_mode: params.save_path = os.path.join(os.path.dirname(params.seg_path), 'vis') else: params.save_path = os.path.join( os.path.dirname(params.images_path), 'vis') if not os.path.isdir(params.save_path): os.makedirs(params.save_path) if params.stitch and params.save_stitched: print('Saving visualization images to: {}'.format(params.save_path)) log_fname = os.path.join(params.save_path, 'vis_log_{:s}.txt'.format(getDateTime())) print('Saving log to: {}'.format(log_fname)) save_path_parent = os.path.dirname(params.save_path) templ_1 = os.path.basename(save_path_parent) templ_2 = os.path.basename(os.path.dirname(save_path_parent)) templ = '{}_{}'.format(templ_1, templ_2) # if params.selective_mode: # label_diff = int(255.0 / n_classes) # else: # label_diff = int(255.0 / (n_classes - 1)) print('templ: {}'.format(templ)) # print('label_diff: {}'.format(label_diff)) n_frames = params.end_id - params.start_id + 1 pix_acc = np.zeros((n_frames, )) mean_acc = np.zeros((n_frames, )) # mean_acc_ice = np.zeros((n_frames,)) # mean_acc_ice_1 = np.zeros((n_frames,)) # mean_acc_ice_2 = np.zeros((n_frames,)) # mean_IU = np.zeros((n_frames, )) # mean_IU_ice = np.zeros((n_frames,)) # mean_IU_ice_1 = np.zeros((n_frames,)) # mean_IU_ice_2 = np.zeros((n_frames,)) fw_IU = np.zeros((n_frames, )) fw_sum = np.zeros((n_classes, )) print_diff = max(1, int(n_frames * 0.01)) avg_mean_acc = {c: 0 for c in all_classes} avg_mean_IU = {c: 0 for c in all_classes} skip_mean_acc = {c: 0 for c in all_classes} skip_mean_IU = {c: 0 for c in all_classes} _pause = 1 labels_img = None for img_id in range(params.start_id, params.end_id + 1): stitched = [] # img_fname = '{:s}_{:d}.{:s}'.format(fname_templ, img_id + 1, img_ext) src_img_fname = src_files[img_id] img_dir = os.path.dirname(src_img_fname) seq_name = os.path.basename(img_dir) img_fname = os.path.basename(src_img_fname) img_fname_no_ext = os.path.splitext(img_fname)[0] src_img = None border_img = None if params.stitch or params.show_img: # src_img_fname = os.path.join(params.images_path, img_fname) src_img = cv2.imread(src_img_fname) if src_img is None: raise SystemError( 'Source image could not be read from: {}'.format( src_img_fname)) try: src_height, src_width, _ = src_img.shape except ValueError as e: print('src_img_fname: {}'.format(src_img_fname)) print('src_img: {}'.format(src_img)) print('src_img.shape: {}'.format(src_img.shape)) print('error: {}'.format(e)) sys.exit(1) if not params.blended: stitched.append(src_img) border_img = np.full_like(src_img, 255) border_img = border_img[:, :5, ...] if not params.no_labels: # labels_img_fname = os.path.join(params.labels_path, img_fname_no_ext + '.{}'.format(params.labels_ext)) labels_img_fname = src_labels_list[img_id] labels_img_orig = cv2.imread(labels_img_fname) if labels_img_orig is None: raise SystemError( 'Labels image could not be read from: {}'.format( labels_img_fname)) _, src_width = labels_img_orig.shape[:2] # if len(labels_img_orig.shape) == 3: # labels_img_orig = np.squeeze(labels_img_orig[:, :, 0]) if params.show_img: cv2.imshow('labels_img_orig', labels_img_orig) labels_img_orig, label_img_raw, class_to_ids = remove_fuzziness_in_mask( labels_img_orig, n_classes, class_id_to_color, fuzziness=5, check_equality=0) labels_img = np.copy(labels_img_orig) # if params.normalize_labels: # if params.selective_mode: # selective_idx = (labels_img_orig == 255) # print('labels_img_orig.shape: {}'.format(labels_img_orig.shape)) # print('selective_idx count: {}'.format(np.count_nonzero(selective_idx))) # labels_img_orig[selective_idx] = n_classes # if params.show_img: # cv2.imshow('labels_img_orig norm', labels_img_orig) # labels_img = (labels_img_orig.astype(np.float64) * label_diff).astype(np.uint8) # else: # labels_img = np.copy(labels_img_orig) # if len(labels_img.shape) != 3: # labels_img = np.stack((labels_img, labels_img, labels_img), axis=2) if params.stitch: if params.blended: full_mask_gs = cv2.cvtColor(labels_img, cv2.COLOR_BGR2GRAY) mask_binary = full_mask_gs == 0 labels_img_vis = (0.5 * src_img + 0.5 * labels_img).astype( np.uint8) labels_img_vis[mask_binary] = src_img[mask_binary] else: labels_img_vis = labels_img if params.add_border: stitched.append(border_img) stitched.append(labels_img_vis) if eval_mode: # seg_img_fname = os.path.join(params.seg_path, img_fname_no_ext + '.{}'.format(params.seg_ext)) seg_img_fname = seg_labels_list[img_id] seg_img = cv2.imread(seg_img_fname) if seg_img is None: raise SystemError( 'Segmentation image could not be read from: {}'.format( seg_img_fname)) # seg_img = convert_to_raw_mask(seg_img, n_classes, seg_img_fname) if len(seg_img.shape) == 3: seg_img = np.squeeze(seg_img[:, :, 0]) eval_cl, _ = eval.extract_classes(seg_img) gt_cl, _ = eval.extract_classes(label_img_raw) # if seg_img.max() > n_classes - 1: # seg_img = (seg_img.astype(np.float64) / label_diff).astype(np.uint8) seg_height, seg_width = seg_img.shape if seg_width == 2 * src_width or seg_width == 3 * src_width: _start_id = seg_width - src_width seg_img = seg_img[:, _start_id:] # print('seg_img.shape: ', seg_img.shape) # print('labels_img_orig.shape: ', labels_img_orig.shape) pix_acc[img_id] = eval.pixel_accuracy(seg_img, label_img_raw, class_ids) _acc, mean_acc[img_id] = eval.mean_accuracy(seg_img, label_img_raw, class_ids, return_acc=1) _IU, mean_IU[img_id] = eval.mean_IU(seg_img, label_img_raw, class_ids, return_iu=1) fw_IU[img_id], _fw = eval.frequency_weighted_IU(seg_img, label_img_raw, class_ids, return_freq=1) # try: # fw_sum += _fw # except ValueError as e: # print('fw_sum: {}'.format(fw_sum)) # print('_fw: {}'.format(_fw)) # # eval_cl, _ = eval.extract_classes(seg_img) # gt_cl, _ = eval.extract_classes(label_img_raw) # cl = np.union1d(eval_cl, gt_cl) # # print('cl: {}'.format(cl)) # print('eval_cl: {}'.format(eval_cl)) # print('gt_cl: {}'.format(gt_cl)) # # raise ValueError(e) for _class_name, _, base_ids in composite_classes: _acc_list = np.asarray(list(_acc.values())) _mean_acc = np.mean(_acc_list[base_ids]) avg_mean_acc[_class_name] += ( _mean_acc - avg_mean_acc[_class_name]) / (img_id + 1) _IU_list = np.asarray(list(_IU.values())) _mean_IU = np.mean(_IU_list[base_ids]) avg_mean_IU[_class_name] += ( _mean_IU - avg_mean_IU[_class_name]) / (img_id + 1) for _class_id, _class_data in enumerate(classes): _class_name = _class_data[0] try: _mean_acc = _acc[_class_id] avg_mean_acc[_class_name] += ( _mean_acc - avg_mean_acc[_class_name]) / ( img_id - skip_mean_acc[_class_name] + 1) except KeyError: print('\nskip_mean_acc {}: {}'.format( _class_name, img_id)) skip_mean_acc[_class_name] += 1 try: _mean_IU = _IU[_class_id] avg_mean_IU[_class_name] += ( _mean_IU - avg_mean_IU[_class_name]) / ( img_id - skip_mean_IU[_class_name] + 1) except KeyError: print('\nskip_mean_IU {}: {}'.format( _class_name, img_id)) skip_mean_IU[_class_name] += 1 # seg_img = (seg_img * label_diff).astype(np.uint8) seg_img_vis = raw_seg_to_rgb(seg_img, class_id_to_color) if params.stitch and params.stitch_seg: if params.blended: full_mask_gs = cv2.cvtColor(seg_img_vis, cv2.COLOR_BGR2GRAY) mask_binary = full_mask_gs == 0 seg_img_vis = (0.5 * src_img + 0.5 * seg_img_vis).astype(np.uint8) seg_img_vis[mask_binary] = src_img[mask_binary] if params.add_border: stitched.append(border_img) stitched.append(seg_img_vis) if not params.stitch and params.show_img: cv2.imshow('seg_img', seg_img_vis) # else: # _, _fw = eval.frequency_weighted_IU(label_img_raw, label_img_raw, return_freq=1) # try: # fw_sum += _fw # except ValueError as e: # print('fw_sum: {}'.format(fw_sum)) # print('_fw: {}'.format(_fw)) # # gt_cl, _ = eval.extract_classes(label_img_raw) # print('gt_cl: {}'.format(gt_cl)) # for k in range(n_classes): # if k not in gt_cl: # _fw.insert(k, 0) # # fw_sum += _fw # _fw_total = np.sum(_fw) # print('_fw: {}'.format(_fw)) # print('_fw_total: {}'.format(_fw_total)) # _fw_frac = np.array(_fw) / float(_fw_total) # print('_fw_frac: {}'.format(_fw_frac)) if params.stitch: stitched = np.concatenate(stitched, axis=1) if params.save_stitched: seg_save_path = os.path.join( params.save_path, '{}_{}.{}'.format(seq_name, img_fname_no_ext, params.out_ext)) cv2.imwrite(seg_save_path, stitched) if params.show_img: cv2.imshow('stitched', stitched) else: if params.show_img: cv2.imshow('src_img', src_img) if params.labels_path: cv2.imshow('labels_img', labels_img) if params.show_img: k = cv2.waitKey(1 - _pause) if k == 27: sys.exit(0) elif k == 32: _pause = 1 - _pause img_done = img_id - params.start_id + 1 if img_done % print_diff == 0: log_txt = 'Done {:5d}/{:5d} frames'.format(img_done, n_frames) if eval_mode: log_txt = '{:s} pix_acc: {:.5f} mean_acc: {:.5f} mean_IU: {:.5f} fw_IU: {:.5f}'.format( log_txt, pix_acc[img_id], mean_acc[img_id], mean_IU[img_id], fw_IU[img_id]) for _class in all_classes: log_txt += ' acc {}: {:.5f} '.format( _class, avg_mean_acc[_class]) for _class in all_classes: log_txt += ' IU {}: {:.5f} '.format( _class, avg_mean_acc[_class]) print_and_write(log_txt, log_fname) sys.stdout.write('\n') sys.stdout.flush() if eval_mode: log_txt = "pix_acc\t mean_acc\t mean_IU\t fw_IU\n{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\n".format( np.mean(pix_acc), np.mean(mean_acc), np.mean(mean_IU), np.mean(fw_IU)) for _class in all_classes: log_txt += ' mean_acc {}\t'.format(_class) log_txt += '\n' for _class in all_classes: log_txt += ' {:.5f}\t'.format(avg_mean_acc[_class]) log_txt += '\n' for _class in all_classes: log_txt += ' mean_IU {}\t'.format(_class) log_txt += '\n' for _class in all_classes: log_txt += ' {:.5f}\t'.format(avg_mean_IU[_class]) log_txt += '\n' print_and_write(log_txt, log_fname) log_txt = templ + '\n\t' for _class in classes: log_txt += '{}\t '.format(_class[0]) log_txt += 'all_classes\t all_classes(fw)\n' log_txt += 'recall\t' for _class in classes: log_txt += '{:.5f}\t'.format(avg_mean_acc[_class[0]]) log_txt += '{:.5f}\t{:.5f}\n'.format(np.mean(mean_acc), np.mean(pix_acc)) log_txt += 'precision\t' for _class in classes: log_txt += '{:.5f}\t'.format(avg_mean_IU[_class[0]]) log_txt += '{:.5f}\t{:.5f}\n'.format(np.mean(mean_IU), np.mean(fw_IU)) print_and_write(log_txt, log_fname) # fw_sum_total = np.sum(fw_sum) # fw_sum_frac = fw_sum / float(fw_sum_total) # print('fw_sum_total: {}'.format(fw_sum_total)) # print('fw_sum_frac: {}'.format(fw_sum_frac)) print('Wrote log to: {}'.format(log_fname))