Exemple #1
0
def predict_video(detector,
                  video_file,
                  threshold,
                  output_dir,
                  save_images=True,
                  save_mot_txts=True,
                  draw_center_traj=False,
                  secs_interval=10,
                  do_entrance_counting=False,
                  camera_id=-1):
    video_name = 'mot_output.mp4'
    if camera_id != -1:
        capture = cv2.VideoCapture(camera_id)
    else:
        capture = cv2.VideoCapture(video_file)
        video_name = os.path.split(video_file)[-1]

    # Get Video info : resolution, fps, frame count
    width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(capture.get(cv2.CAP_PROP_FPS))
    frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
    print("fps: %d, frame_count: %d" % (fps, frame_count))

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    out_path = os.path.join(output_dir, video_name)
    if not save_images:
        video_format = 'mp4v'
        fourcc = cv2.VideoWriter_fourcc(*video_format)
        writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
    frame_id = 0
    timer = MOTTimer()
    results = defaultdict(list)  # support single class and multi classes
    num_classes = detector.num_classes
    data_type = 'mcmot' if num_classes > 1 else 'mot'
    ids2names = detector.pred_config.labels
    center_traj = None
    entrance = None
    records = None
    if draw_center_traj:
        center_traj = [{} for i in range(num_classes)]

    if num_classes == 1:
        id_set = set()
        interval_id_set = set()
        in_id_list = list()
        out_id_list = list()
        prev_center = dict()
        records = list()
        entrance = [0, height / 2., width, height / 2.]

    video_fps = fps

    while (1):
        ret, frame = capture.read()
        if not ret:
            break
        timer.tic()
        online_tlwhs, online_scores, online_ids = detector.predict([frame],
                                                                   threshold)
        timer.toc()

        for cls_id in range(num_classes):
            results[cls_id].append((frame_id + 1, online_tlwhs[cls_id],
                                    online_scores[cls_id], online_ids[cls_id]))

        fps = 1. / timer.duration
        # NOTE: just implement flow statistic for one class
        if num_classes == 1:
            result = (frame_id + 1, online_tlwhs[0], online_scores[0],
                      online_ids[0])
            statistic = flow_statistic(result, secs_interval,
                                       do_entrance_counting, video_fps,
                                       entrance, id_set, interval_id_set,
                                       in_id_list, out_id_list, prev_center,
                                       records, data_type, num_classes)
            id_set = statistic['id_set']
            interval_id_set = statistic['interval_id_set']
            in_id_list = statistic['in_id_list']
            out_id_list = statistic['out_id_list']
            prev_center = statistic['prev_center']
            records = statistic['records']

        elif num_classes > 1 and do_entrance_counting:
            raise NotImplementedError(
                'Multi-class flow counting is not implemented now!')
        im = plot_tracking_dict(frame,
                                num_classes,
                                online_tlwhs,
                                online_ids,
                                online_scores,
                                frame_id=frame_id,
                                fps=fps,
                                ids2names=ids2names,
                                do_entrance_counting=do_entrance_counting,
                                entrance=entrance,
                                records=records,
                                center_traj=center_traj)

        if save_images:
            save_dir = os.path.join(output_dir, video_name.split('.')[-2])
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
                        im)
        else:
            writer.write(im)

        frame_id += 1
        print('detect frame: %d, fps: %f' % (frame_id, fps))
        if camera_id != -1:
            cv2.imshow('Tracking Detection', im)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
    if save_mot_txts:
        result_filename = os.path.join(output_dir,
                                       video_name.split('.')[-2] + '.txt')

        write_mot_results(result_filename, results, data_type, num_classes)

        if num_classes == 1:
            result_filename = os.path.join(
                output_dir,
                video_name.split('.')[-2] + '_flow_statistic.txt')
            f = open(result_filename, 'w')
            for line in records:
                f.write(line)
            print('Flow statistic save in {}'.format(result_filename))
            f.close()

    if save_images:
        save_dir = os.path.join(output_dir, video_name.split('.')[-2])
        cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(
            save_dir, out_path)
        os.system(cmd_str)
        print('Save video in {}.'.format(out_path))
    else:
        writer.release()
def predict_video(detector,
                  reid_model,
                  video_file,
                  scaled,
                  threshold,
                  output_dir,
                  save_images=True,
                  save_mot_txts=True,
                  draw_center_traj=False,
                  secs_interval=10,
                  do_entrance_counting=False,
                  camera_id=-1):
    video_name = 'mot_output.mp4'
    if camera_id != -1:
        capture = cv2.VideoCapture(camera_id)
    else:
        capture = cv2.VideoCapture(video_file)
        video_name = os.path.split(video_file)[-1]

    # Get Video info : resolution, fps, frame count
    width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(capture.get(cv2.CAP_PROP_FPS))
    frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
    print("fps: %d, frame_count: %d" % (fps, frame_count))

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    out_path = os.path.join(output_dir, video_name)
    if not save_images:
        video_format = 'mp4v'
        fourcc = cv2.VideoWriter_fourcc(*video_format)
        writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
    frame_id = 0
    timer = MOTTimer()
    results = defaultdict(list)
    id_set = set()
    interval_id_set = set()
    in_id_list = list()
    out_id_list = list()
    prev_center = dict()
    records = list()
    entrance = [0, height / 2., width, height / 2.]
    video_fps = fps

    while (1):
        ret, frame = capture.read()
        if not ret:
            break
        timer.tic()
        ori_image_shape = list(frame.shape[:2])
        pred_dets, pred_xyxys = detector.predict([frame], ori_image_shape,
                                                 threshold, scaled)

        if len(pred_dets) == 1 and np.sum(pred_dets) == 0:
            print('Frame {} has no object, try to modify score threshold.'.
                  format(frame_id))
            timer.toc()
            im = frame
        else:
            # reid process
            crops = reid_model.get_crops(pred_xyxys, frame)
            tracking_outs = reid_model.predict(crops, pred_dets)

            online_tlwhs = tracking_outs['online_tlwhs']
            online_scores = tracking_outs['online_scores']
            online_ids = tracking_outs['online_ids']

            results[0].append(
                (frame_id + 1, online_tlwhs, online_scores, online_ids))
            # NOTE: just implement flow statistic for one class
            result = (frame_id + 1, online_tlwhs, online_scores, online_ids)
            statistic = flow_statistic(result, secs_interval,
                                       do_entrance_counting, video_fps,
                                       entrance, id_set, interval_id_set,
                                       in_id_list, out_id_list, prev_center,
                                       records)
            id_set = statistic['id_set']
            interval_id_set = statistic['interval_id_set']
            in_id_list = statistic['in_id_list']
            out_id_list = statistic['out_id_list']
            prev_center = statistic['prev_center']
            records = statistic['records']

            timer.toc()

            fps = 1. / timer.duration
            im = plot_tracking(frame,
                               online_tlwhs,
                               online_ids,
                               online_scores,
                               frame_id=frame_id,
                               fps=fps,
                               do_entrance_counting=do_entrance_counting,
                               entrance=entrance)

        if save_images:
            save_dir = os.path.join(output_dir, video_name.split('.')[-2])
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            cv2.imwrite(os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
                        im)
        else:
            writer.write(im)

        frame_id += 1
        print('detect frame:%d, fps: %f' % (frame_id, fps))

        if camera_id != -1:
            cv2.imshow('Tracking Detection', im)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    if save_mot_txts:
        result_filename = os.path.join(output_dir,
                                       video_name.split('.')[-2] + '.txt')
        write_mot_results(result_filename, results)

        result_filename = os.path.join(
            output_dir,
            video_name.split('.')[-2] + '_flow_statistic.txt')
        f = open(result_filename, 'w')
        for line in records:
            f.write(line)
        print('Flow statistic save in {}'.format(result_filename))
        f.close()

    if save_images:
        save_dir = os.path.join(output_dir, video_name.split('.')[-2])
        cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(
            save_dir, out_path)
        os.system(cmd_str)
        print('Save video in {}.'.format(out_path))
    else:
        writer.release()
def predict_mtmct_seq(detector,
                      reid_model,
                      mtmct_dir,
                      seq_name,
                      scaled,
                      threshold,
                      output_dir,
                      save_images=True,
                      save_mot_txts=True):
    fpath = os.path.join(mtmct_dir, seq_name)
    if os.path.exists(os.path.join(fpath, 'img1')):
        fpath = os.path.join(fpath, 'img1')

    assert os.path.isdir(fpath), '{} should be a directory'.format(fpath)
    image_list = os.listdir(fpath)
    image_list.sort()
    assert len(image_list) > 0, '{} has no images.'.format(fpath)

    results = defaultdict(list)
    mot_features_dict = {}  # cid_tid_fid feats
    print('Totally {} frames found in seq {}.'.format(len(image_list),
                                                      seq_name))

    for frame_id, img_file in enumerate(image_list):
        if frame_id % 10 == 0:
            print('Processing frame {} of seq {}.'.format(frame_id, seq_name))
        frame = cv2.imread(os.path.join(fpath, img_file))
        ori_image_shape = list(frame.shape[:2])
        frame_path = os.path.join(fpath, img_file)
        pred_dets, pred_xyxys = detector.predict([frame_path], ori_image_shape,
                                                 threshold, scaled)

        if len(pred_dets) == 1 and np.sum(pred_dets) == 0:
            print('Frame {} has no object, try to modify score threshold.'.
                  format(frame_id))
            online_im = frame
        else:
            # reid process
            crops = reid_model.get_crops(pred_xyxys, frame)

            tracking_outs = reid_model.predict(crops,
                                               pred_dets,
                                               MTMCT=True,
                                               frame_id=frame_id,
                                               seq_name=seq_name)

            feat_data_dict = tracking_outs['feat_data']
            mot_features_dict = dict(mot_features_dict, **feat_data_dict)

            online_tlwhs = tracking_outs['online_tlwhs']
            online_scores = tracking_outs['online_scores']
            online_ids = tracking_outs['online_ids']

            online_im = plot_tracking(frame, online_tlwhs, online_ids,
                                      online_scores, frame_id)
            results[0].append(
                (frame_id + 1, online_tlwhs, online_scores, online_ids))

        if save_images:
            save_dir = os.path.join(output_dir, seq_name)
            if not os.path.exists(save_dir): os.makedirs(save_dir)
            img_name = os.path.split(img_file)[-1]
            out_path = os.path.join(save_dir, img_name)
            cv2.imwrite(out_path, online_im)

    if save_mot_txts:
        result_filename = os.path.join(output_dir, seq_name + '.txt')
        write_mot_results(result_filename, results)

    return mot_features_dict
Exemple #4
0
    def predict_video(self, video_file, camera_id):
        video_out_name = 'mot_output.mp4'
        if camera_id != -1:
            capture = cv2.VideoCapture(camera_id)
        else:
            capture = cv2.VideoCapture(video_file)
            video_out_name = os.path.split(video_file)[-1]
        # Get Video info : resolution, fps, frame count
        width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(capture.get(cv2.CAP_PROP_FPS))
        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
        print("fps: %d, frame_count: %d" % (fps, frame_count))

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        out_path = os.path.join(self.output_dir, video_out_name)
        video_format = 'mp4v'
        fourcc = cv2.VideoWriter_fourcc(*video_format)
        writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))

        frame_id = 1
        timer = MOTTimer()
        results = defaultdict(list)  # support single class and multi classes
        num_classes = self.num_classes
        data_type = 'mcmot' if num_classes > 1 else 'mot'
        ids2names = self.pred_config.labels

        center_traj = None
        entrance = None
        records = None
        if self.draw_center_traj:
            center_traj = [{} for i in range(num_classes)]
        if num_classes == 1:
            id_set = set()
            interval_id_set = set()
            in_id_list = list()
            out_id_list = list()
            prev_center = dict()
            records = list()
            entrance = [0, height / 2., width, height / 2.]

        video_fps = fps

        while (1):
            ret, frame = capture.read()
            if not ret:
                break
            if frame_id % 10 == 0:
                print('Tracking frame: %d' % (frame_id))
            frame_id += 1

            timer.tic()
            seq_name = video_out_name.split('.')[0]
            mot_results = self.predict_image([frame],
                                             visual=False,
                                             seq_name=seq_name)
            timer.toc()

            online_tlwhs, online_scores, online_ids = mot_results[0]
            for cls_id in range(num_classes):
                results[cls_id].append(
                    (frame_id + 1, online_tlwhs[cls_id], online_scores[cls_id],
                     online_ids[cls_id]))

            # NOTE: just implement flow statistic for single class
            if num_classes == 1:
                result = (frame_id + 1, online_tlwhs[0], online_scores[0],
                          online_ids[0])
                statistic = flow_statistic(result, self.secs_interval,
                                           self.do_entrance_counting,
                                           video_fps, entrance, id_set,
                                           interval_id_set, in_id_list,
                                           out_id_list, prev_center, records,
                                           data_type, num_classes)
                records = statistic['records']

            fps = 1. / timer.duration
            im = plot_tracking_dict(
                frame,
                num_classes,
                online_tlwhs,
                online_ids,
                online_scores,
                frame_id=frame_id,
                fps=fps,
                ids2names=ids2names,
                do_entrance_counting=self.do_entrance_counting,
                entrance=entrance,
                records=records,
                center_traj=center_traj)

            writer.write(im)
            if camera_id != -1:
                cv2.imshow('Mask Detection', im)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break

        if self.save_mot_txts:
            result_filename = os.path.join(
                self.output_dir,
                video_out_name.split('.')[-2] + '.txt')

            write_mot_results(result_filename, results, data_type, num_classes)

            if num_classes == 1:
                result_filename = os.path.join(
                    self.output_dir,
                    video_out_name.split('.')[-2] + '_flow_statistic.txt')
                f = open(result_filename, 'w')
                for line in records:
                    f.write(line)
                print('Flow statistic save in {}'.format(result_filename))
                f.close()

        writer.release()