예제 #1
0
def tracks2nix(video_file=None,
               tracking_results='tracking.csv',
               out_nix_csv_file='my_glitter_format.csv',
               zone_info='zone_info.json',
               overlay_mask=True,
               score_threshold=None,
               motion_threshold=None,
               deep=False,
               pretrained_model=None,
               subject_names=None,
               behavior_names=None):
    """
    Args:
        video_file (str): video file path. Defaults to None.
        tracking_results (str, optional): the tracking results csv file froma a model.
         Defaults to 'tracking.csv'.
        out_nix_csv_file (str, optional): [description]. Defaults to 'my_glitter_format.csv'.
        zone_info ([type], optional): a comma seperated string e.g.
           "0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0". Defaults to None.
        overlay_mask (bool): Overal mask or not. Defaults to True. 
        score_threshold (float): the class score threshold between 0.0 to 1.0 to display the segmentation. 
        motion_threshold (float): threshold for motion between frames. defaults 0. 
        deep (bool): use deep learning based motion model. defaults to False.
        pretrained_model (str): path to the trained motion model. defaults to None.
        subject_name (str): a list of comma seperated subject names like vole_01,vole_02,...
        behavior_names (str): a list of comma seperated behavior names like rearing,walking,... 

    Create a nix format csv file and annotated video
    """

    print(f"Class or Instance score threshold is: {score_threshold}.")

    print(f"Please update the definitions of keypoints, instances, and events")
    keypoint_cfg_file = Path(__file__).parent.parent / \
        'configs' / 'keypoints.yaml'
    open_or_start_file(keypoint_cfg_file)

    if zone_info and '.json' in zone_info:
        zone_info = Path(zone_info)
    elif zone_info == 'zone_info.json':
        zone_info = Path(__file__).parent / zone_info

    _class_meta_data = draw.get_keypoint_connection_rules()
    keypoints_connection_rules, animal_names, behaviors, zones_names = _class_meta_data

    if subject_names is not None:
        animal_names = f"{animal_names} {' '.join(subject_names.split(','))}"
    if behavior_names is not None:
        behaviors = f"{behaviors} {' '.join(behavior_names.split(','))}"

    _animal_object_list = animal_names.split()

    # here we assume and use the first subject from the config file
    subject_animal_name = _animal_object_list[0]
    left_interact_object = _animal_object_list[1]
    right_interact_object = _animal_object_list[2]
    body_parts = [bp for bp in keypoints_connection_rules[0]]

    df_motion = None
    if motion_threshold > 0:
        fa = FreezingAnalyzer(video_file,
                              tracking_results,
                              motion_threshold=motion_threshold)
        if pretrained_model is not None and Path(pretrained_model).exists():
            deep = True
        df_motion = fa.run(deep=deep, pretrained_model=pretrained_model)

    df = pd.read_csv(tracking_results)
    try:
        df = df.drop(columns=['Unnamed: 0'])
    except KeyError:
        print("data frame does not have a column named Unnmaed: 0")

    def get_bbox(frame_number):
        _df = df[df.frame_number == frame_number]
        try:
            res = _df.to_dict(orient='records')
        except:
            res = []
        return res

    def is_freezing(frame_number, instance_name):
        if df_motion is not None:
            freezing = df_motion[(df_motion.frame_number == frame_number) & (
                df_motion.instance_name == instance_name)].freezing.values[0]
            return freezing > 0
        else:
            return False

    def keypoint_in_body_mask(frame_number, keypoint_name, animal_name=None):

        if animal_name is None:
            animal_name = subject_animal_name

        _df_k_b = df[df.frame_number == frame_number]
        try:
            body_seg = _df_k_b[_df_k_b.instance_name ==
                               animal_name]['segmentation'].values[0]
            body_seg = ast.literal_eval(body_seg)
        except IndexError:
            return False

        try:
            keypoint_seg = _df_k_b[_df_k_b.instance_name ==
                                   keypoint_name]['segmentation'].values[0]
            keypoint_seg = ast.literal_eval(keypoint_seg)
        except IndexError:
            return False

        if keypoint_seg and body_seg:
            overlap = mask_util.iou([body_seg], [keypoint_seg],
                                    [False, False]).flatten()[0]
            return overlap > 0
        else:
            return False

    def left_right_interact(fn,
                            subject_instance='subject_vole',
                            left_instance='left_vole',
                            right_instance='right_vole'):
        _df_top = df[df.frame_number == fn]
        right_interact = None
        left_interact = None
        try:
            subject_instance_seg = _df_top[
                _df_top.instance_name ==
                subject_instance]['segmentation'].values[0]
            subject_instance_seg = ast.literal_eval(subject_instance_seg)
        except IndexError:
            return 0.0, 0.0
        try:
            left_instance_seg = _df_top[
                _df_top.instance_name ==
                left_instance]['segmentation'].values[0]
            left_instance_seg = ast.literal_eval(left_instance_seg)
            left_interact = mask_util.iou([left_instance_seg],
                                          [subject_instance_seg],
                                          [False, False]).flatten()[0]
        except IndexError:
            left_interact = 0.0
        try:
            right_instance_seg = _df_top[
                _df_top.instance_name ==
                right_instance]['segmentation'].values[0]
            right_instance_seg = ast.literal_eval(right_instance_seg)
            right_interact = mask_util.iou([right_instance_seg],
                                           [subject_instance_seg],
                                           [False, False]).flatten()[0]
        except IndexError:
            right_interact = 0.0

        return left_interact, right_interact

    cap = cv2.VideoCapture(video_file)

    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    target_fps = int(cap.get(cv2.CAP_PROP_FPS))
    num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # fix left right swtich by checking the middle
    # point of the video frame width
    df['instance_name'] = df.apply(
        lambda row: TracksResults.switch_left_right(row, width), axis=1)

    # find the unique list of all instance names and remove nan
    instance_names = df['instance_name'].dropna().unique()

    # add an instance name to animal names list if not in it
    for instance_name in instance_names:
        if (instance_name not in animal_names
                and (instance_name not in behaviors
                     or instance_name not in zones_names)):
            animal_names += ' ' + instance_name

    metadata_dict = {}
    metadata_dict['filename'] = video_file
    metadata_dict['pixels_per_meter'] = 0
    metadata_dict['video_width'] = f"{width}"
    metadata_dict['video_height'] = f"{height}"
    metadata_dict['saw_all_timestamps'] = 'TRUE'

    zone_dict = {}

    if zone_info is not None and zone_info.suffix != '.json':
        zone_background_dict = {}

        zone_background_dict['zone:background:property'] = ['type', 'points']

        zone_background_dict['zone:background:value'] = ['polygon', zone_info]

        for isn in df['instance_name'].dropna().unique():

            if isn != 'nan' and 'object' in isn:
                zone_dict[f"zone:{isn}:property"] = [
                    'type', 'center', 'radius'
                ]
                zone_dict[f'zone:{isn}:value'] = ['circle', "0, 0", 0]
    elif zone_info and zone_info.exists():
        zone_file = json.loads(zone_info.read_bytes())
        zones = zone_file['shapes']
    else:
        zones = None

    timestamps = {}

    num_grooming = 0
    num_rearing = 0
    num_object_investigation = 0
    num_left_interact = 0
    num_right_interact = 0

    out_video_file = f"{os.path.splitext(video_file)[0]}_tracked.mp4"

    video_writer = cv2.VideoWriter(out_video_file,
                                   cv2.VideoWriter_fourcc(*"mp4v"), target_fps,
                                   (width, height))

    # try mutlple times if opencv cannot read a frame
    for frame_number, frame in enumerate(frame_from_video(cap, num_frames)):
        # timestamp in seconds
        frame_timestamp = cap.get(cv2.CAP_PROP_POS_MSEC) / 1000

        bbox_info = get_bbox(frame_number)

        # calculate left or rgiht interact in the frame
        left_interact, right_interact = left_right_interact(
            frame_number, subject_animal_name, left_interact_object,
            right_interact_object)

        timestamps.setdefault(frame_timestamp, {})
        timestamps[frame_timestamp].setdefault('event:Grooming', 0)
        timestamps[frame_timestamp].setdefault('event:Rearing', 0)
        timestamps[frame_timestamp].setdefault('event:Object_investigation', 0)
        timestamps[frame_timestamp].setdefault('event:RightInteract', 0)
        timestamps[frame_timestamp].setdefault('event:LeftInteract', 0)
        timestamps[frame_timestamp].setdefault('event:Freezing', 0)

        timestamps[frame_timestamp].setdefault('pos:animal_center:x', -1)
        timestamps[frame_timestamp].setdefault('pos:animal_center:y', -1)

        timestamps[frame_timestamp].setdefault('pos:interact_center:x', -1)
        timestamps[frame_timestamp].setdefault('pos:interact_center:y', -1)

        timestamps[frame_timestamp].setdefault('pos:animal_nose:x', -1)
        timestamps[frame_timestamp].setdefault('pos:animal_nose:y', -1)
        timestamps[frame_timestamp].setdefault('pos:animal_:x', -1)
        timestamps[frame_timestamp].setdefault('pos:animal_:y', -1)
        timestamps[frame_timestamp].setdefault('frame_number', frame_number)

        right_zone_box = None
        left_zone_box = None

        if zones:
            for zs in zones:
                zone_box = zs['points']
                zone_label = zs['label']
                zone_box = functools.reduce(operator.iconcat, zone_box, [])
                if 'right' in zone_label.lower():
                    right_zone_box = zone_box
                elif 'left' in zone_label.lower():
                    left_zone_box = zone_box

                # draw masks labeled as zones
                # encode and merge polygons with format [[x1,y1,x2,y2,x3,y3....]]
                try:
                    rles = mask_util.frPyObjects([zone_box], height, width)
                    rle = mask_util.merge(rles)

                    # convert the polygons to mask
                    m = mask_util.decode(rle)
                    frame = draw.draw_binary_masks(frame, [m], [zone_label])
                except:
                    # skip non polygon zones
                    continue

        parts_locations = {}

        timestamps[frame_timestamp]['frame_number'] = frame_number

        for bf in bbox_info:
            _frame_num = bf['frame_number'],
            x1 = bf['x1'],
            y1 = bf['y1'],
            x2 = bf['x2'],
            y2 = bf['y2'],
            _class = bf['instance_name'],
            score = bf['class_score'],
            _mask = bf['segmentation']
            if isinstance(_frame_num, tuple):
                _frame_num = _frame_num[0]
                x1 = x1[0]
                y1 = y1[0]
                x2 = x2[0]
                y2 = y2[0]
                _class = _class[0]
                score = score[0]

            if not pd.isnull(_mask) and overlay_mask:
                if _class not in zones_names:
                    if score >= score_threshold and (_class in animal_names
                                                     or _class.lower()
                                                     in animal_names):
                        _mask = ast.literal_eval(_mask)
                        mask_area = mask_util.area(_mask)
                        _mask = mask_util.decode(_mask)[:, :]
                        frame = draw.draw_binary_masks(frame, [_mask],
                                                       [_class])

            # In glitter, the y-axis is such that the bottom is zero and the top is height.
            # i.e. origin is bottom left
            glitter_y1 = height - y1
            glitter_y2 = height - y2

            if 'right' in str(_class).lower() and 'interact' in _class.lower():
                _class = 'RightInteract'
            elif 'left' in str(
                    _class).lower() and 'interact' in _class.lower():
                _class = 'LeftInteract'

            is_draw = True
            if _class == "RightInteract" and (right_zone_box is not None
                                              and x1 < right_zone_box[0]):
                is_draw = False

            # draw bbox if model predicted with interact and their masks overlaps
            is_draw = is_draw  # and (left_interact > 0 or right_interact > 0)

            if not math.isnan(x1) and _frame_num == frame_number:
                cx = int((x1 + x2) / 2)
                cy_glitter = int((glitter_y1 + glitter_y2) / 2)
                cy = int((y1 + y2) / 2)
                _, color = draw.get_label_color(_class)

                # the first animal
                if keypoint_in_body_mask(_frame_num, _class,
                                         subject_animal_name):
                    parts_locations[_class] = (cx, cy, color)

                if _class == 'nose' or 'nose' in _class.lower():
                    timestamps[frame_timestamp]['pos:animal_nose:x'] = cx
                    timestamps[frame_timestamp][
                        'pos:animal_nose:y'] = cy_glitter
                elif _class == 'centroid' or _class.lower().endswith('mouse') \
                        or _class.lower().endswith('vole'):
                    timestamps[frame_timestamp]['pos:animal_center:x'] = cx
                    timestamps[frame_timestamp][
                        'pos:animal_center:y'] = cy_glitter
                elif _class == 'grooming':
                    timestamps[frame_timestamp]['event:Grooming'] = 1
                    timestamps[frame_timestamp]['pos:animal_:x'] = cx
                    timestamps[frame_timestamp]['pos:animal_:y'] = cy_glitter
                    num_grooming += 1
                elif 'rearing' in _class:
                    timestamps[frame_timestamp]['event:Rearing'] = 1
                    timestamps[frame_timestamp]['pos:animal_:x'] = cx
                    timestamps[frame_timestamp]['pos:animal_:y'] = cy_glitter
                    num_rearing += 1
                elif _class == 'object_investigation':
                    timestamps[frame_timestamp][
                        'event:Object_investigation'] = 1
                    timestamps[frame_timestamp]['pos:animal_:x'] = cx
                    timestamps[frame_timestamp]['pos:animal_:y'] = cy_glitter
                    num_object_investigation += 1
                elif _class == 'LeftInteract' and left_interact > 0:

                    timestamps[frame_timestamp]['pos:interact_center:x'] = cx
                    timestamps[frame_timestamp][
                        'pos:interact_center:y'] = cy_glitter

                    if cx > width / 2:
                        timestamps[frame_timestamp]['event:RightInteract'] = 1
                        num_right_interact += 1
                        _class = "RightInteract"
                    else:
                        timestamps[frame_timestamp]['event:LeftInteract'] = 1
                        num_left_interact += 1
                elif (is_draw and _class == 'RightInteract'
                      and score >= score_threshold and right_interact > 0):
                    timestamps[frame_timestamp]['event:RightInteract'] = 1
                    timestamps[frame_timestamp]['pos:interact_center_:x'] = cx
                    timestamps[frame_timestamp][
                        'pos:interact_center_:y'] = cy_glitter
                    num_right_interact += 1

                elif 'object' in _class.lower(
                ) and _class != 'object_investigation':
                    zone_dict[f'zone:{_class}:value'] = [
                        'circle', [cx, cy_glitter],
                        min(int((x2 - x1) / 2), int(glitter_y2 - glitter_y1))
                    ]
                bbox = [[x1, y1, x2, y2]]

                # only draw behavior with bbox not body parts
                if (is_draw and _class in behaviors
                        and score >= score_threshold
                        and _class not in body_parts
                        and _class not in animal_names):

                    if _class == 'grooming':
                        label = f"{_class}: {num_grooming} times"
                    elif _class == 'rearing':
                        label = f"{_class}: {num_rearing} times"
                    elif _class == "object_investigation":
                        label = f"{_class}: {num_object_investigation} times"
                    elif _class == "LeftInteract":
                        label = f"{_class}: {num_left_interact} times"
                    elif _class == "RightInteract":
                        label = f"{_class}: {num_right_interact} times"
                    elif "rearing" in _class:
                        label = 'rearing'
                    else:
                        label = f"{_class}:{round(score * 100,2)}%"

                    if _class == 'RightInteract' and right_interact <= 0:
                        pass
                    elif _class == 'LeftInteract' and left_interact <= 0:
                        pass
                    else:
                        draw.draw_boxes(frame,
                                        bbox,
                                        identities=[label],
                                        draw_track=False,
                                        points=points)
                elif score >= score_threshold:
                    # draw box center as keypoints
                    # do not draw point center for zones
                    is_keypoint_in_mask = keypoint_in_body_mask(
                        _frame_num, _class, subject_animal_name)
                    if (is_keypoint_in_mask or any(map(str.isdigit, _class))
                            or _class in _animal_object_list):
                        if 'zone' not in _class.lower():
                            cv2.circle(frame, (cx, cy), 6, color, -1)
                        if _class in animal_names and 'zone' not in _class.lower(
                        ):
                            cv2.putText(frame, f"-{_class}:{score*100:.2f}%",
                                        (cx + 3, cy + 3),
                                        cv2.FONT_HERSHEY_SIMPLEX, 0.65, color,
                                        2)

                if (left_interact > 0 and 'left' in _class.lower()
                        and 'interact' in _class.lower()):
                    num_left_interact += 1
                    timestamps[frame_timestamp]['event:LeftInteract'] = 1
                    timestamps[frame_timestamp]['pos:interact_center_:x'] = cx
                    timestamps[frame_timestamp][
                        'pos:interact_center_:y'] = cy_glitter
                    label = f"left interact:{num_left_interact} times"
                    draw.draw_boxes(frame,
                                    bbox,
                                    identities=[label],
                                    draw_track=False,
                                    points=points)
                if (right_interact > 0 and 'right' in _class.lower()
                        and 'interact' in _class.lower()):
                    num_right_interact += 1
                    timestamps[frame_timestamp]['event:RightInteract'] = 1
                    timestamps[frame_timestamp]['pos:interact_center_:x'] = cx
                    timestamps[frame_timestamp][
                        'pos:interact_center_:y'] = cy_glitter
                    label = f"right interact:{num_right_interact} times"
                    draw.draw_boxes(frame,
                                    bbox,
                                    identities=[label],
                                    draw_track=False,
                                    points=points)

                freezing = is_freezing(_frame_num, _class)
                if freezing:
                    timestamps[frame_timestamp]['event:Freezing'] = 1
                    draw.draw_boxes(frame,
                                    bbox,
                                    identities=['freezing'],
                                    draw_track=False,
                                    points=points)

                if _class in behaviors:
                    cv2.rectangle(frame, (5, 35), (5 + 140, 35 + 35),
                                  (0, 0, 0), -1)
                    cv2.putText(frame, f"{_class}", (15, 55),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.65,
                                (255, 255, 255), 2)

                if _class in zones_names:
                    draw.draw_boxes(frame,
                                    bbox,
                                    identities=[_class],
                                    draw_track=False,
                                    points=points)

        # draw the lines between predefined keypoints
        draw.draw_keypoint_connections(frame, parts_locations,
                                       keypoints_connection_rules)

        cv2.putText(frame, f"Timestamp: {frame_timestamp}", (25, 25),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.65, (255, 255, 255), 2)

        cv2.imshow("Frame", frame)
        video_writer.write(frame)

        key = cv2.waitKey(1)
        if key == 27:
            break

    cv2.destroyAllWindows()
    cap.release()
    video_writer.release()

    # save a NIX format CSV file for Glitter2
    df_res = pd.DataFrame.from_dict(timestamps, orient='index')
    df_res.index.rename('timestamps', inplace=True)

    df_meta = pd.DataFrame.from_dict(metadata_dict, orient='index')

    if zone_info is not None and zone_info.suffix != '.json':
        df_zone_background = pd.DataFrame.from_dict(zone_background_dict)

    if zone_dict:
        df_zone = pd.DataFrame.from_dict(zone_dict)

    df_res.reset_index(inplace=True)
    df_meta.reset_index(inplace=True)
    df_meta.columns = ['metadata', 'value']
    df_res.insert(0, "metadata", df_meta['metadata'])
    df_res.insert(1, "value", df_meta['value'])

    if zone_info is not None and zone_info.suffix != '.json':
        df_res = pd.concat([df_res, df_zone_background], axis=1)

    if zone_dict:
        df_res = pd.concat([df_res, df_zone], axis=1)

    df_res.to_csv(out_nix_csv_file, index=False)
예제 #2
0
def detect(opt, save_img=False, tracking=True, points=None):
    out, source, weights, view_img, save_txt, imgsz = \
        opt.output, opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
    webcam = source.isnumeric() or source.startswith(
        'rtsp') or source.startswith('http') or source.endswith('.txt')

    # Initialize
    set_logging()
    device = select_device(opt.device)
    if os.path.exists(out):
        shutil.rmtree(out)  # delete output folder
    os.makedirs(out)  # make new output folder
    half = device.type != 'cpu'  # half precision only supported on CUDA

    # Load model
    model = attempt_load(weights, map_location=device)  # load FP32 model
    imgsz = check_img_size(imgsz, s=model.stride.max())  # check img_size
    if half:
        model.half()  # to FP16

    # Second-stage classifier
    classify = False
    if classify:
        modelc = load_classifier(name='resnet101', n=2)  # initialize
        modelc.load_state_dict(
            torch.load('weights/resnet101.pt',
                       map_location=device)['model'])  # load weights
        modelc.to(device).eval()

    # Set Dataloader
    vid_path, vid_writer = None, None
    if webcam:
        view_img = True
        cudnn.benchmark = True  # set True to speed up constant image size inference
        dataset = LoadStreams(source, img_size=imgsz)
    else:
        save_img = True
        dataset = LoadImages(source, img_size=imgsz)

    # Get names and colors
    names = model.module.names if hasattr(model, 'module') else model.names
    colors = [[random.randint(0, 255) for _ in range(3)]
              for _ in range(len(names))]

    # Run inference
    t0 = time.time()
    img = torch.zeros((1, 3, imgsz, imgsz), device=device)  # init img
    # run once
    _ = model(img.half() if half else img) if device.type != 'cpu' else None
    for path, img, im0s, vid_cap in dataset:
        img = torch.from_numpy(img).to(device)
        img = img.half() if half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # Inference
        t1 = time_synchronized()
        pred = model(img, augment=opt.augment)[0]

        # Apply NMS
        pred = non_max_suppression(pred,
                                   opt.conf_thres,
                                   opt.iou_thres,
                                   classes=opt.classes,
                                   agnostic=opt.agnostic_nms)
        t2 = time_synchronized()

        # Apply Classifier
        if classify:
            pred = apply_classifier(pred, modelc, img, im0s)
        # Process detections
        for i, det in enumerate(pred):  # detections per image
            if webcam:  # batch_size >= 1
                p, s, im0 = path[i], '%g: ' % i, im0s[i].copy()
            else:
                p, s, im0 = path, '', im0s

            save_path = str(Path(out) / Path(p).name)
            txt_path = str(Path(out) / Path(p).stem) + (
                '_%g' % dataset.frame if dataset.mode == 'video' else '')
            s += '%gx%g ' % img.shape[2:]  # print string
            # normalization gain whwh
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]
            if det is not None and len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4],
                                          im0.shape).round()

                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += '%g %ss, ' % (n, names[int(c)])  # add to string

                if tracking:
                    xywh = xyxy2xywh((det[:, :4]).cpu().numpy())
                    conf = (det[:, -2]).cpu().numpy()
                    outputs = tracker.update(xywh, conf, im0)
                    if len(outputs) > 0:
                        bbox_xyxy = outputs[:, :4]
                        identities = outputs[:, -1]
                        im0 = draw_boxes(im0,
                                         bbox_xyxy,
                                         identities,
                                         draw_track=True,
                                         points=points)

                for *xyxy, conf, cls in reversed(det):
                    if save_txt:  # Write to file
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) /
                                gn).view(-1).tolist()  # normalized xywh
                        with open(txt_path + '.txt', 'a') as f:
                            f.write(('%g ' * 5 + '\n') %
                                    (cls, *xywh))  # label format

                    if save_img or view_img:  # Add bbox to image
                        label = f"{names[int(cls)]}:{conf: .2f}"
                        plot_one_box(xyxy,
                                     im0,
                                     label=label,
                                     color=colors[int(cls)],
                                     line_thickness=3,
                                     tracking=tracking)

            # Print time (inference + NMS)
            print('%sDone. (%.3fs)' % (s, t2 - t1))

            # Stream results
            if view_img:
                cv2.imshow(p, im0)
                if cv2.waitKey(1) == ord('q'):  # q to quit
                    raise StopIteration

            # Save results (image with detections)
            if save_img:
                if dataset.mode == 'images':
                    cv2.imwrite(save_path, im0)
                else:
                    if vid_path != save_path:  # new video
                        vid_path = save_path
                        if isinstance(vid_writer, cv2.VideoWriter):
                            vid_writer.release(
                            )  # release previous video writer

                        fourcc = 'mp4v'  # output video codec
                        fps = vid_cap.get(cv2.CAP_PROP_FPS)
                        w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                        h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                        vid_writer = cv2.VideoWriter(
                            save_path, cv2.VideoWriter_fourcc(*fourcc), fps,
                            (w, h))
                    vid_writer.write(im0)

    if save_txt or save_img:
        print('Results saved to %s' % Path(out))
        if platform.system() == 'Darwin' and not opt.update:  # MacOS
            os.system('open ' + save_path)

    print('Done. (%.3fs)' % (time.time() - t0))
예제 #3
0
    def run(self, deep=False, pretrained_model=None):
        video = cv2.VideoCapture(self.video_file)
        width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
        frames_per_second = video.get(cv2.CAP_PROP_FPS)
        num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
        ret, frame1 = video.read()
        frame_number = int(video.get(cv2.CAP_PROP_POS_FRAMES))
        prvs = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY)
        prvs = cv2.blur(prvs, (5, 5))
        prvs_instances = self.instances(frame_number)
        hsv = np.zeros_like(frame1)
        hsv[..., 1] = 255
        self.motion_model = deformation.build_model(
            vol_shape=prvs.shape, pretrained_model=pretrained_model)

        video_writer = self.get_video_writer(self.video_file,
                                             frames_per_second, width, height)

        instance_status = dict((
            i,
            0,
        ) for i in (self.instance_names()))

        while video.isOpened():
            ret, frame2 = video.read()
            frame_number = int(video.get(cv2.CAP_PROP_POS_FRAMES))
            current_instances = self.instances(frame_number - 1)
            df_prvs_cur = pd.merge(prvs_instances,
                                   current_instances,
                                   how='inner',
                                   on='instance_name')

            try:
                df_prvs_cur['mask_iou'] = df_prvs_cur.apply(self.mask_iou,
                                                            axis=1)
            except ValueError as ve:
                df_prvs_cur['mask_iou'] = 0.0

            if ret:
                next = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY)
                next = cv2.blur(next, (5, 5))
                flow = self.motion(prvs, next, deep)
                mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
                hsv[..., 0] = ang * 180 / np.pi / 2
                hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
                bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
                for index, _row in df_prvs_cur.iterrows():

                    if _row.segmentation_y:
                        _mask = ast.literal_eval(_row.segmentation_y)
                        _mask = mask_util.decode(_mask)[:, :]
                        mask_motion = np.sum(_mask * mag) / np.sum(_mask)
                        self.motion_values.append(
                            (_row.frame_number_x, _row.instance_name,
                             mask_motion, _row.mask_iou,
                             int(instance_status[_row.instance_name] > 0)))
                        bgr = draw.draw_binary_masks(bgr, [_mask],
                                                     [_row.instance_name])
                    else:
                        self.motion_values.append(
                            (_row.frame_number_x, _row.instance_name, 0.0,
                             _row.mask_iou,
                             int(instance_status[_row.instance_name] > 0)))
                        print("No mask")

                    if _row.instance_name == self.target_instance:
                        draw.draw_boxes(
                            bgr,
                            [[_row.x1_y, _row.y1_y, _row.x2_y, _row.y2_y]],
                            identities=[f"Motion: {mask_motion:.2f}"],
                            draw_track=False)

                    if _row.mask_iou >= self.iou_threshold and mask_motion < self.motion_threshold:
                        instance_status[_row.instance_name] += 1

                        if instance_status[_row.instance_name] >= 5:
                            draw.draw_boxes(
                                bgr,
                                [[_row.x1_y, _row.y1_y, _row.x2_y, _row.y2_y]],
                                identities=[f"Motion: {mask_motion:.2f}"],
                                draw_track=False)
                    else:
                        instance_status[_row.instance_name] -= 3

                dst = cv2.addWeighted(frame1, 1, bgr, 1, 0)
                if not deep:
                    dst = draw.draw_flow(dst, flow)
                cv2.imshow('frame2', dst)
                video_writer.write(dst)
                k = cv2.waitKey(30) & 0xff
                if k == 27:
                    break
                prvs = next
                frame1 = frame2
                prvs_instances = current_instances
            else:
                break

        video.release()
        video_writer.release()
        cv2.destroyAllWindows()

        df_motion = pd.DataFrame(self.motion_values)
        df_motion.columns = [
            "frame_number", 'instance_name', 'motion_index', 'mask_iou',
            'freezing'
        ]
        df_res = pd.merge(self.tracking_results,
                          df_motion,
                          how='left',
                          on=['frame_number', 'instance_name'])
        tracking_results_motion = self.tracking_results_name.replace(
            ".csv", '_motion.csv')
        df_res.to_csv(tracking_results_motion, index=False)

        return df_res
예제 #4
0
파일: videos.py 프로젝트: shamavir/annolid
def track(video_file=None,
          name="YOLOV5",
          weights=None
          ):
    points = [deque(maxlen=30) for _ in range(1000)]

    if name == "YOLOV5":
         # avoid installing pytorch
        # if the user only wants to use it for
        # extract frames
        # maybe there is a better way to do this
        import torch
        from annolid.detector.yolov5.detect import detect
        from annolid.utils.config import get_config
        cfg = get_config("./configs/yolov5s.yaml")
        from annolid.detector.yolov5.utils.general import strip_optimizer

        opt = cfg
        if weights is not None:
            opt.weights = weights
        opt.source = video_file

        with torch.no_grad():
            if opt.update:  # update all models (to fix SourceChangeWarning)
                for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
                    detect(opt, points=points)
                    strip_optimizer(opt.weights)
            else:
                detect(opt, points=points)
                strip_optimizer(opt.weights)
    else:
        from annolid.tracker import build_tracker
        from annolid.detector import build_detector
        from annolid.utils.draw import draw_boxes
        if not (os.path.isfile(video_file)):
            print("Please provide a valid video file")
        detector = build_detector()
        class_names = detector.class_names

        cap = cv2.VideoCapture(video_file)

        ret, prev_frame = cap.read()
        deep_sort = build_tracker()

        while ret:
            ret, frame = cap.read()
            if not ret:
                print("Finished tracking.")
                break
            im = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            bbox_xywh, cls_conf, cls_ids = detector(im)
            bbox_xywh[:, 3:] *= 1.2
            mask = cls_ids == 0
            cls_conf = cls_conf[mask]

            outputs = deep_sort.update(bbox_xywh, cls_conf, im)

            if len(outputs) > 0:
                bbox_xyxy = outputs[:, :4]
                identities = outputs[:, -1]
                frame = draw_boxes(frame,
                                   bbox_xyxy,
                                   identities,
                                   draw_track=True,
                                   points=points
                                   )

            cv2.imshow("Frame", frame)

            key = cv2.waitKey(1)
            if key == 27:
                break

            prev_frame = frame

        cv2.destroyAllWindows()
        cap.release()