Beispiel #1
0
def test_single(img_path, dev):
    """
    :param img_path:
    :param dev:
    :return:
    """
    if not os.path.isfile(img_path):
        print('[Err]: invalid image path.')
        return

    # Head dimensions of the net
    heads = {'hm': 5, 'reg': 2, 'wh': 2, 'id': 128}

    # Load model and put to device
    net = create_model(arch='resdcn_18', heads=heads, head_conv=256)
    model_path = '/mnt/diskb/even/MCMOT/exp/mot/default/mcmot_last_det_resdcn_18.pth'
    net = load_model(model=net, model_path=model_path)
    net = net.to(dev)
    net.eval()
    print(net)

    # Read image
    img_0 = cv2.imread(img_path)  # BGR
    assert img_0 is not None, 'Failed to load ' + img_path

    # Padded resize
    h_in, w_in = 608, 1088  # (608, 1088) (320, 640)
    img, _, _, _ = letterbox(img=img_0, height=h_in, width=w_in)

    # Preprocess image: BGR -> RGB and H×W×C -> C×H×W
    img = img[:, :, ::-1].transpose(2, 0, 1)
    img = np.ascontiguousarray(img, dtype=np.float32)
    img /= 255.0

    # Convert to tensor and put to device
    blob = torch.from_numpy(img).unsqueeze(0).to(dev)

    with torch.no_grad():
        # Network output
        output = net.forward(blob)[-1]

        # Tracking output
        hm = output['hm'].sigmoid_()
        reg = output['reg']
        wh = output['wh']
        id_feature = output['id']
        id_feature = F.normalize(id_feature,
                                 dim=1)  # L2 normalization for feature vector

        # Decode output
        dets, inds, cls_inds_mask = mot_decode(hm, wh, reg, 5, False, 128)

        # Get ReID feature vector by object class
        cls_id_feats = []  # topK feature vectors of each object class
        for cls_id in range(5):  # cls_id starts from 0
            # get inds of each object class
            cls_inds = inds[:, cls_inds_mask[cls_id]]

            # gather feats for each object class
            cls_id_feature = _tranpose_and_gather_feat(id_feature,
                                                       cls_inds)  # inds: 1×128
            cls_id_feature = cls_id_feature.squeeze(0)  # n × FeatDim
            if dev == 'cpu':
                cls_id_feature = cls_id_feature.numpy()
            else:
                cls_id_feature = cls_id_feature.cpu().numpy()
            cls_id_feats.append(cls_id_feature)

        # Convert back to original image coordinate system
        height_0, width_0 = img_0.shape[0], img_0.shape[
            1]  # H, W of original input image
        dets = map2orig(dets, h_in // 4, w_in // 4, height_0, width_0,
                        5)  # translate and scale

        # Parse detections of each class
        dets_dict = defaultdict(list)
        for cls_id in range(5):  # cls_id start from index 0
            cls_dets = dets[cls_id]

            # filter out low conf score dets
            remain_inds = cls_dets[:, 4] > 0.4
            cls_dets = cls_dets[remain_inds]
            # cls_id_feature = cls_id_feats[cls_id][remain_inds]  # if need re-id
            dets_dict[cls_id] = cls_dets

    # Visualize detection results
    img_draw = plot_detects(img_0, dets_dict, 5, frame_id=0, fps=30.0)
    # cv2.imshow('Detection', img_draw)
    # cv2.waitKey()
    cv2.imwrite('/mnt/diskb/even/MCMOT/results/00000.jpg', img_draw)
Beispiel #2
0
def eval_seq_and_output_dets(opt,
                             data_loader,
                             data_type,
                             result_f_name,
                             out_dir,
                             save_dir=None,
                             show_image=True):
    """
    :param opt:
    :param data_loader:
    :param data_type:
    :param result_f_name:
    :param out_dir:
    :param save_dir:
    :param show_image:
    :return:
    """
    if save_dir:
        mkdir_if_missing(save_dir)

    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)
    else:
        shutil.rmtree(out_dir)
        os.makedirs(out_dir)

    tracker = JDETracker(opt, frame_rate=30)

    timer = Timer()

    results_dict = defaultdict(list)

    frame_id = 0  # 帧编号
    for path, img, img_0 in data_loader:
        if frame_id % 20 == 0:
            logger.info('Processing frame {} ({:.2f} fps)'.format(
                frame_id, 1. / max(1e-5, timer.average_time)))

        # --- run tracking
        timer.tic()
        blob = torch.from_numpy(img).to(opt.device).unsqueeze(0)

        # update detection results of this frame(or image)
        dets_dict = tracker.update_detection(blob, img_0)

        timer.toc()

        # plot detection results
        if show_image or save_dir is not None:
            online_im = vis.plot_detects(image=img_0,
                                         dets_dict=dets_dict,
                                         num_classes=opt.num_classes,
                                         frame_id=frame_id,
                                         fps=1.0 /
                                         max(1e-5, timer.average_time))

        if frame_id > 0:
            # 是否显示中间结果
            if show_image:
                cv2.imshow('online_im', online_im)
            if save_dir is not None:
                cv2.imwrite(
                    os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
                    online_im)

        # ----- 格式化并输出detection结果(txt)到指定目录
        # 格式化
        dets_list = format_dets_dict2dets_list(dets_dict,
                                               w=img_0.shape[1],
                                               h=img_0.shape[0])

        # 输出到指定目录
        out_img_name = os.path.split(path)[-1]
        # if out_img_name == '192.168.1.219_2_2018-02-13_14-46-00-688_3-1518504845.jpg':
        #     print('pause here')
        out_f_name = out_img_name.replace('.jpg', '.txt')
        out_f_path = out_dir + '/' + out_f_name
        with open(out_f_path, 'w', encoding='utf-8') as w_h:
            w_h.write('class prob x y w h total=' + str(len(dets_list)) + '\n')

            for det in dets_list:
                w_h.write('%d %f %f %f %f %f\n' %
                          (det[0], det[1], det[2], det[3], det[4], det[5]))
        # print('{} written'.format(out_f_path))

        # 处理完一帧, 更新frame_id
        frame_id += 1
    print('Total {:d} detection result output.\n'.format(frame_id))

    # 写入最终结果save results
    write_results_dict(result_f_name, results_dict, data_type)

    # 返回结果
    return frame_id, timer.average_time, timer.calls
Beispiel #3
0
def eval_seq(opt,
             data_loader,
             data_type,
             result_f_name,
             save_dir=None,
             show_image=True,
             frame_rate=30,
             mode='track'):
    """
    :param opt:
    :param data_loader:
    :param data_type:
    :param result_f_name:
    :param save_dir:
    :param show_image:
    :param frame_rate:
    :param mode: track or detect
    :return:
    """
    if save_dir:
        mkdir_if_missing(save_dir)

    tracker = JDETracker(opt, frame_rate=frame_rate)

    timer = Timer()

    results_dict = defaultdict(list)

    frame_id = 0  # 帧编号
    for path, img, img_0 in data_loader:
        if frame_id % 20 == 0:
            logger.info('Processing frame {} ({:.2f} fps)'.format(
                frame_id, 1. / max(1e-5, timer.average_time)))

        # --- run tracking
        timer.tic()
        # blob = torch.from_numpy(img).cuda().unsqueeze(0)
        blob = torch.from_numpy(img).to(opt.device).unsqueeze(0)

        if mode == 'track':  # process tracking
            # --- track updates of each frame
            online_targets_dict = tracker.update_tracking(blob, img_0)

            # 聚合每一帧的结果
            online_tlwhs_dict = defaultdict(list)
            online_ids_dict = defaultdict(list)
            for cls_id in range(opt.num_classes):
                # 处理每一个目标检测类
                online_targets = online_targets_dict[cls_id]
                for track in online_targets:
                    tlwh = track.tlwh
                    t_id = track.track_id
                    # vertical = tlwh[2] / tlwh[3] > 1.6  # box宽高比判断:w/h不能超过1.6?
                    if tlwh[2] * tlwh[
                            3] > opt.min_box_area:  # and not vertical:
                        online_tlwhs_dict[cls_id].append(tlwh)
                        online_ids_dict[cls_id].append(t_id)

            timer.toc()

            # 保存每一帧的结果
            for cls_id in range(opt.num_classes):
                results_dict[cls_id].append(
                    (frame_id + 1, online_tlwhs_dict[cls_id],
                     online_ids_dict[cls_id]))

            # 绘制每一帧的结果
            if show_image or save_dir is not None:
                if frame_id > 0:
                    online_im: ndarray = vis.plot_tracks(
                        image=img_0,
                        tlwhs_dict=online_tlwhs_dict,
                        obj_ids_dict=online_ids_dict,
                        num_classes=opt.num_classes,
                        frame_id=frame_id,
                        fps=1.0 / timer.average_time)

        elif mode == 'detect':  # process detections
            # update detection results of this frame(or image)
            dets_dict = tracker.update_detection(blob, img_0)

            timer.toc()

            # plot detection results
            if show_image or save_dir is not None:
                online_im = vis.plot_detects(image=img_0,
                                             dets_dict=dets_dict,
                                             num_classes=opt.num_classes,
                                             frame_id=frame_id,
                                             fps=1.0 /
                                             max(1e-5, timer.average_time))
        else:
            print('[Err]: un-recognized mode.')

        # # 可视化中间结果
        # if frame_id > 0:
        #     cv2.imshow('Frame {}'.format(str(frame_id)), online_im)
        #     cv2.waitKey()

        if frame_id > 0:
            # 是否显示中间结果
            if show_image:
                cv2.imshow('online_im', online_im)
            if save_dir is not None:
                cv2.imwrite(
                    os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
                    online_im)

        # 处理完一帧, 更新frame_id
        frame_id += 1

    # 写入最终结果save results
    write_results_dict(result_f_name, results_dict, data_type)

    return frame_id, timer.average_time, timer.calls
Beispiel #4
0
def eval_seq(opt,
             data_loader,
             data_type,
             result_f_name,
             save_dir=None,
             show_image=True,
             frame_rate=30,
             mode='track'):
    """
    :param opt:
    :param data_loader:
    :param data_type:
    :param result_f_name:
    :param save_dir:
    :param show_image:
    :param frame_rate:
    :param mode: track or detect
    :return:
    """
    if save_dir:
        mkdir_if_missing(save_dir)

    # tracker = JDETracker(opt, frame_rate)
    tracker = MCJDETracker(opt, frame_rate)

    timer = Timer()

    results_dict = defaultdict(list)

    frame_id = 0  # frame index
    for path, img, img0 in data_loader:
        if frame_id % 30 == 0 and frame_id != 0:
            logger.info('Processing frame {} ({:.2f} fps)'.format(
                frame_id, 1.0 / max(1e-5, timer.average_time)))

        # --- run tracking
        blob = torch.from_numpy(img).unsqueeze(0).to(opt.device)

        if mode == 'track':  # process tracking
            # ----- track updates of each frame
            timer.tic()

            online_targets_dict = tracker.update_tracking(blob, img0)

            timer.toc()
            # -----

            # collect current frame's result
            online_tlwhs_dict = defaultdict(list)
            online_ids_dict = defaultdict(list)
            online_scores_dict = defaultdict(list)
            for cls_id in range(opt.num_classes):  # process each class id
                online_targets = online_targets_dict[cls_id]
                for track in online_targets:
                    tlwh = track.tlwh
                    t_id = track.track_id
                    score = track.score
                    if tlwh[2] * tlwh[
                            3] > opt.min_box_area:  # and not vertical:
                        online_tlwhs_dict[cls_id].append(tlwh)
                        online_ids_dict[cls_id].append(t_id)
                        online_scores_dict[cls_id].append(score)

            # collect result
            for cls_id in range(opt.num_classes):
                results_dict[cls_id].append(
                    (frame_id + 1, online_tlwhs_dict[cls_id],
                     online_ids_dict[cls_id], online_scores_dict[cls_id]))

            # draw track/detection
            if show_image or save_dir is not None:
                if frame_id > 0:
                    online_im: ndarray = vis.plot_tracks(
                        image=img0,
                        tlwhs_dict=online_tlwhs_dict,
                        obj_ids_dict=online_ids_dict,
                        num_classes=opt.num_classes,
                        frame_id=frame_id,
                        fps=1.0 / timer.average_time)

        elif mode == 'detect':  # process detections
            timer.tic()

            # update detection results of this frame(or image)
            dets_dict = tracker.update_detection(blob, img0)

            timer.toc()

            # plot detection results
            if show_image or save_dir is not None:
                online_im = vis.plot_detects(image=img0,
                                             dets_dict=dets_dict,
                                             num_classes=opt.num_classes,
                                             frame_id=frame_id,
                                             fps=1.0 /
                                             max(1e-5, timer.average_time))
        else:
            print('[Err]: un-recognized mode.')

        if frame_id > 0:
            if show_image:
                cv2.imshow('online_im', online_im)
            if save_dir is not None:
                cv2.imwrite(
                    os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
                    online_im)

        # update frame id
        frame_id += 1

    # write track/detection results
    write_results_dict(result_f_name, results_dict, data_type)

    return frame_id, timer.average_time, timer.calls