Beispiel #1
0
def run_demo(opt):
    result_root = opt.output_root if opt.output_root != '' else '.'
    mkdir_if_missing(result_root)

    logger.info('Starting tracking...')
    data_loader = datasets.LoadVideo(opt.input_video, opt.img_size)
    result_file_name = os.path.join(result_root, 'results.txt')
    frame_rate = data_loader.frame_rate

    frame_dir = None if opt.output_format == 'text' else osp.join(result_root, 'frame')
    try:  # 视频推断的入口函数
        eval_seq(opt=opt,
                 data_loader=data_loader,
                 data_type='mot',
                 result_f_name=result_file_name,
                 save_dir=frame_dir,
                 show_image=False,
                 frame_rate=frame_rate)
    except Exception as e:
        logger.info(e)

    if opt.output_format == 'video':
        output_video_path = osp.join(result_root, 'result.mp4')
        cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -b 5000k -c:v mpeg4 {}' \
            .format(osp.join(result_root, 'frame'),
                    output_video_path)
        os.system(cmd_str)
Beispiel #2
0
def write_results(filename, results, data_type):
    if data_type == 'mot':
        save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
    elif data_type == 'kitti':
        save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
    else:
        raise ValueError(data_type)

    with open(filename, 'w') as f:
        for frame_id, tlwhs, track_ids in results:
            if data_type == 'kitti':
                frame_id -= 1
            for tlwh, track_id in zip(tlwhs, track_ids):
                if track_id < 0:
                    continue
                x1, y1, w, h = tlwh
                x2, y2 = x1 + w, y1 + h
                line = save_format.format(frame=frame_id,
                                          id=track_id,
                                          x1=x1,
                                          y1=y1,
                                          x2=x2,
                                          y2=y2,
                                          w=w,
                                          h=h)
                f.write(line)
    logger.info('save results to {}'.format(filename))
def main(data_root='../results', seqs=('MOT16-02', ), exp_name='demo'):
    logger.setLevel(logging.INFO)

    data_type = 'mot'

    # run tracking
    accs = []
    n_frame = 0
    timer_avgs, timer_calls = [], []
    for seq in seqs:

        result_root = os.path.join(data_root, seq)
        result_filename = os.path.join(result_root,
                                       '{}.txt'.format('kalman_iou'))
        print(result_filename)

        # eval
        logger.info('Evaluate seq: {}'.format(seq))
        evaluator = Evaluator(data_root, seq, data_type)
        accs.append(evaluator.eval_file(result_filename))

    # get summary
    metrics = mm.metrics.motchallenge_metrics
    mh = mm.metrics.create()
    summary = Evaluator.get_summary(accs, seqs, metrics)
    strsummary = mm.io.render_summary(summary,
                                      formatters=mh.formatters,
                                      namemap=mm.io.motchallenge_metric_names)
    print(strsummary)
    Evaluator.save_summary(
        summary, os.path.join(result_root, 'summary_{}.xlsx'.format(exp_name)))
Beispiel #4
0
def write_results_dict(file_name, results_dict, data_type, num_classes=2):
    """
    :param file_name:
    :param results_dict:
    :param data_type:
    :param num_classes:
    :return:
    """
    if data_type == 'mot':
        save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
    elif data_type == 'kitti':
        save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
    else:
        raise ValueError(data_type)

    with open(file_name, 'w') as f:
        for cls_id in range(num_classes):
            if cls_id == 0:  # 背景类不处理
                continue

            # 处理每一个目标检测类别的结果
            results = results_dict[cls_id]
            for frame_id, tlwhs, track_ids in results:
                if data_type == 'kitti':
                    frame_id -= 1
                for tlwh, track_id in zip(tlwhs, track_ids):
                    if track_id < 0:
                        continue

                    x1, y1, w, h = tlwh
                    x2, y2 = x1 + w, y1 + h
                    line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h)
                    f.write(line)

    logger.info('save results to {}'.format(file_name))
Beispiel #5
0
def write_results(filename, results_dict: Dict, data_type: str):
    if not filename:
        return
    path = os.path.dirname(filename)
    if not os.path.exists(path):
        os.makedirs(path)

    if data_type in ('mot', 'mcmot', 'lab'):
        save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
    elif data_type == 'kitti':
        save_format = '{frame} {id} pedestrian -1 -1 -10 {x1} {y1} {x2} {y2} -1 -1 -1 -1000 -1000 -1000 -10 {score}\n'
    else:
        raise ValueError(data_type)

    with open(filename, 'w') as f:
        for frame_id, frame_data in results_dict.items():
            if data_type == 'kitti':
                frame_id -= 1
            for tlwh, track_id in frame_data:
                if track_id < 0:
                    continue
                x1, y1, w, h = tlwh
                x2, y2 = x1 + w, y1 + h
                line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h, score=1.0)
                f.write(line)
    logger.info('Save results to {}'.format(filename))
Beispiel #6
0
def eval_seq(opt,
             dataloader,
             data_type,
             result_filename,
             save_dir=None,
             show_image=True,
             frame_rate=30):
    if save_dir:
        mkdir_if_missing(save_dir)
    tracker = JDETracker(opt, frame_rate=frame_rate)
    timer = Timer()
    results = []
    frame_id = 0
    for path, img, img0 in dataloader:
        if frame_id % 20 == 0:
            logger.info('Processing frame {} ({:.2f} fps)'.format(
                frame_id, 1. / max(1e-5, timer.average_time)))

        # run tracking
        timer.tic()
        blob = torch.from_numpy(img).cuda().unsqueeze(0)
        online_targets = tracker.update(blob, img0)
        online_tlwhs = []
        online_ids = []
        for t in online_targets:
            tlwh = t.tlwh
            tid = t.track_id
            vertical = tlwh[2] / tlwh[3] > 1.6
            if tlwh[2] * tlwh[3] > opt.min_box_area and not vertical:
                online_tlwhs.append(tlwh)
                online_ids.append(tid)
        timer.toc()
        # save results
        results.append((frame_id + 1, online_tlwhs, online_ids))
        if show_image or save_dir is not None:
            online_im = vis.plot_tracking(img0,
                                          online_tlwhs,
                                          online_ids,
                                          frame_id=frame_id,
                                          fps=1. / timer.average_time)
        if show_image:
            cv2.imshow('online_im', online_im)
            cv2.waitKey(1)
        if save_dir is not None:
            cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
                        online_im)
        frame_id += 1
    # save results
    write_results(result_filename, results, data_type)
    return frame_id, timer.average_time, timer.calls
Beispiel #7
0
def write_results_dict(file_name, results_dict, data_type, num_classes=5):
    """
    :param file_name:
    :param results_dict:
    :param data_type:
    :param num_classes:
    :return:
    """
    if data_type == 'mot':
        # save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
        save_format = '{frame},{id},{x1},{y1},{w},{h},1,{cls_id},1\n'
        save_format = '{frame},{id},{x1},{y1},{w},{h},{score},{cls_id},1\n'
    elif data_type == 'kitti':
        save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
    else:
        raise ValueError(data_type)

    with open(file_name, 'w') as f:
        for cls_id in range(num_classes):  # process each object class
            cls_results = results_dict[cls_id]
            for frame_id, tlwhs, track_ids, scores in cls_results:
                if data_type == 'kitti':
                    frame_id -= 1

                for tlwh, track_id, score in zip(tlwhs, track_ids, scores):
                    if track_id < 0:
                        continue

                    x1, y1, w, h = tlwh
                    # x2, y2 = x1 + w, y1 + h
                    # line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h)
                    line = save_format.format(
                        frame=frame_id,
                        id=track_id,
                        x1=x1,
                        y1=y1,
                        w=w,
                        h=h,
                        score=score,  # detection score
                        cls_id=cls_id)
                    f.write(line)

    logger.info('save results to {}'.format(file_name))
Beispiel #8
0
def run_demo(opt):
    """
    :param opt:
    :return:
    """
    result_root = opt.output_root if opt.output_root != '' else '.'
    mkdir_if_missing(result_root)

    # clear existing frame results
    frame_res_dir = result_root + '/frames'
    if os.path.isdir(frame_res_dir):
        shutil.rmtree(frame_res_dir)
        os.makedirs(frame_res_dir)
    else:
        os.makedirs(frame_res_dir)

    if opt.input_mode == 'video':
        logger.info('Starting tracking...')
        data_loader = datasets.LoadVideo(opt.input_video, opt.img_size)
    elif opt.input_mode == 'image_dir':
        logger.info('Starting detection...')
        data_loader = datasets.LoadImages(opt.input_img, opt.img_size)
    elif opt.input_mode == 'img_path_list_txt':
        if not os.path.isfile(opt.input_img):
            print('[Err]: invalid image file path list.')
            return

        with open(opt.input_img, 'r', encoding='utf-8') as r_h:
            logger.info('Starting detection...')
            paths = [x.strip() for x in r_h.readlines()]
            print('Total {:d} image files.'.format(len(paths)))
            data_loader = datasets.LoadImages(path=paths,
                                              img_size=opt.img_size)

    result_file_name = os.path.join(result_root, 'results.txt')
    frame_rate = data_loader.frame_rate

    frame_dir = None if opt.output_format == 'text' else osp.join(
        result_root, 'frame')

    opt.device = device
    try:  # 视频推断的入口函数
        if opt.id_weight > 0:
            eval_seq(opt=opt,
                     data_loader=data_loader,
                     data_type='mot',
                     result_f_name=result_file_name,
                     save_dir=frame_dir,
                     show_image=False,
                     frame_rate=frame_rate,
                     mode='track')
        else:
            # eval_seq(opt=opt,
            #          data_loader=data_loader,
            #          data_type='mot',
            #          result_f_name=result_file_name,
            #          save_dir=frame_dir,
            #          show_image=False,
            #          frame_rate=frame_rate,
            #          mode='detect')

            # only for tmp detection evaluation...
            output_dir = '/users/duanyou/c5/results_new/results_all/mcmot_hrnet18_deconv_ep3'
            eval_seq_and_output_dets(opt=opt,
                                     data_loader=data_loader,
                                     data_type='mot',
                                     result_f_name=result_file_name,
                                     out_dir=output_dir,
                                     save_dir=frame_dir,
                                     show_image=False)
    except Exception as e:
        logger.info(e)

    if opt.output_format == 'video':
        output_video_path = osp.join(result_root, 'result.mp4')
        cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -b 5000k -c:v mpeg4 {}' \
            .format(osp.join(result_root, 'frame'),
                    output_video_path)
        os.system(cmd_str)
Beispiel #9
0
def run_demo(opt):
    """
    :param opt:
    :return:
    """
    result_root = opt.output_root if opt.output_root != '' else '.'
    mkdir_if_missing(result_root)

    # clear existing frame results
    frame_res_dir = result_root + '/frame'
    if os.path.isdir(frame_res_dir):
        shutil.rmtree(frame_res_dir)
        os.makedirs(frame_res_dir)
    else:
        os.makedirs(frame_res_dir)

    if opt.input_mode == 'video':
        if opt.id_weight > 0:
            logger.info('Starting tracking...')
        else:
            logger.info('Starting detection...')
        if not os.path.isfile(opt.input_video):
            print('[Err]: invalid input video file.')
            return

        data_loader = datasets.LoadVideo(opt.input_video,
                                         opt.img_size)  # load video as input
        f_name = os.path.split(opt.input_video)[-1][:-4]
    elif opt.input_mode == 'image_dir':
        logger.info('Starting detection...')
        data_loader = datasets.LoadImages(opt.input_img,
                                          opt.img_size)  # load images as input
        f_name = os.path.split(opt.input_video)[-1]
        opt.id_weight = 0  # only do detection in this mode
    elif opt.input_mode == 'img_path_list_txt':
        logger.info('Starting detection...')
        if not os.path.isfile(opt.input_img):
            print('[Err]: invalid image file path list.')
            return

        opt.id_weight = 0  # only do detection in this mode
        with open(opt.input_img, 'r', encoding='utf-8') as r_h:
            logger.info('Starting detection...')
            paths = [x.strip() for x in r_h.readlines()]
            print('Total {:d} image files.'.format(len(paths)))
            data_loader = datasets.LoadImages(path=paths,
                                              img_size=opt.img_size)

    result_file_name = os.path.join(result_root, 'results.txt')
    frame_rate = data_loader.frame_rate
    frame_dir = None if opt.output_format == 'text' else osp.join(
        result_root, 'frame')

    # Set device
    # opt.device = device

    # set device
    opt.device = str(FindFreeGPU())
    print('Using gpu: {:s}'.format(opt.device))
    device = select_device(
        device='cpu' if not torch.cuda.is_available() else opt.device)
    opt.device = device

    try:
        if opt.input_mode == 'video':
            if opt.id_weight > 0:
                eval_seq(opt=opt,
                         data_loader=data_loader,
                         data_type='mot',
                         result_f_name=result_file_name,
                         save_dir=frame_dir,
                         show_image=False,
                         frame_rate=frame_rate,
                         mode='track')
            else:  # input video, do detection
                eval_seq(opt=opt,
                         data_loader=data_loader,
                         data_type='mot',
                         result_f_name=result_file_name,
                         save_dir=frame_dir,
                         show_image=False,
                         frame_rate=frame_rate,
                         mode='detect')
        else:
            # only for tmp detection evaluation...
            output_dir = '/users/duanyou/c5/results_new/results_all/tmp'
            eval_imgs_output_dets(opt=opt,
                                  data_loader=data_loader,
                                  data_type='mot',
                                  result_f_name=result_file_name,
                                  out_dir=output_dir,
                                  save_dir=frame_dir,
                                  show_image=False)
    except Exception as e:
        logger.info(e)

    if opt.output_format == 'video':
        output_video_path = 'result.mp4'
        if opt.input_mode == 'video':
            if opt.id_weight > 0:
                output_video_path = osp.join(result_root,
                                             f_name + '_track.mp4')
            else:
                output_video_path = osp.join(result_root, f_name + '_det.mp4')
        elif opt.input_mode == 'image_dir':
            output_video_path = osp.join(result_root, f_name + '_det.mp4')
        cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -b 5000k -c:v mpeg4 {}' \
            .format(osp.join(result_root, 'frame'), output_video_path)
        print(cmd_str)
        os.system(cmd_str)
Beispiel #10
0
def main(opt,
         data_root='/data/MOT16/train',
         det_root=None,
         seqs=('MOT16-05', ),
         exp_name='demo',
         save_images=False,
         save_videos=False,
         show_image=True):
    """
    """

    logger.setLevel(logging.INFO)
    result_root = os.path.join(data_root, '..', 'results', exp_name)
    mkdir_if_missing(result_root)
    data_type = 'mot'

    # run tracking
    accs = []
    n_frame = 0
    timer_avgs, timer_calls = [], []
    for seq in seqs:
        output_dir = os.path.join(data_root, '..', 'outputs', exp_name,
                                  seq) if save_images or save_videos else None
        logger.info('start seq: {}'.format(seq))
        dataloader = datasets.LoadImages(osp.join(data_root, seq, 'img1'),
                                         opt.img_size)
        result_filename = os.path.join(result_root, '{}.txt'.format(seq))
        meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()
        frame_rate = int(meta_info[meta_info.find('frameRate') +
                                   10:meta_info.find('\nseqLength')])
        nf, ta, tc = eval_seq(opt,
                              dataloader,
                              data_type,
                              result_filename,
                              save_dir=output_dir,
                              show_image=show_image,
                              frame_rate=frame_rate)
        n_frame += nf
        timer_avgs.append(ta)
        timer_calls.append(tc)

        # eval
        logger.info('Evaluate seq: {}'.format(seq))
        evaluator = Evaluator(data_root, seq, data_type)
        accs.append(evaluator.eval_file(result_filename))
        if save_videos:
            output_video_path = osp.join(output_dir, '{}.mp4'.format(seq))
            cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -c:v copy {}'.format(
                output_dir, output_video_path)
            os.system(cmd_str)
    timer_avgs = np.asarray(timer_avgs)
    timer_calls = np.asarray(timer_calls)
    all_time = np.dot(timer_avgs, timer_calls)
    avg_time = all_time / np.sum(timer_calls)
    logger.info('Time elapsed: {:.2f} seconds, FPS: {:.2f}'.format(
        all_time, 1.0 / avg_time))

    # get summary
    metrics = mm.metrics.motchallenge_metrics
    mh = mm.metrics.create()
    summary = Evaluator.get_summary(accs, seqs, metrics)
    strsummary = mm.io.render_summary(summary,
                                      formatters=mh.formatters,
                                      namemap=mm.io.motchallenge_metric_names)
    print(strsummary)
    Evaluator.save_summary(
        summary, os.path.join(result_root, 'summary_{}.xlsx'.format(exp_name)))
Beispiel #11
0
def eval_seq(opt,
             data_loader,
             data_type,
             result_f_name,
             save_dir=None,
             show_image=True,
             frame_rate=30,
             mode='track'):
    """
    :param opt:
    :param data_loader:
    :param data_type:
    :param result_f_name:
    :param save_dir:
    :param show_image:
    :param frame_rate:
    :param mode: track or detect
    :return:
    """
    if save_dir:
        mkdir_if_missing(save_dir)

    tracker = JDETracker(opt, frame_rate=frame_rate)

    timer = Timer()

    results_dict = defaultdict(list)

    frame_id = 0  # 帧编号
    for path, img, img_0 in data_loader:
        if frame_id % 20 == 0:
            logger.info('Processing frame {} ({:.2f} fps)'.format(
                frame_id, 1. / max(1e-5, timer.average_time)))

        # --- run tracking
        timer.tic()
        # blob = torch.from_numpy(img).cuda().unsqueeze(0)
        blob = torch.from_numpy(img).to(opt.device).unsqueeze(0)

        if mode == 'track':  # process tracking
            # --- track updates of each frame
            online_targets_dict = tracker.update_tracking(blob, img_0)

            # 聚合每一帧的结果
            online_tlwhs_dict = defaultdict(list)
            online_ids_dict = defaultdict(list)
            for cls_id in range(opt.num_classes):
                # 处理每一个目标检测类
                online_targets = online_targets_dict[cls_id]
                for track in online_targets:
                    tlwh = track.tlwh
                    t_id = track.track_id
                    # vertical = tlwh[2] / tlwh[3] > 1.6  # box宽高比判断:w/h不能超过1.6?
                    if tlwh[2] * tlwh[
                            3] > opt.min_box_area:  # and not vertical:
                        online_tlwhs_dict[cls_id].append(tlwh)
                        online_ids_dict[cls_id].append(t_id)

            timer.toc()

            # 保存每一帧的结果
            for cls_id in range(opt.num_classes):
                results_dict[cls_id].append(
                    (frame_id + 1, online_tlwhs_dict[cls_id],
                     online_ids_dict[cls_id]))

            # 绘制每一帧的结果
            if show_image or save_dir is not None:
                if frame_id > 0:
                    online_im: ndarray = vis.plot_tracks(
                        image=img_0,
                        tlwhs_dict=online_tlwhs_dict,
                        obj_ids_dict=online_ids_dict,
                        num_classes=opt.num_classes,
                        frame_id=frame_id,
                        fps=1.0 / timer.average_time)

        elif mode == 'detect':  # process detections
            # update detection results of this frame(or image)
            dets_dict = tracker.update_detection(blob, img_0)

            timer.toc()

            # plot detection results
            if show_image or save_dir is not None:
                online_im = vis.plot_detects(image=img_0,
                                             dets_dict=dets_dict,
                                             num_classes=opt.num_classes,
                                             frame_id=frame_id,
                                             fps=1.0 /
                                             max(1e-5, timer.average_time))
        else:
            print('[Err]: un-recognized mode.')

        # # 可视化中间结果
        # if frame_id > 0:
        #     cv2.imshow('Frame {}'.format(str(frame_id)), online_im)
        #     cv2.waitKey()

        if frame_id > 0:
            # 是否显示中间结果
            if show_image:
                cv2.imshow('online_im', online_im)
            if save_dir is not None:
                cv2.imwrite(
                    os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
                    online_im)

        # 处理完一帧, 更新frame_id
        frame_id += 1

    # 写入最终结果save results
    write_results_dict(result_f_name, results_dict, data_type)

    return frame_id, timer.average_time, timer.calls
Beispiel #12
0
def eval_seq_and_output_dets(opt,
                             data_loader,
                             data_type,
                             result_f_name,
                             out_dir,
                             save_dir=None,
                             show_image=True):
    """
    :param opt:
    :param data_loader:
    :param data_type:
    :param result_f_name:
    :param out_dir:
    :param save_dir:
    :param show_image:
    :return:
    """
    if save_dir:
        mkdir_if_missing(save_dir)

    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)
    else:
        shutil.rmtree(out_dir)
        os.makedirs(out_dir)

    tracker = JDETracker(opt, frame_rate=30)

    timer = Timer()

    results_dict = defaultdict(list)

    frame_id = 0  # 帧编号
    for path, img, img_0 in data_loader:
        if frame_id % 20 == 0:
            logger.info('Processing frame {} ({:.2f} fps)'.format(
                frame_id, 1. / max(1e-5, timer.average_time)))

        # --- run tracking
        timer.tic()
        blob = torch.from_numpy(img).to(opt.device).unsqueeze(0)

        # update detection results of this frame(or image)
        dets_dict = tracker.update_detection(blob, img_0)

        timer.toc()

        # plot detection results
        if show_image or save_dir is not None:
            online_im = vis.plot_detects(image=img_0,
                                         dets_dict=dets_dict,
                                         num_classes=opt.num_classes,
                                         frame_id=frame_id,
                                         fps=1.0 /
                                         max(1e-5, timer.average_time))

        if frame_id > 0:
            # 是否显示中间结果
            if show_image:
                cv2.imshow('online_im', online_im)
            if save_dir is not None:
                cv2.imwrite(
                    os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
                    online_im)

        # ----- 格式化并输出detection结果(txt)到指定目录
        # 格式化
        dets_list = format_dets_dict2dets_list(dets_dict,
                                               w=img_0.shape[1],
                                               h=img_0.shape[0])

        # 输出到指定目录
        out_img_name = os.path.split(path)[-1]
        # if out_img_name == '192.168.1.219_2_2018-02-13_14-46-00-688_3-1518504845.jpg':
        #     print('pause here')
        out_f_name = out_img_name.replace('.jpg', '.txt')
        out_f_path = out_dir + '/' + out_f_name
        with open(out_f_path, 'w', encoding='utf-8') as w_h:
            w_h.write('class prob x y w h total=' + str(len(dets_list)) + '\n')

            for det in dets_list:
                w_h.write('%d %f %f %f %f %f\n' %
                          (det[0], det[1], det[2], det[3], det[4], det[5]))
        # print('{} written'.format(out_f_path))

        # 处理完一帧, 更新frame_id
        frame_id += 1
    print('Total {:d} detection result output.\n'.format(frame_id))

    # 写入最终结果save results
    write_results_dict(result_f_name, results_dict, data_type)

    # 返回结果
    return frame_id, timer.average_time, timer.calls
Beispiel #13
0
def eval_seq(opt,
             data_loader,
             data_type,
             result_f_name,
             save_dir=None,
             show_image=True,
             frame_rate=30,
             mode='track'):
    """
    :param opt:
    :param data_loader:
    :param data_type:
    :param result_f_name:
    :param save_dir:
    :param show_image:
    :param frame_rate:
    :param mode: track or detect
    :return:
    """
    if save_dir:
        mkdir_if_missing(save_dir)

    # tracker = JDETracker(opt, frame_rate)
    tracker = MCJDETracker(opt, frame_rate)

    timer = Timer()

    results_dict = defaultdict(list)

    frame_id = 0  # frame index
    for path, img, img0 in data_loader:
        if frame_id % 30 == 0 and frame_id != 0:
            logger.info('Processing frame {} ({:.2f} fps)'.format(
                frame_id, 1.0 / max(1e-5, timer.average_time)))

        # --- run tracking
        blob = torch.from_numpy(img).unsqueeze(0).to(opt.device)

        if mode == 'track':  # process tracking
            # ----- track updates of each frame
            timer.tic()

            online_targets_dict = tracker.update_tracking(blob, img0)

            timer.toc()
            # -----

            # collect current frame's result
            online_tlwhs_dict = defaultdict(list)
            online_ids_dict = defaultdict(list)
            online_scores_dict = defaultdict(list)
            for cls_id in range(opt.num_classes):  # process each class id
                online_targets = online_targets_dict[cls_id]
                for track in online_targets:
                    tlwh = track.tlwh
                    t_id = track.track_id
                    score = track.score
                    if tlwh[2] * tlwh[
                            3] > opt.min_box_area:  # and not vertical:
                        online_tlwhs_dict[cls_id].append(tlwh)
                        online_ids_dict[cls_id].append(t_id)
                        online_scores_dict[cls_id].append(score)

            # collect result
            for cls_id in range(opt.num_classes):
                results_dict[cls_id].append(
                    (frame_id + 1, online_tlwhs_dict[cls_id],
                     online_ids_dict[cls_id], online_scores_dict[cls_id]))

            # draw track/detection
            if show_image or save_dir is not None:
                if frame_id > 0:
                    online_im: ndarray = vis.plot_tracks(
                        image=img0,
                        tlwhs_dict=online_tlwhs_dict,
                        obj_ids_dict=online_ids_dict,
                        num_classes=opt.num_classes,
                        frame_id=frame_id,
                        fps=1.0 / timer.average_time)

        elif mode == 'detect':  # process detections
            timer.tic()

            # update detection results of this frame(or image)
            dets_dict = tracker.update_detection(blob, img0)

            timer.toc()

            # plot detection results
            if show_image or save_dir is not None:
                online_im = vis.plot_detects(image=img0,
                                             dets_dict=dets_dict,
                                             num_classes=opt.num_classes,
                                             frame_id=frame_id,
                                             fps=1.0 /
                                             max(1e-5, timer.average_time))
        else:
            print('[Err]: un-recognized mode.')

        if frame_id > 0:
            if show_image:
                cv2.imshow('online_im', online_im)
            if save_dir is not None:
                cv2.imwrite(
                    os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
                    online_im)

        # update frame id
        frame_id += 1

    # write track/detection results
    write_results_dict(result_f_name, results_dict, data_type)

    return frame_id, timer.average_time, timer.calls
Beispiel #14
0
def eval_seq(opt,
             data_loader,
             data_type,
             result_f_name,
             save_dir=None,
             show_image=True,
             frame_rate=30):
    """
    对序列进行跟踪
    :param opt:
    :param data_loader:
    :param data_type:
    :param result_f_name:
    :param save_dir:
    :param show_image:
    :param frame_rate:
    :return:
    """
    if save_dir:
        mkdir_if_missing(save_dir)

    tracker = JDETracker(opt, frame_rate=frame_rate)

    timer = Timer()

    results_dict = defaultdict(list)

    frame_id = 0  # 帧编号
    for path, img, img0 in data_loader:
        if frame_id % 20 == 0:
            logger.info('Processing frame {} ({:.2f} fps)'.format(
                frame_id, 1. / max(1e-5, timer.average_time)))

        # --- run tracking
        timer.tic()
        blob = torch.from_numpy(img).cuda().unsqueeze(0)

        # --- 输出结果的核心函数: 更新跟踪状态
        online_targets_dict = tracker.update(blob, img0)

        # 聚合每一帧的结果
        online_tlwhs_dict = defaultdict(list)
        online_ids_dict = defaultdict(list)
        for cls_id in range(opt.num_classes):
            # 处理每一个目标检测类
            online_targets = online_targets_dict[cls_id]
            for track in online_targets:
                tlwh = track.tlwh
                t_id = track.track_id
                # vertical = tlwh[2] / tlwh[3] > 1.6  # box宽高比判断:w/h不能超过1.6?
                if tlwh[2] * tlwh[3] > opt.min_box_area:  # and not vertical:
                    online_tlwhs_dict[cls_id].append(tlwh)
                    online_ids_dict[cls_id].append(t_id)

        timer.toc()

        # 保存每一帧的结果
        for cls_id in range(opt.num_classes):
            results_dict[cls_id].append(
                (frame_id + 1, online_tlwhs_dict[cls_id],
                 online_ids_dict[cls_id]))

        # 绘制每一帧的结果
        if show_image or save_dir is not None:
            if frame_id > 0:
                online_im = vis.plot_tracks(image=img0,
                                            tlwhs_dict=online_tlwhs_dict,
                                            obj_ids_dict=online_ids_dict,
                                            num_classes=opt.num_classes,
                                            frame_id=frame_id,
                                            fps=1.0 / timer.average_time)

        # # 可视化中间结果
        # if frame_id > 0:
        #     cv2.imshow('Frame {}'.format(str(frame_id)), online_im)
        #     cv2.waitKey()

        if frame_id > 0:
            # 是否显示中间结果
            if show_image:
                cv2.imshow('online_im', online_im)
            if save_dir is not None:
                cv2.imwrite(
                    os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
                    online_im)

        # 处理完一帧, 更新frame_id
        frame_id += 1

    # 写入最终结果save results
    write_results_dict(result_f_name, results_dict, data_type)

    return frame_id, timer.average_time, timer.calls