def eval_seq(opt, data_loader, data_type, result_f_name, save_dir=None, show_image=True, frame_rate=30, mode='track'): """ :param opt: :param data_loader: :param data_type: :param result_f_name: :param save_dir: :param show_image: :param frame_rate: :param mode: track or detect :return: """ if save_dir: mkdir_if_missing(save_dir) tracker = JDETracker(opt, frame_rate=frame_rate) timer = Timer() results_dict = defaultdict(list) frame_id = 0 # 帧编号 for path, img, img_0 in data_loader: if frame_id % 20 == 0: logger.info('Processing frame {} ({:.2f} fps)'.format( frame_id, 1. / max(1e-5, timer.average_time))) # --- run tracking timer.tic() # blob = torch.from_numpy(img).cuda().unsqueeze(0) blob = torch.from_numpy(img).to(opt.device).unsqueeze(0) if mode == 'track': # process tracking # --- track updates of each frame online_targets_dict = tracker.update_tracking(blob, img_0) # 聚合每一帧的结果 online_tlwhs_dict = defaultdict(list) online_ids_dict = defaultdict(list) for cls_id in range(opt.num_classes): # 处理每一个目标检测类 online_targets = online_targets_dict[cls_id] for track in online_targets: tlwh = track.tlwh t_id = track.track_id # vertical = tlwh[2] / tlwh[3] > 1.6 # box宽高比判断:w/h不能超过1.6? if tlwh[2] * tlwh[ 3] > opt.min_box_area: # and not vertical: online_tlwhs_dict[cls_id].append(tlwh) online_ids_dict[cls_id].append(t_id) timer.toc() # 保存每一帧的结果 for cls_id in range(opt.num_classes): results_dict[cls_id].append( (frame_id + 1, online_tlwhs_dict[cls_id], online_ids_dict[cls_id])) # 绘制每一帧的结果 if show_image or save_dir is not None: if frame_id > 0: online_im: ndarray = vis.plot_tracks( image=img_0, tlwhs_dict=online_tlwhs_dict, obj_ids_dict=online_ids_dict, num_classes=opt.num_classes, frame_id=frame_id, fps=1.0 / timer.average_time) elif mode == 'detect': # process detections # update detection results of this frame(or image) dets_dict = tracker.update_detection(blob, img_0) timer.toc() # plot detection results if show_image or save_dir is not None: online_im = vis.plot_detects(image=img_0, dets_dict=dets_dict, num_classes=opt.num_classes, frame_id=frame_id, fps=1.0 / max(1e-5, timer.average_time)) else: print('[Err]: un-recognized mode.') # # 可视化中间结果 # if frame_id > 0: # cv2.imshow('Frame {}'.format(str(frame_id)), online_im) # cv2.waitKey() if frame_id > 0: # 是否显示中间结果 if show_image: cv2.imshow('online_im', online_im) if save_dir is not None: cv2.imwrite( os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im) # 处理完一帧, 更新frame_id frame_id += 1 # 写入最终结果save results write_results_dict(result_f_name, results_dict, data_type) return frame_id, timer.average_time, timer.calls
def eval_seq_and_output_dets(opt, data_loader, data_type, result_f_name, out_dir, save_dir=None, show_image=True): """ :param opt: :param data_loader: :param data_type: :param result_f_name: :param out_dir: :param save_dir: :param show_image: :return: """ if save_dir: mkdir_if_missing(save_dir) if not os.path.isdir(out_dir): os.makedirs(out_dir) else: shutil.rmtree(out_dir) os.makedirs(out_dir) tracker = JDETracker(opt, frame_rate=30) timer = Timer() results_dict = defaultdict(list) frame_id = 0 # 帧编号 for path, img, img_0 in data_loader: if frame_id % 20 == 0: logger.info('Processing frame {} ({:.2f} fps)'.format( frame_id, 1. / max(1e-5, timer.average_time))) # --- run tracking timer.tic() blob = torch.from_numpy(img).to(opt.device).unsqueeze(0) # update detection results of this frame(or image) dets_dict = tracker.update_detection(blob, img_0) timer.toc() # plot detection results if show_image or save_dir is not None: online_im = vis.plot_detects(image=img_0, dets_dict=dets_dict, num_classes=opt.num_classes, frame_id=frame_id, fps=1.0 / max(1e-5, timer.average_time)) if frame_id > 0: # 是否显示中间结果 if show_image: cv2.imshow('online_im', online_im) if save_dir is not None: cv2.imwrite( os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im) # ----- 格式化并输出detection结果(txt)到指定目录 # 格式化 dets_list = format_dets_dict2dets_list(dets_dict, w=img_0.shape[1], h=img_0.shape[0]) # 输出到指定目录 out_img_name = os.path.split(path)[-1] # if out_img_name == '192.168.1.219_2_2018-02-13_14-46-00-688_3-1518504845.jpg': # print('pause here') out_f_name = out_img_name.replace('.jpg', '.txt') out_f_path = out_dir + '/' + out_f_name with open(out_f_path, 'w', encoding='utf-8') as w_h: w_h.write('class prob x y w h total=' + str(len(dets_list)) + '\n') for det in dets_list: w_h.write('%d %f %f %f %f %f\n' % (det[0], det[1], det[2], det[3], det[4], det[5])) # print('{} written'.format(out_f_path)) # 处理完一帧, 更新frame_id frame_id += 1 print('Total {:d} detection result output.\n'.format(frame_id)) # 写入最终结果save results write_results_dict(result_f_name, results_dict, data_type) # 返回结果 return frame_id, timer.average_time, timer.calls