示例#1
0
    def __init__(
        self, instance_type: str, ndds_ann_obj: NDDS_Annotation_Object, instance_name: str=None, contained_instance_list: List[ObjectInstance]=None
    ):
        super().__init__()
        
        # Required
        if instance_type.startswith('bbox') and len(instance_type.replace('bbox', '')) > 0:
            if instance_type.replace('bbox', '').isdigit():
                self.part_num = int(instance_type.replace('bbox', ''))
                self.instance_type = 'bbox'
            else:
                logger.error(f'Part number must be a string that can be converted to an integer.')
                logger.error(f'Valid example: bbox0')
                logger.error(f'Invalid example: bboxzero')
                raise Exception
        elif instance_type.startswith('seg') and len(instance_type.replace('seg', '')) > 0:
            if instance_type.replace('seg', '').isdigit():
                self.part_num = int(instance_type.replace('seg', ''))
                self.instance_type = 'seg'
            else:
                logger.error(f'Part number must be a string that can be converted to an integer.')
                logger.error(f'Valid example: seg0')
                logger.error(f'Invalid example: segzero')
                raise Exception
        else:
            check_value(instance_type, valid_value_list=['bbox', 'seg', 'kpt'])
            self.instance_type = instance_type
            self.part_num = None
        self.ndds_ann_obj = ndds_ann_obj

        # Optional
        self.instance_name = instance_name
        self.contained_instance_list = contained_instance_list if contained_instance_list is not None else []
示例#2
0
 def _check_valid(shape_type: str, points: Point2D_List):
     check_value(shape_type,
                 valid_value_list=[
                     'polygon', 'rectangle', 'circle', 'line', 'point',
                     'linestrip'
                 ])
     if shape_type == 'polygon':
         if len(points) < 3:
             logger.error(f'Labelme polygon requires at least 3 points.')
             raise Exception
     elif shape_type == 'rectangle':
         if len(points) != 2:
             logger.error(f'Labelme rectangle requires exactly 2 points.')
             raise Exception
     elif shape_type == 'circle':
         if len(points) != 2:
             logger.error(f'Labelme circle requires exactly 2 points.')
             raise Exception
     elif shape_type == 'line':
         if len(points) != 2:
             logger.error(f'Labelme line requires exactly 2 points.')
             raise Exception
     elif shape_type == 'point':
         if len(points) != 1:
             logger.error(f'Labelme point requires exactly 1 points.')
             raise Exception
     elif shape_type == 'linestrip':
         if len(points) < 2:
             logger.error(f'Labelme linestrip requires at least 2 points.')
             raise Exception
示例#3
0
 def _init_batch(self, usage: str, color: Tuple[int]):
     check_value(usage, valid_value_list=['static', 'dynamic', 'stream'])
     n_rows = self.grid_height // self.tile_height
     n_cols = self.grid_width // self.tile_width
     horizontal_vertex_lists = []
     for row in range(n_rows + 1):
         y = row * self.tile_height
         vertex_list = self.batch.add_indexed(
             2, gl.GL_LINES, None, [0, 1],
             (f'v2f/{usage}',
              (0,
               (self.grid_origin_y + y) % self.grid_height, self.grid_width,
               (self.grid_origin_y + y) % self.grid_height)),
             (f'c3B/{usage}', tuple(list(color) * 2)))
         horizontal_vertex_lists.append(vertex_list)
     vertical_vertex_lists = []
     for col in range(n_cols + 1):
         x = col * self.tile_width
         vertex_list = self.batch.add_indexed(
             2, gl.GL_LINES, None, [0, 1],
             (f'v2f/{usage}',
              ((self.grid_origin_x + x) % self.grid_width, 0,
               (self.grid_origin_x + x) % self.grid_width,
               self.grid_height)), (f'c3B/{usage}', tuple(list(color) * 2)))
         vertical_vertex_lists.append(vertex_list)
     self._vertex_list_grid.append(horizontal_vertex_lists)
     self._vertex_list_grid.append(vertical_vertex_lists)
示例#4
0
 def get_frame(self, side: str=None) -> np.ndarray:
     frame0, frame1 = self.get_frames()
     if self.direction == 0: # left-right direction
         use_side = None
         if side is not None:
             check_value(item=side, valid_value_list=['left', 'right'])
             use_side = side
         else:
             use_side = 'left'
         if use_side == 'left':
             return frame0
         elif use_side == 'right':
             return frame1
         else:
             raise Exception
     elif self.direction == 1: # top-down direction
         use_side = None
         if side is not None:
             check_value(item=side, valid_value_list=['top', 'down'])
             use_side = side
         else:
             use_side = 'top'
         if use_side == 'top':
             return frame0
         elif use_side == 'right':
             return frame1
         else:
             raise Exception
     else:
         raise Exception
示例#5
0
 def vertex_list(self, n: int, orientation: int = 'horizontal'):
     check_value(orientation, valid_value_list=['horizontal', 'vertical'])
     if orientation == 'horizontal':
         return self._vertex_list_grid[0][n]
     elif orientation == 'vertical':
         return self._vertex_list_grid[1][n]
     else:
         raise Exception
示例#6
0
 def move(self, dx: int, dy: int, mode: str = 'both'):
     check_value(mode, valid_value_list=['world', 'camera', 'both'])
     if mode == 'world':
         self.world_pos.move(dx=dx, dy=dx)
     elif mode == 'camera':
         self.camera_pos.move(dx=dx, dy=dy)
     elif mode == 'both':
         self.world_pos.move(dx=dx, dy=dx)
         self.camera_pos.move(dx=dx, dy=dy)
     else:
         raise Exception
示例#7
0
    def __init__(self, src, scale_factor: float=1.0, direction: int=0):
        """
        direction=0 : left-right direction
        direction=1 : top-bottom direction
        """

        super().__init__(src, scale_factor)
        self.assert_open()
        check_value(item=direction, valid_value_list=[0, 1])
        self.direction = direction
        self.init_dims()
        self.current_left_frame = None
        self.current_right_frame = None
示例#8
0
    def append_contained(self, new_contained_instance: ObjectInstance):
        check_value(self.instance_type, valid_value_list=['bbox', 'seg'])
        check_type(new_contained_instance, valid_type_list=[ObjectInstance])
        check_value(new_contained_instance.instance_type, valid_value_list=['bbox', 'seg', 'kpt'])

        # Check Instance Id
        if new_contained_instance.ndds_ann_obj.instance_id == self.ndds_ann_obj.instance_id:
            logger.error(f'new_contained_instance.ndds_ann_obj.instance_id == self.ndds_ann_obj.instance_id')
            logger.error(f'new_contained_instance: {new_contained_instance}')
            logger.error(f'self: {self}')
            raise Exception
        if new_contained_instance.ndds_ann_obj.instance_id in [
            contained_instance.ndds_ann_obj.instance_id for contained_instance in self.contained_instance_list
        ]:
            logger.error(
                f'new_contained_instance.ndds_ann_obj.instance_id in ' + \
                f'[contained_instance.ndds_ann_obj.instance_id for contained_instance in self.contained_instance_list] == True'
            )
            logger.error(f'new_contained_instance: {new_contained_instance}')
            logger.error(f'self: {self}')
            raise Exception
        # Check (instance_type, instance_name) pair
        if (new_contained_instance.instance_type, new_contained_instance.instance_name) in [
            (contained_instance.instance_type, contained_instance.instance_name) for contained_instance in self.contained_instance_list
        ]:
            logger.error(
                f'(new_contained_instance.instance_type, new_contained_instance.instance_name)=' + \
                f'{(new_contained_instance.instance_type, new_contained_instance.instance_name)} ' + \
                f'pair already exists in self.contained_instance_list'
            )
            logger.error(f'Existing pairs:')
            found_inst = None
            for inst in self.contained_instance_list:
                logger.error(f'\t(inst.instance_type, inst.instance_name)={(inst.instance_type, inst.instance_name)}')
                if (inst.instance_type, inst.instance_name) == (new_contained_instance.instance_type, new_contained_instance.instance_name):
                    found_inst = inst.copy()
            found_inst = ObjectInstance.buffer(found_inst)
            logger.error(f'\n')
            if new_contained_instance.ndds_ann_obj != found_inst.ndds_ann_obj:
                # logger.error(f'new_contained_instance:\n{new_contained_instance}')
                # logger.error(f'found_inst:\n{found_inst}')
                for key in NDDS_Annotation_Object.get_constructor_params():
                    if new_contained_instance.ndds_ann_obj.__dict__[key] != found_inst.ndds_ann_obj.__dict__[key]:
                        logger.error(f'Found difference in key={key}')
                        logger.error(f'\tnew_contained_instance.ndds_ann_obj.__dict__[{key}]:\n\t{new_contained_instance.ndds_ann_obj.__dict__[key]}')
                        logger.error(f'\tfound_inst.ndds_ann_obj.__dict__[{key}]:\n\t{found_inst.ndds_ann_obj.__dict__[key]}')
            else:
                logger.error(f"The two instance's ndds_ann_obj are identical.")
            raise Exception
        
        self.contained_instance_list.append(new_contained_instance)
示例#9
0
 def create_streamer(self,
                     src: str,
                     mode: str = 'mono',
                     scale_factor: float = 1.0,
                     verbose: bool = False) -> StreamerObject:
     if verbose: logger.info(f"Creating Streamer for src={src}")
     check_value(item=mode, valid_value_list=['mono', 'dual'])
     check_file_exists(src)
     if mode == 'mono':
         streamer = Streamer(src=src, scale_factor=scale_factor)
     elif mode == 'dual':
         streamer = DualStreamer(src=src,
                                 scale_factor=scale_factor,
                                 direction=0)
     else:
         raise Exception
     return streamer
示例#10
0
 def _init_batch(self, usage: str, color_seq: List[Tuple[int]]):
     check_value(usage, valid_value_list=['static', 'dynamic', 'stream'])
     n_rows = self.grid_height // self.tile_height
     n_cols = self.grid_width // self.tile_width
     for row in range(n_rows):
         vertex_list_row = []
         for col in range(n_cols):
             x = col * self.tile_width
             y = row * self.tile_height
             color = color_seq[(row * n_cols + col) % len(color_seq)]
             vertex_list = self.batch.add_indexed(
                 4, gl.GL_TRIANGLES, None, [0, 1, 2, 1, 2, 3],
                 (f'v2f/{usage}',
                  (x, y, x, y + self.tile_height, x + self.tile_width, y,
                   x + self.tile_width, y + self.tile_height)),
                 (f'c3B/{usage}', tuple(list(color) * 4)))
             vertex_list_row.append(vertex_list)
         self._vertex_list_grid.append(vertex_list_row)
示例#11
0
    def __init__(self,
                 x: int,
                 y: int,
                 width: int,
                 height: int,
                 color: Tuple[int] = (255, 0, 0),
                 transparency: int = 255,
                 usage: str = 'dynamic'):
        gl.glEnable(gl.GL_BLEND)
        gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA)

        self.x = x
        self.y = y
        self.width = width
        self.height = height
        self.color = color
        self.transparency = transparency
        check_value(usage, valid_value_list=['static', 'dynamic', 'stream'])
        self.usage = usage
        self.vertex_list = self.__get_indexed_vertex_list()
示例#12
0
    def _get_ann_format_list(self, dataset_paths: list, ann_format) -> list:
        check_type(item=ann_format, valid_type_list=[str, list])

        if type(ann_format) is str:
            check_value(item=ann_format,
                        valid_value_list=self.valid_ann_formats)
            ann_format_list = [ann_format] * len(dataset_paths)
        elif type(ann_format) is list:
            check_value_from_list(item_list=ann_format,
                                  valid_value_list=self.valid_ann_formats)
            if len(ann_format) == len(dataset_paths):
                check_type_from_list(item_list=ann_format,
                                     valid_type_list=[str])
                ann_format_list = ann_format
            else:
                logger.error(
                    f"type(ann_format) is list but len(ann_format) == {len(ann_format)} != {len(dataset_paths)} == len(dataset_paths)"
                )
                raise Exception
        else:
            raise Exception
        return ann_format_list
示例#13
0
    def check_valid_config(self, collection_dict_list: list):
        check_type(item=collection_dict_list, valid_type_list=[list])
        for i, collection_dict in enumerate(collection_dict_list):
            check_type(item=collection_dict, valid_type_list=[dict])
            check_value_from_list(item_list=list(collection_dict.keys()),
                                  valid_value_list=self.main_required_keys)
            for required_key in self.main_required_keys:
                if required_key not in collection_dict.keys():
                    logger.error(
                        f"collection_dict at index {i} is missing required key: {required_key}"
                    )
                    raise Exception
            collection_dir = collection_dict['collection_dir']
            check_type(item=collection_dir, valid_type_list=[str])
            dataset_names = collection_dict['dataset_names']
            check_type(item=dataset_names, valid_type_list=[list])
            check_type_from_list(item_list=dataset_names,
                                 valid_type_list=[str])
            dataset_specific = collection_dict['dataset_specific']
            check_type(item=dataset_specific, valid_type_list=[dict])
            check_value_from_list(item_list=list(dataset_specific.keys()),
                                  valid_value_list=self.specific_required_keys)
            for required_key in self.specific_required_keys:
                if required_key not in dataset_specific.keys():
                    logger.error(
                        f"dataset_specific at index {i} is missing required key: {required_key}"
                    )
                    raise Exception
            img_dir = dataset_specific['img_dir']
            ann_path = dataset_specific['ann_path']
            ann_format = dataset_specific['ann_format']
            check_type_from_list(item_list=[img_dir, ann_path, ann_format],
                                 valid_type_list=[str, list])
            if type(img_dir) is list and len(img_dir) != len(dataset_names):
                logger.error(f"Length mismatch at index: {i}")
                logger.error(
                    f"type(img_dir) is list but len(img_dir) == {len(img_dir)} != {len(dataset_names)} == len(dataset_names)"
                )
                raise Exception
            if type(ann_path) is list and len(ann_path) != len(dataset_names):
                logger.error(f"Length mismatch at index: {i}")
                logger.error(
                    f"type(ann_path) is list but len(ann_path) == {len(ann_path)} != {len(dataset_names)} == len(dataset_names)"
                )
                raise Exception
            if type(ann_format) is list and len(ann_format) != len(
                    dataset_names):
                logger.error(f"Length mismatch at index: {i}")
                logger.error(
                    f"type(ann_format) is list but len(ann_format) == {len(ann_format)} != {len(dataset_names)} == len(dataset_names)"
                )
                raise Exception

            if type(ann_format) is str:
                check_value(item=ann_format,
                            valid_value_list=self.valid_ann_formats)
            elif type(ann_format) is list:
                check_value_from_list(item_list=ann_format,
                                      valid_value_list=self.valid_ann_formats)
            else:
                raise Exception
示例#14
0
def infer(
    path: str,
    weights_path: str,
    thresh: int = 0.5,
    key: str = 'R',
    infer_dump_dir: str = '',
    model: str = 'mask_rcnn_R_50_FPN_1x',
    size: int = 1024,
    class_names: List[str] = ['hook'],
    gt_path:
    str = '/home/jitesh/3d/data/coco_data/hook_test/json/cropped_hook.json'):
    # class_names=['hook', 'pole']
    # class_names=['hook']
    conf_thresh = 0.001
    show_bbox_border = True
    gt_dataset = COCO_Dataset.load_from_path(json_path=gt_path)
    inferer_seg = inferer(
        weights_path=weights_path,
        confidence_threshold=0.1,
        # num_classes=1,
        # num_classes=2,
        class_names=class_names,
        # class_names=['hook'],
        model='keypoint_rcnn_R_50_FPN_1x',
        # model='faster_rcnn_X_101_32x8d_FPN_3x',
        # model='faster_rcnn_R_101_FPN_3x',
        # model=model,
    )
    inferer_seg.cfg.INPUT.MIN_SIZE_TEST = size
    inferer_seg.cfg.INPUT.MAX_SIZE_TEST = size
    inferer_seg.cfg.MODEL.MASK_ON = True

    weights_path = '/home/jitesh/3d/data/coco_data/hook_sim_real_data7/weights/Keypoints_R_50_1x_aug_cm_seg_val_1/model_0009999.pth'
    weights_path = '/home/jitesh/3d/data/coco_data/hook_sim_real_data7_0.1/weights/Keypoints_R_50_1x_aug_cm_seg_val_3/model_0009999.pth'
    weights_path = '/home/jitesh/3d/data/coco_data/hook_sim_real_data7_0.1/weights/Keypoints_R_50_1x_aug_cm_seg_val_1/model_0007999.pth'
    weights_path = '/home/jitesh/3d/data/coco_data/hook_sim_real_data8/weights/Keypoints_R_50_1x_aug_key_seg_val_1/model_0009999.pth'
    weights_path = '/home/jitesh/3d/data/coco_data/hook_sim_real_data8/weights/Keypoints_R_50_1x_aug_key_seg_val_2/model_0004999.pth'
    # inferer_key = jDetectron2KeypointInferer(
    #     weights_path=weights_path,
    #     # ref_coco_ann_path=f'/home/jitesh/3d/data/coco_data/hook_real1/json/hook.json',
    #     # categories_path=f'/home/jitesh/3d/data/categories/hook_infer.json',
    #     # categories_path=f'/home/jitesh/3d/data/categories/hook_7ckpt.json',
    #     categories_path=f'/home/jitesh/3d/data/categories/hook_7ckpt_pole.json',
    #     target_category='hook',
    #     model_name='keypoint_rcnn_R_50_FPN_1x',
    #     bbox_threshold=bbox_thresh,
    #     kpt_threshold=kpt_thresh,
    #     key_box='hook',
    # )
    # k_size = 1024
    # inferer_key.cfg.INPUT.MIN_SIZE_TEST = k_size
    # inferer_key.cfg.INPUT.MAX_SIZE_TEST = k_size

    possible_modes = ['save', 'preview']
    mode = 'save'
    check_value(mode, valid_value_list=possible_modes)
    # make_dir_if_not_exists(infer_dump_dir)
    img_extensions = ['jpg', 'JPG', 'png', 'PNG']
    img_pathlist = get_all_files_in_extension_list(
        dir_path=f'{path}', extension_list=img_extensions)
    img_pathlist.sort()

    confirm_folder(infer_dump_dir, mode)
    # confirm_folder(f'{infer_dump_dir}/good_seg', mode)
    # confirm_folder(f'{infer_dump_dir}/good_cropped', mode)
    # confirm_folder(f'{infer_dump_dir}/good', mode)
    # confirm_folder(f'{infer_dump_dir}/G(>4D) P(>4D)', mode)
    # confirm_folder(f'{infer_dump_dir}/G(>4D) P(<4D)', mode)
    # confirm_folder(f'{infer_dump_dir}/G(<4D) P(>4D)', mode)
    # confirm_folder(f'{infer_dump_dir}/G(<4D) P(<4D)', mode)
    # confirm_folder(f'{infer_dump_dir}/bad', mode)
    confirm_folder(f'{infer_dump_dir}/infer_key_seg', mode)

    count = 0
    start = datetime.now()
    df = pd.DataFrame(data=[],
                      columns=[
                          'gt_d',
                          'pred_d',
                          'gt_ab',
                          'pred_ab',
                          'gt_ratio',
                          'pred_ratio',
                          'gt_ratio>4',
                          'pred_ratio>4',
                          'correct_above4d_ratio',
                          'incorrect_above4d_ratio',
                          'correct_below4d_ratio',
                          'incorrect_below4d_ratio',
                      ])
    #  'image_path'])
    for i, img_path in enumerate(tqdm(
            img_pathlist,
            desc='Writing images',
    )):
        img_filename = get_filename(img_path)
        # if not '201005_70_縮小革命PB020261.jpg' in img_path:
        #     continue
        # if i > 19:
        #     continue
        printj.purple(img_path)
        img = cv2.imread(img_path)
        result = img
        # print(f'shape {img.shape}')
        # cv2.imshow('i', img)
        # cv2.waitKey(100000)
        # continue
        score_list, pred_class_list, bbox_list, pred_masks_list, pred_keypoints_list, vis_keypoints_list, kpt_confidences_list = inferer_seg.predict(
            img=img)
        # printj.blue(pred_masks_list)
        max_hook_score = -1
        max_pole_score = -1
        diameter = -1
        len_ab = -1
        found_hook = False
        found_pole = False
        for score, pred_class, bbox, mask, keypoints, vis_keypoints, kpt_confidences in zip(
                score_list, pred_class_list, bbox_list, pred_masks_list,
                pred_keypoints_list, vis_keypoints_list, kpt_confidences_list):

            if pred_class == 'pole':
                found_pole = True
                if max_pole_score < score:
                    # if True:
                    max_pole_score = score
                    diameter = compute_diameter(mask)
                    # result = draw_bool_mask(img=result, mask=mask, color=[
                    #                     0, 255, 255],
                    #                     transparent=True
                    #                     )
                    pole_bbox_text = f'pole {str(round(score, 2))}'
                    pole_bbox = bbox
                    pole_mask = mask
                    # result = draw_bbox(img=result, bbox=bbox,
                    #                    text=pole_bbox_text, label_only=not show_bbox_border, label_orientation='bottom')
                    printj.blue(f'diameter={diameter}')
            if pred_class == 'hook':
                # printj.green.bold_on_yellow(score)
                found_hook = True
                if max_hook_score < score:
                    # if True:
                    max_hook_score = score
                    hook_bbox = BBox.buffer(bbox)
                    hook_score = round(score, 2)
                    hook_mask = mask
                    hook_keypoints = keypoints
                    hook_vis_keypoints = vis_keypoints
                    hook_kpt_confidences = kpt_confidences
                    # xmin, ymin, xmax, ymax = bbox.to_int().to_list()
                    # _xmin, _ymin, _xmax, _ymax = _bbox.to_int().to_list()
                    # width = _xmax-_xmin
                    # height = _ymax-_ymin
                    # scale = 0.2
                    # xmin = max(int(_xmin - width*scale), 0)
                    # xmax = min(int(_xmax + width*scale), img.shape[1])
                    # ymin = max(int(_ymin - height*scale), 0)
                    # ymax = min(int(_ymax + height*scale), img.shape[0])

                    # printj.red(score)
                    # printj.red(bbox)
                    # return
                    # img = draw_bbox(img=img, bbox=_bbox, color=[
                    #                 0, 255, 255], thickness=2, text=f"{pred_class} {round(score, 3)}",
                    #                 label_orientation='top')
                    # img = draw_bbox(img=img, bbox=_bbox, color=[
                    #                 0, 255, 255], thickness=2, text=f"{pred_class} {round(score, 3)}",
                    #                 label_orientation='bottom')
                    # result = draw_bool_mask(img=result, mask=mask, color=[
                    #     255, 255, 0],
                    #     transparent=True
                    # )
                    # result = result
                    # bbox_text = str(round(score, 4))
                    # result = draw_bbox(img=result, bbox=bbox,
                    #                    text=bbox_text, label_only=not show_bbox_border)
                    bbox_label_mode = 'euler'
                    # result = draw_keypoints(
                    #     img=result, keypoints=vis_keypoints, radius=2, color=[0, 0, 255],
                    #     # keypoint_labels=kpt_labels, show_keypoints_labels=True, label_thickness=1,
                    #     # ignore_kpt_idx=conf_idx_list
                    #     )
                    kpt_labels = [
                        "kpt-a", "kpt-b", "kpt-cb", "kpt-c", "kpt-cd", "kpt-d",
                        "kpt-e"
                    ]
                    kpt_skeleton = [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5],
                                    [5, 6]]
                    conf_idx_list = np.argwhere(
                        np.array(kpt_confidences) > conf_thresh).reshape(-1)
                    not_conf_idx_list = np.argwhere(
                        np.array(kpt_confidences) <= conf_thresh).reshape(
                            -1).astype(int)
                    conf_keypoints, conf_kpt_labels = np.array(vis_keypoints)[
                        conf_idx_list], np.array(kpt_labels)[conf_idx_list]
                    not_conf_keypoints, not_conf_kpt_labels = np.array(
                        vis_keypoints)[not_conf_idx_list], np.array(
                            kpt_labels)[not_conf_idx_list]
                    cleaned_keypoints = np.array(vis_keypoints.copy()).astype(
                        np.float32)
                    # result = draw_bool_mask(img=result, mask=mask, color=[
                    #     255, 255, 0],
                    #     transparent=True
                    # )
                    # result, len_ab = draw_inference_on_hook2(img=result, cleaned_keypoints=cleaned_keypoints, kpt_labels=kpt_labels, kpt_skeleton=kpt_skeleton,
                    #                                         score=score, bbox=_bbox, vis_keypoints=vis_keypoints, kpt_confidences=kpt_confidences, conf_idx_list=conf_idx_list, not_conf_idx_list=not_conf_idx_list,
                    #                                         conf_keypoints=conf_keypoints, conf_kpt_labels=conf_kpt_labels, not_conf_keypoints=not_conf_keypoints, not_conf_kpt_labels=not_conf_kpt_labels,
                    #                                         conf_thresh=conf_thresh, show_bbox_border=show_bbox_border, bbox_label_mode=bbox_label_mode, index_offset=0, diameter=diameter)
                    # result=result
                    # printj.green(_bbox)
                    # printj.green(_bbox.to_int())
                    # printj.green(_bbox.to_int().to_list())
        printj.green.on_white(max_hook_score)
        if found_pole:
            result = draw_bool_mask(img=result,
                                    mask=pole_mask,
                                    color=[0, 255, 255],
                                    transparent=True)
            result = draw_bbox(img=result,
                               bbox=pole_bbox,
                               text=pole_bbox_text,
                               label_only=not show_bbox_border,
                               label_orientation='top')
            result = draw_bbox(img=result,
                               bbox=pole_bbox,
                               text=pole_bbox_text,
                               label_only=not show_bbox_border,
                               label_orientation='bottom')
        if found_hook:
            result = draw_bool_mask(img=result,
                                    mask=hook_mask,
                                    color=[255, 255, 0],
                                    transparent=True)
            result, len_ab = draw_inference_on_hook2(
                img=result,
                cleaned_keypoints=cleaned_keypoints,
                kpt_labels=kpt_labels,
                kpt_skeleton=kpt_skeleton,
                score=hook_score,
                bbox=hook_bbox,
                vis_keypoints=hook_vis_keypoints,
                kpt_confidences=hook_kpt_confidences,
                conf_idx_list=conf_idx_list,
                not_conf_idx_list=not_conf_idx_list,
                conf_keypoints=conf_keypoints,
                conf_kpt_labels=conf_kpt_labels,
                not_conf_keypoints=not_conf_keypoints,
                not_conf_kpt_labels=not_conf_kpt_labels,
                conf_thresh=conf_thresh,
                show_bbox_border=show_bbox_border,
                bbox_label_mode=bbox_label_mode,
                index_offset=0,
                diameter=diameter)
        printj.purple(len_ab)
        if len_ab == 0:
            printj.green(keypoints)
        result = draw_info_box(result, len_ab, diameter)
        #                 img: np.ndarray, cleaned_keypoints, kpt_labels: List[str], kpt_skeleton: List[list],
        # score: float, bbox: BBox, vis_keypoints: list, kpt_confidences: list, conf_idx_list: list, not_conf_idx_list: list,
        # conf_keypoints, conf_kpt_labels, not_conf_keypoints, not_conf_kpt_labels,
        # conf_thresh: float = 0.3, show_bbox_border: bool = False, bbox_label_mode: str = 'euler', index_offset: int = 0, diameter=1

        # cv2.imshow('i', result)
        # # cv2.imwrite('i', result)
        # cv2.waitKey(10000)
        # quit_flag = cv_simple_image_viewer(img=result, preview_width=1000)
        # if quit_flag:
        #     break

        # cv2.imwrite(f"{infer_dump_dir}/good_seg/{img_filename}", result)
        cv2.imwrite(f"{infer_dump_dir}/infer_key_seg/{img_filename}", result)
示例#15
0
 def set_value(self, x: int, y: int, collision: int, x_scale: float=1.0, y_scale: float=1.0):
     check_value(collision, valid_value_list=[0, 1])
     super().set_value(x=x, y=y, x_scale=x_scale, y_scale=y_scale, value=collision)
示例#16
0
 def press(self, direction: str):
     check_value(direction, valid_value_list=self._valid_directions)
     self.buffer.append(direction)
示例#17
0
 def release(self, direction: str):
     check_value(direction, valid_value_list=self._valid_directions)
     while direction in self.buffer:
         idx = self.buffer.index(direction)
         del self.buffer[idx]