def test_single(img_path, dev): """ :param img_path: :param dev: :return: """ if not os.path.isfile(img_path): print('[Err]: invalid image path.') return # Head dimensions of the net heads = {'hm': 5, 'reg': 2, 'wh': 2, 'id': 128} # Load model and put to device net = create_model(arch='resdcn_18', heads=heads, head_conv=256) model_path = '/mnt/diskb/even/MCMOT/exp/mot/default/mcmot_last_det_resdcn_18.pth' net = load_model(model=net, model_path=model_path) net = net.to(dev) net.eval() print(net) # Read image img_0 = cv2.imread(img_path) # BGR assert img_0 is not None, 'Failed to load ' + img_path # Padded resize h_in, w_in = 608, 1088 # (608, 1088) (320, 640) img, _, _, _ = letterbox(img=img_0, height=h_in, width=w_in) # Preprocess image: BGR -> RGB and H×W×C -> C×H×W img = img[:, :, ::-1].transpose(2, 0, 1) img = np.ascontiguousarray(img, dtype=np.float32) img /= 255.0 # Convert to tensor and put to device blob = torch.from_numpy(img).unsqueeze(0).to(dev) with torch.no_grad(): # Network output output = net.forward(blob)[-1] # Tracking output hm = output['hm'].sigmoid_() reg = output['reg'] wh = output['wh'] id_feature = output['id'] id_feature = F.normalize(id_feature, dim=1) # L2 normalization for feature vector # Decode output dets, inds, cls_inds_mask = mot_decode(hm, wh, reg, 5, False, 128) # Get ReID feature vector by object class cls_id_feats = [] # topK feature vectors of each object class for cls_id in range(5): # cls_id starts from 0 # get inds of each object class cls_inds = inds[:, cls_inds_mask[cls_id]] # gather feats for each object class cls_id_feature = _tranpose_and_gather_feat(id_feature, cls_inds) # inds: 1×128 cls_id_feature = cls_id_feature.squeeze(0) # n × FeatDim if dev == 'cpu': cls_id_feature = cls_id_feature.numpy() else: cls_id_feature = cls_id_feature.cpu().numpy() cls_id_feats.append(cls_id_feature) # Convert back to original image coordinate system height_0, width_0 = img_0.shape[0], img_0.shape[ 1] # H, W of original input image dets = map2orig(dets, h_in // 4, w_in // 4, height_0, width_0, 5) # translate and scale # Parse detections of each class dets_dict = defaultdict(list) for cls_id in range(5): # cls_id start from index 0 cls_dets = dets[cls_id] # filter out low conf score dets remain_inds = cls_dets[:, 4] > 0.4 cls_dets = cls_dets[remain_inds] # cls_id_feature = cls_id_feats[cls_id][remain_inds] # if need re-id dets_dict[cls_id] = cls_dets # Visualize detection results img_draw = plot_detects(img_0, dets_dict, 5, frame_id=0, fps=30.0) # cv2.imshow('Detection', img_draw) # cv2.waitKey() cv2.imwrite('/mnt/diskb/even/MCMOT/results/00000.jpg', img_draw)
def eval_seq_and_output_dets(opt, data_loader, data_type, result_f_name, out_dir, save_dir=None, show_image=True): """ :param opt: :param data_loader: :param data_type: :param result_f_name: :param out_dir: :param save_dir: :param show_image: :return: """ if save_dir: mkdir_if_missing(save_dir) if not os.path.isdir(out_dir): os.makedirs(out_dir) else: shutil.rmtree(out_dir) os.makedirs(out_dir) tracker = JDETracker(opt, frame_rate=30) timer = Timer() results_dict = defaultdict(list) frame_id = 0 # 帧编号 for path, img, img_0 in data_loader: if frame_id % 20 == 0: logger.info('Processing frame {} ({:.2f} fps)'.format( frame_id, 1. / max(1e-5, timer.average_time))) # --- run tracking timer.tic() blob = torch.from_numpy(img).to(opt.device).unsqueeze(0) # update detection results of this frame(or image) dets_dict = tracker.update_detection(blob, img_0) timer.toc() # plot detection results if show_image or save_dir is not None: online_im = vis.plot_detects(image=img_0, dets_dict=dets_dict, num_classes=opt.num_classes, frame_id=frame_id, fps=1.0 / max(1e-5, timer.average_time)) if frame_id > 0: # 是否显示中间结果 if show_image: cv2.imshow('online_im', online_im) if save_dir is not None: cv2.imwrite( os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im) # ----- 格式化并输出detection结果(txt)到指定目录 # 格式化 dets_list = format_dets_dict2dets_list(dets_dict, w=img_0.shape[1], h=img_0.shape[0]) # 输出到指定目录 out_img_name = os.path.split(path)[-1] # if out_img_name == '192.168.1.219_2_2018-02-13_14-46-00-688_3-1518504845.jpg': # print('pause here') out_f_name = out_img_name.replace('.jpg', '.txt') out_f_path = out_dir + '/' + out_f_name with open(out_f_path, 'w', encoding='utf-8') as w_h: w_h.write('class prob x y w h total=' + str(len(dets_list)) + '\n') for det in dets_list: w_h.write('%d %f %f %f %f %f\n' % (det[0], det[1], det[2], det[3], det[4], det[5])) # print('{} written'.format(out_f_path)) # 处理完一帧, 更新frame_id frame_id += 1 print('Total {:d} detection result output.\n'.format(frame_id)) # 写入最终结果save results write_results_dict(result_f_name, results_dict, data_type) # 返回结果 return frame_id, timer.average_time, timer.calls
def eval_seq(opt, data_loader, data_type, result_f_name, save_dir=None, show_image=True, frame_rate=30, mode='track'): """ :param opt: :param data_loader: :param data_type: :param result_f_name: :param save_dir: :param show_image: :param frame_rate: :param mode: track or detect :return: """ if save_dir: mkdir_if_missing(save_dir) tracker = JDETracker(opt, frame_rate=frame_rate) timer = Timer() results_dict = defaultdict(list) frame_id = 0 # 帧编号 for path, img, img_0 in data_loader: if frame_id % 20 == 0: logger.info('Processing frame {} ({:.2f} fps)'.format( frame_id, 1. / max(1e-5, timer.average_time))) # --- run tracking timer.tic() # blob = torch.from_numpy(img).cuda().unsqueeze(0) blob = torch.from_numpy(img).to(opt.device).unsqueeze(0) if mode == 'track': # process tracking # --- track updates of each frame online_targets_dict = tracker.update_tracking(blob, img_0) # 聚合每一帧的结果 online_tlwhs_dict = defaultdict(list) online_ids_dict = defaultdict(list) for cls_id in range(opt.num_classes): # 处理每一个目标检测类 online_targets = online_targets_dict[cls_id] for track in online_targets: tlwh = track.tlwh t_id = track.track_id # vertical = tlwh[2] / tlwh[3] > 1.6 # box宽高比判断:w/h不能超过1.6? if tlwh[2] * tlwh[ 3] > opt.min_box_area: # and not vertical: online_tlwhs_dict[cls_id].append(tlwh) online_ids_dict[cls_id].append(t_id) timer.toc() # 保存每一帧的结果 for cls_id in range(opt.num_classes): results_dict[cls_id].append( (frame_id + 1, online_tlwhs_dict[cls_id], online_ids_dict[cls_id])) # 绘制每一帧的结果 if show_image or save_dir is not None: if frame_id > 0: online_im: ndarray = vis.plot_tracks( image=img_0, tlwhs_dict=online_tlwhs_dict, obj_ids_dict=online_ids_dict, num_classes=opt.num_classes, frame_id=frame_id, fps=1.0 / timer.average_time) elif mode == 'detect': # process detections # update detection results of this frame(or image) dets_dict = tracker.update_detection(blob, img_0) timer.toc() # plot detection results if show_image or save_dir is not None: online_im = vis.plot_detects(image=img_0, dets_dict=dets_dict, num_classes=opt.num_classes, frame_id=frame_id, fps=1.0 / max(1e-5, timer.average_time)) else: print('[Err]: un-recognized mode.') # # 可视化中间结果 # if frame_id > 0: # cv2.imshow('Frame {}'.format(str(frame_id)), online_im) # cv2.waitKey() if frame_id > 0: # 是否显示中间结果 if show_image: cv2.imshow('online_im', online_im) if save_dir is not None: cv2.imwrite( os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im) # 处理完一帧, 更新frame_id frame_id += 1 # 写入最终结果save results write_results_dict(result_f_name, results_dict, data_type) return frame_id, timer.average_time, timer.calls
def eval_seq(opt, data_loader, data_type, result_f_name, save_dir=None, show_image=True, frame_rate=30, mode='track'): """ :param opt: :param data_loader: :param data_type: :param result_f_name: :param save_dir: :param show_image: :param frame_rate: :param mode: track or detect :return: """ if save_dir: mkdir_if_missing(save_dir) # tracker = JDETracker(opt, frame_rate) tracker = MCJDETracker(opt, frame_rate) timer = Timer() results_dict = defaultdict(list) frame_id = 0 # frame index for path, img, img0 in data_loader: if frame_id % 30 == 0 and frame_id != 0: logger.info('Processing frame {} ({:.2f} fps)'.format( frame_id, 1.0 / max(1e-5, timer.average_time))) # --- run tracking blob = torch.from_numpy(img).unsqueeze(0).to(opt.device) if mode == 'track': # process tracking # ----- track updates of each frame timer.tic() online_targets_dict = tracker.update_tracking(blob, img0) timer.toc() # ----- # collect current frame's result online_tlwhs_dict = defaultdict(list) online_ids_dict = defaultdict(list) online_scores_dict = defaultdict(list) for cls_id in range(opt.num_classes): # process each class id online_targets = online_targets_dict[cls_id] for track in online_targets: tlwh = track.tlwh t_id = track.track_id score = track.score if tlwh[2] * tlwh[ 3] > opt.min_box_area: # and not vertical: online_tlwhs_dict[cls_id].append(tlwh) online_ids_dict[cls_id].append(t_id) online_scores_dict[cls_id].append(score) # collect result for cls_id in range(opt.num_classes): results_dict[cls_id].append( (frame_id + 1, online_tlwhs_dict[cls_id], online_ids_dict[cls_id], online_scores_dict[cls_id])) # draw track/detection if show_image or save_dir is not None: if frame_id > 0: online_im: ndarray = vis.plot_tracks( image=img0, tlwhs_dict=online_tlwhs_dict, obj_ids_dict=online_ids_dict, num_classes=opt.num_classes, frame_id=frame_id, fps=1.0 / timer.average_time) elif mode == 'detect': # process detections timer.tic() # update detection results of this frame(or image) dets_dict = tracker.update_detection(blob, img0) timer.toc() # plot detection results if show_image or save_dir is not None: online_im = vis.plot_detects(image=img0, dets_dict=dets_dict, num_classes=opt.num_classes, frame_id=frame_id, fps=1.0 / max(1e-5, timer.average_time)) else: print('[Err]: un-recognized mode.') if frame_id > 0: if show_image: cv2.imshow('online_im', online_im) if save_dir is not None: cv2.imwrite( os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im) # update frame id frame_id += 1 # write track/detection results write_results_dict(result_f_name, results_dict, data_type) return frame_id, timer.average_time, timer.calls