def eval_seq(opt, dataloader, data_type, result_filename, save_dir=None, show_image=True, frame_rate=30): if save_dir: mkdir_if_missing(save_dir) tracker = JDETracker(opt, frame_rate=frame_rate) timer = Timer() results = [] frame_id = 0 for path, img, img0 in dataloader: if frame_id % 20 == 0: logger.info('Processing frame {} ({:.2f} fps)'.format( frame_id, 1. / max(1e-5, timer.average_time))) # run tracking timer.tic() blob = torch.from_numpy(img).cuda().unsqueeze(0) online_targets = tracker.update(blob, img0) online_tlwhs = [] online_ids = [] for t in online_targets: tlwh = t.tlwh tid = t.track_id vertical = tlwh[2] / tlwh[3] > 1.6 if tlwh[2] * tlwh[3] > opt.min_box_area and not vertical: online_tlwhs.append(tlwh) online_ids.append(tid) timer.toc() # save results results.append((frame_id + 1, online_tlwhs, online_ids)) if show_image or save_dir is not None: online_im = vis.plot_tracking(img0, online_tlwhs, online_ids, frame_id=frame_id, fps=1. / timer.average_time) if show_image: cv2.imshow('online_im', online_im) cv2.waitKey(1) if save_dir is not None: cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im) frame_id += 1 # save results write_results(result_filename, results, data_type) return frame_id, timer.average_time, timer.calls
def eval_seq(opt, data_loader, data_type, result_f_name, save_dir=None, show_image=True, frame_rate=30): """ 对序列进行跟踪 :param opt: :param data_loader: :param data_type: :param result_f_name: :param save_dir: :param show_image: :param frame_rate: :return: """ if save_dir: mkdir_if_missing(save_dir) tracker = JDETracker(opt, frame_rate=frame_rate) timer = Timer() results_dict = defaultdict(list) frame_id = 0 # 帧编号 for path, img, img0 in data_loader: if frame_id % 20 == 0: logger.info('Processing frame {} ({:.2f} fps)'.format( frame_id, 1. / max(1e-5, timer.average_time))) # --- run tracking timer.tic() blob = torch.from_numpy(img).cuda().unsqueeze(0) # --- 输出结果的核心函数: 更新跟踪状态 online_targets_dict = tracker.update(blob, img0) # 聚合每一帧的结果 online_tlwhs_dict = defaultdict(list) online_ids_dict = defaultdict(list) for cls_id in range(opt.num_classes): # 处理每一个目标检测类 online_targets = online_targets_dict[cls_id] for track in online_targets: tlwh = track.tlwh t_id = track.track_id # vertical = tlwh[2] / tlwh[3] > 1.6 # box宽高比判断:w/h不能超过1.6? if tlwh[2] * tlwh[3] > opt.min_box_area: # and not vertical: online_tlwhs_dict[cls_id].append(tlwh) online_ids_dict[cls_id].append(t_id) timer.toc() # 保存每一帧的结果 for cls_id in range(opt.num_classes): results_dict[cls_id].append( (frame_id + 1, online_tlwhs_dict[cls_id], online_ids_dict[cls_id])) # 绘制每一帧的结果 if show_image or save_dir is not None: if frame_id > 0: online_im = vis.plot_tracks(image=img0, tlwhs_dict=online_tlwhs_dict, obj_ids_dict=online_ids_dict, num_classes=opt.num_classes, frame_id=frame_id, fps=1.0 / timer.average_time) # # 可视化中间结果 # if frame_id > 0: # cv2.imshow('Frame {}'.format(str(frame_id)), online_im) # cv2.waitKey() if frame_id > 0: # 是否显示中间结果 if show_image: cv2.imshow('online_im', online_im) if save_dir is not None: cv2.imwrite( os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im) # 处理完一帧, 更新frame_id frame_id += 1 # 写入最终结果save results write_results_dict(result_f_name, results_dict, data_type) return frame_id, timer.average_time, timer.calls
def eval_seq(opt, data_loader, data_type, result_f_name, save_dir=None, show_image=True, frame_rate=30, mode='track'): """ :param opt: :param data_loader: :param data_type: :param result_f_name: :param save_dir: :param show_image: :param frame_rate: :param mode: track or detect :return: """ if save_dir: mkdir_if_missing(save_dir) tracker = JDETracker(opt, frame_rate=frame_rate) timer = Timer() results_dict = defaultdict(list) frame_id = 0 # 帧编号 for path, img, img_0 in data_loader: if frame_id % 20 == 0: logger.info('Processing frame {} ({:.2f} fps)'.format( frame_id, 1. / max(1e-5, timer.average_time))) # --- run tracking timer.tic() # blob = torch.from_numpy(img).cuda().unsqueeze(0) blob = torch.from_numpy(img).to(opt.device).unsqueeze(0) if mode == 'track': # process tracking # --- track updates of each frame online_targets_dict = tracker.update(blob, img_0) # 聚合每一帧的结果 online_tlwhs_dict = defaultdict(list) online_ids_dict = defaultdict(list) for cls_id in range(opt.num_classes): # 处理每一个目标检测类 online_targets = online_targets_dict[cls_id] for track in online_targets: tlwh = track.tlwh t_id = track.track_id # vertical = tlwh[2] / tlwh[3] > 1.6 # box宽高比判断:w/h不能超过1.6? if tlwh[2] * tlwh[3] > opt.min_box_area: # and not vertical: online_tlwhs_dict[cls_id].append(tlwh) online_ids_dict[cls_id].append(t_id) timer.toc() # 保存每一帧的结果 for cls_id in range(opt.num_classes): results_dict[cls_id].append((frame_id + 1, online_tlwhs_dict[cls_id], online_ids_dict[cls_id])) # 绘制每一帧的结果 if show_image or save_dir is not None: if frame_id > 0: online_im: ndarray = vis.plot_tracks(image=img_0, tlwhs_dict=online_tlwhs_dict, obj_ids_dict=online_ids_dict, num_classes=opt.num_classes, frame_id=frame_id, fps=1.0 / timer.average_time) elif mode == 'detect': # process detections # update detection results of this frame(or image) dets_dict = tracker.update_detections(blob, img_0) timer.toc() # plot detection results if show_image or save_dir is not None: online_im = vis.plot_detects(image=img_0, dets_dict=dets_dict, num_classes=opt.num_classes, frame_id=frame_id, fps=1.0 / max(1e-5, timer.average_time)) else: print('[Err]: un-recognized mode.') # # 可视化中间结果 # if frame_id > 0: # cv2.imshow('Frame {}'.format(str(frame_id)), online_im) # cv2.waitKey() if frame_id > 0: # 是否显示中间结果 if show_image: cv2.imshow('online_im', online_im) if save_dir is not None: cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im) # 处理完一帧, 更新frame_id frame_id += 1 # 写入最终结果save results write_results_dict(result_f_name, results_dict, data_type) return frame_id, timer.average_time, timer.calls