Exemple #1
0
def eval_seq(opt,
             dataloader,
             data_type,
             result_filename,
             save_dir=None,
             show_image=True,
             frame_rate=30):
    if save_dir:
        mkdir_if_missing(save_dir)
    tracker = JDETracker(opt, frame_rate=frame_rate)
    timer = Timer()
    results = []
    frame_id = 0
    for path, img, img0 in dataloader:
        if frame_id % 20 == 0:
            logger.info('Processing frame {} ({:.2f} fps)'.format(
                frame_id, 1. / max(1e-5, timer.average_time)))

        # run tracking
        timer.tic()
        blob = torch.from_numpy(img).cuda().unsqueeze(0)
        online_targets = tracker.update(blob, img0)
        online_tlwhs = []
        online_ids = []
        for t in online_targets:
            tlwh = t.tlwh
            tid = t.track_id
            vertical = tlwh[2] / tlwh[3] > 1.6
            if tlwh[2] * tlwh[3] > opt.min_box_area and not vertical:
                online_tlwhs.append(tlwh)
                online_ids.append(tid)
        timer.toc()
        # save results
        results.append((frame_id + 1, online_tlwhs, online_ids))
        if show_image or save_dir is not None:
            online_im = vis.plot_tracking(img0,
                                          online_tlwhs,
                                          online_ids,
                                          frame_id=frame_id,
                                          fps=1. / timer.average_time)
        if show_image:
            cv2.imshow('online_im', online_im)
            cv2.waitKey(1)
        if save_dir is not None:
            cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
                        online_im)
        frame_id += 1
    # save results
    write_results(result_filename, results, data_type)
    return frame_id, timer.average_time, timer.calls
Exemple #2
0
def eval_seq(opt,
             data_loader,
             data_type,
             result_f_name,
             save_dir=None,
             show_image=True,
             frame_rate=30):
    """
    对序列进行跟踪
    :param opt:
    :param data_loader:
    :param data_type:
    :param result_f_name:
    :param save_dir:
    :param show_image:
    :param frame_rate:
    :return:
    """
    if save_dir:
        mkdir_if_missing(save_dir)

    tracker = JDETracker(opt, frame_rate=frame_rate)

    timer = Timer()

    results_dict = defaultdict(list)

    frame_id = 0  # 帧编号
    for path, img, img0 in data_loader:
        if frame_id % 20 == 0:
            logger.info('Processing frame {} ({:.2f} fps)'.format(
                frame_id, 1. / max(1e-5, timer.average_time)))

        # --- run tracking
        timer.tic()
        blob = torch.from_numpy(img).cuda().unsqueeze(0)

        # --- 输出结果的核心函数: 更新跟踪状态
        online_targets_dict = tracker.update(blob, img0)

        # 聚合每一帧的结果
        online_tlwhs_dict = defaultdict(list)
        online_ids_dict = defaultdict(list)
        for cls_id in range(opt.num_classes):
            # 处理每一个目标检测类
            online_targets = online_targets_dict[cls_id]
            for track in online_targets:
                tlwh = track.tlwh
                t_id = track.track_id
                # vertical = tlwh[2] / tlwh[3] > 1.6  # box宽高比判断:w/h不能超过1.6?
                if tlwh[2] * tlwh[3] > opt.min_box_area:  # and not vertical:
                    online_tlwhs_dict[cls_id].append(tlwh)
                    online_ids_dict[cls_id].append(t_id)

        timer.toc()

        # 保存每一帧的结果
        for cls_id in range(opt.num_classes):
            results_dict[cls_id].append(
                (frame_id + 1, online_tlwhs_dict[cls_id],
                 online_ids_dict[cls_id]))

        # 绘制每一帧的结果
        if show_image or save_dir is not None:
            if frame_id > 0:
                online_im = vis.plot_tracks(image=img0,
                                            tlwhs_dict=online_tlwhs_dict,
                                            obj_ids_dict=online_ids_dict,
                                            num_classes=opt.num_classes,
                                            frame_id=frame_id,
                                            fps=1.0 / timer.average_time)

        # # 可视化中间结果
        # if frame_id > 0:
        #     cv2.imshow('Frame {}'.format(str(frame_id)), online_im)
        #     cv2.waitKey()

        if frame_id > 0:
            # 是否显示中间结果
            if show_image:
                cv2.imshow('online_im', online_im)
            if save_dir is not None:
                cv2.imwrite(
                    os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
                    online_im)

        # 处理完一帧, 更新frame_id
        frame_id += 1

    # 写入最终结果save results
    write_results_dict(result_f_name, results_dict, data_type)

    return frame_id, timer.average_time, timer.calls
Exemple #3
0
def eval_seq(opt,
             data_loader,
             data_type,
             result_f_name,
             save_dir=None,
             show_image=True,
             frame_rate=30,
             mode='track'):
    """
    :param opt:
    :param data_loader:
    :param data_type:
    :param result_f_name:
    :param save_dir:
    :param show_image:
    :param frame_rate:
    :param mode: track or detect
    :return:
    """
    if save_dir:
        mkdir_if_missing(save_dir)

    tracker = JDETracker(opt, frame_rate=frame_rate)

    timer = Timer()

    results_dict = defaultdict(list)

    frame_id = 0  # 帧编号
    for path, img, img_0 in data_loader:
        if frame_id % 20 == 0:
            logger.info('Processing frame {} ({:.2f} fps)'.format(
                frame_id, 1. / max(1e-5, timer.average_time)))

        # --- run tracking
        timer.tic()
        # blob = torch.from_numpy(img).cuda().unsqueeze(0)
        blob = torch.from_numpy(img).to(opt.device).unsqueeze(0)

        if mode == 'track':  # process tracking
            # --- track updates of each frame
            online_targets_dict = tracker.update(blob, img_0)

            # 聚合每一帧的结果
            online_tlwhs_dict = defaultdict(list)
            online_ids_dict = defaultdict(list)
            for cls_id in range(opt.num_classes):
                # 处理每一个目标检测类
                online_targets = online_targets_dict[cls_id]
                for track in online_targets:
                    tlwh = track.tlwh
                    t_id = track.track_id
                    # vertical = tlwh[2] / tlwh[3] > 1.6  # box宽高比判断:w/h不能超过1.6?
                    if tlwh[2] * tlwh[3] > opt.min_box_area:  # and not vertical:
                        online_tlwhs_dict[cls_id].append(tlwh)
                        online_ids_dict[cls_id].append(t_id)

            timer.toc()

            # 保存每一帧的结果
            for cls_id in range(opt.num_classes):
                results_dict[cls_id].append((frame_id + 1, online_tlwhs_dict[cls_id], online_ids_dict[cls_id]))

            # 绘制每一帧的结果
            if show_image or save_dir is not None:
                if frame_id > 0:
                    online_im: ndarray = vis.plot_tracks(image=img_0,
                                                         tlwhs_dict=online_tlwhs_dict,
                                                         obj_ids_dict=online_ids_dict,
                                                         num_classes=opt.num_classes,
                                                         frame_id=frame_id,
                                                         fps=1.0 / timer.average_time)

        elif mode == 'detect':  # process detections
            # update detection results of this frame(or image)
            dets_dict = tracker.update_detections(blob, img_0)

            timer.toc()

            # plot detection results
            if show_image or save_dir is not None:
                online_im = vis.plot_detects(image=img_0,
                                             dets_dict=dets_dict,
                                             num_classes=opt.num_classes,
                                             frame_id=frame_id,
                                             fps=1.0 / max(1e-5, timer.average_time))
        else:
            print('[Err]: un-recognized mode.')


        # # 可视化中间结果
        # if frame_id > 0:
        #     cv2.imshow('Frame {}'.format(str(frame_id)), online_im)
        #     cv2.waitKey()

        if frame_id > 0:
            # 是否显示中间结果
            if show_image:
                cv2.imshow('online_im', online_im)
            if save_dir is not None:
                cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), online_im)

        # 处理完一帧, 更新frame_id
        frame_id += 1

    # 写入最终结果save results
    write_results_dict(result_f_name, results_dict, data_type)

    return frame_id, timer.average_time, timer.calls