Example #1
0
 def setUp(self):
     self.deep_sort_config = get_config(
         "./annolid/configs/deep_sort.yaml")
     self.custom_dataset_config = get_config(
         "./annolid/configs/custom_dataset.yaml")
     self.keypoints_config = get_config(
         "./annolid/configs/keypoints.yaml"
     )
Example #2
0
def get_keypoint_connection_rules(keypoint_cfg_file=None):
    """Keypoint connection rules defined in the config file. 

    Args:
        keypoint_cfg_file (str, optional): a yaml config file. Defaults to None.

    Returns:
        [(tuples),]: [(body_part_1,body_part2,(225,255,0))]
    """
    if keypoint_cfg_file is None:
        keypoint_cfg_file = Path(__file__).parent.parent / \
            'configs' / 'keypoints.yaml'

    keypoints_connection_rules = []
    if keypoint_cfg_file.exists():
        key_points_rules = get_config(
            str(keypoint_cfg_file)
        )
        # color is a placehold for future customization
        for k, v in key_points_rules['HEAD'].items():
            v, r, g, b = v.split(',')
            color = (int(r), int(g), int(b))
            keypoints_connection_rules.append((k, v, color))
        for k, v in key_points_rules['BODY'].items():
            v, r, g, b = v.split(',')
            color = (int(r), int(g), int(b))
            keypoints_connection_rules.append((k, v, color))
    return (keypoints_connection_rules,
            key_points_rules['NAME'],
            key_points_rules['EVENTS'],
            key_points_rules['ZONES'])
Example #3
0
def build_tracker(cfg_file="./configs/deep_sort.yaml",
                  use_cuda=None):
    if use_cuda is None:
        use_cuda = torch.cuda.is_available()
    cfg = get_config(cfg_file)
    return DeepSort(cfg.DEEPSORT.REID_CKPT,
                    max_dist=cfg.DEEPSORT.MAX_DIST,
                    min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,
                    nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP,
                    max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,
                    max_age=cfg.DEEPSORT.MAX_AGE,
                    n_init=cfg.DEEPSORT.N_INIT,
                    nn_budget=cfg.DEEPSORT.NN_BUDGET,
                    use_cuda=use_cuda)
Example #4
0
def build_detector(cfg_file="./configs/yolov3_tiny.yaml",
                   use_cuda=None):
    if use_cuda is None:
        use_cuda = torch.cuda.is_available()
    cfg = get_config(cfg_file)
    yolo_version = 3 if '3' in os.path.basename(cfg_file) else 5
    if yolo_version == 3:
        from .YOLOv3 import YOLOv3
        return YOLOv3(cfg.YOLOV3.CFG,
                      cfg.YOLOV3.WEIGHT,
                      cfg.YOLOV3.CLASS_NAMES,
                      score_thresh=cfg.YOLOV3.SCORE_THRESH,
                      nms_thresh=cfg.YOLOV3.NMS_THRESH,
                      is_xywh=True,
                      use_cuda=use_cuda)
Example #5
0
def get_dataset(name, image_set, transform, data_path, num_classes=91):
    paths = {
        "coco": (data_path, get_coco, num_classes),
        "coco_kp": (data_path, get_coco_kp, 2)
    }

    p, dataset_func, num_classes = paths[name]
    dataset = dataset_func(p, image_set=image_set, transforms=transform)
    config_file = os.path.join(data_path, 'data.yaml')

    if os.path.isfile(config_file):
        custom_config = get_config(config_file)
        num_classes = len(custom_config.DATASET.class_names) + 1

    return dataset, num_classes
Example #6
0
def set_cfg(config_name: str):
    """ Sets the active config. Works even if cfg is already imported! """
    global cfg

    # Note this can
    # be used like ssd300_config.copy({'max_size': 400}) for extreme fine-tuning
    if os.path.isfile(config_name):
        from annolid.utils.config import get_config
        custom_config = get_config(config_name)
        custom_dataset = dataset_base.copy({
            'name':
            custom_config.DATASET.name,
            'train_info':
            custom_config.DATASET.train_info,
            'train_images':
            custom_config.DATASET.train_images,
            'valid_info':
            custom_config.DATASET.valid_info,
            'valid_images':
            custom_config.DATASET.valid_images,
            'class_names':
            custom_config.DATASET.class_names,
        })
        yolact_resnet50_custom_config = yolact_resnet50_config.copy({
            'name':
            custom_config.YOLACT.name,
            'dataset':
            custom_dataset,
            'name_classes':
            len(custom_dataset.class_names) + 1,
            'max_size':
            custom_config.YOLACT.max_size
        })
        cfg.name = custom_config.YOLACT.name
        cfg.replace(yolact_resnet50_custom_config)
    else:
        cfg.replace(eval(config_name))

    if cfg.name is None:
        cfg.name = config_name.split('_config')[0]
Example #7
0
def track(video_file=None,
          name="YOLOV5",
          weights=None
          ):
    points = [deque(maxlen=30) for _ in range(1000)]

    if name == "YOLOV5":
         # avoid installing pytorch
        # if the user only wants to use it for
        # extract frames
        # maybe there is a better way to do this
        import torch
        from annolid.detector.yolov5.detect import detect
        from annolid.utils.config import get_config
        cfg = get_config("./configs/yolov5s.yaml")
        from annolid.detector.yolov5.utils.general import strip_optimizer

        opt = cfg
        if weights is not None:
            opt.weights = weights
        opt.source = video_file

        with torch.no_grad():
            if opt.update:  # update all models (to fix SourceChangeWarning)
                for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
                    detect(opt, points=points)
                    strip_optimizer(opt.weights)
            else:
                detect(opt, points=points)
                strip_optimizer(opt.weights)
    else:
        from annolid.tracker import build_tracker
        from annolid.detector import build_detector
        from annolid.utils.draw import draw_boxes
        if not (os.path.isfile(video_file)):
            print("Please provide a valid video file")
        detector = build_detector()
        class_names = detector.class_names

        cap = cv2.VideoCapture(video_file)

        ret, prev_frame = cap.read()
        deep_sort = build_tracker()

        while ret:
            ret, frame = cap.read()
            if not ret:
                print("Finished tracking.")
                break
            im = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            bbox_xywh, cls_conf, cls_ids = detector(im)
            bbox_xywh[:, 3:] *= 1.2
            mask = cls_ids == 0
            cls_conf = cls_conf[mask]

            outputs = deep_sort.update(bbox_xywh, cls_conf, im)

            if len(outputs) > 0:
                bbox_xyxy = outputs[:, :4]
                identities = outputs[:, -1]
                frame = draw_boxes(frame,
                                   bbox_xyxy,
                                   identities,
                                   draw_track=True,
                                   points=points
                                   )

            cv2.imshow("Frame", frame)

            key = cv2.waitKey(1)
            if key == 27:
                break

            prev_frame = frame

        cv2.destroyAllWindows()
        cap.release()