def load(self, target, verbose: bool = False): if type(target) is str: extension = get_extension_from_path(target) if extension == 'yaml': check_file_exists(target) loaded_data = yaml.load(open(target, 'r'), Loader=yaml.FullLoader) elif extension == 'json': check_file_exists(target) loaded_data = json.load(open(target, 'r')) else: logger.error(f"Invalid extension: {extension}") logger.error( f"Note that string targets are assumed to be paths.") logger.error(f"Valid file extensions: {self.valid_extensions}") raise Exception elif type(target) is list: loaded_data = target elif type(target) is dict: loaded_data = [target] else: logger.error(f"Invalid target type: {type(target)}") raise Exception self.check_valid_config(collection_dict_list=loaded_data) self.data = loaded_data if verbose: logger.good(f"Dataset path config has been loaded successfully.")
def init_save_dir(self): make_dir_if_not_exists(self._save_dir) if self._clear: delete_all_files_in_dir(self._save_dir, ask_permission=False) else: self._existing_extensions = list(set([get_extension_from_path(path) for path in get_valid_image_paths(self._save_dir)])) self._first = False
def load_from_path(cls, path: str) -> DatasetConfigCollection: check_file_exists(path) extension = get_extension_from_path(path) if extension == 'json': collection_dict = json.load(open(path, 'r')) elif extension == 'yaml': collection_dict = yaml.load(open(path, 'r'), Loader=yaml.FullLoader) else: logger.error(f'Invalid file extension encountered: {extension}') logger.error(f'Path specified: {path}') raise Exception return DatasetConfigCollection.from_dict(collection_dict)
def write_config(self, dest_path: str, verbose: bool = False): extension = get_extension_from_path(dest_path) if extension == 'yaml': yaml.dump(self.data, open(dest_path, 'w'), allow_unicode=True) elif extension == 'json': json.dump(self.data, open(dest_path, 'w'), indent=2, ensure_ascii=False) else: logger.error(f"Invalid extension: {extension}") logger.error(f"Valid file extensions: {self.valid_extensions}") raise Exception if verbose: logger.good( f"Dataset path config written successfully to:\n{dest_path}")
def save_to_path(self, save_path: str, overwrite: bool = False): if file_exists(save_path) and not overwrite: logger.error(f'File already exists at: {save_path}') logger.error(f'Use overwrite=True to overwrite.') raise Exception extension = get_extension_from_path(save_path) if extension == 'json': json.dump(self.to_dict(), open(save_path, 'w'), indent=2, ensure_ascii=False) elif extension == 'yaml': yaml.dump(self.to_dict(), open(save_path, 'w'), allow_unicode=True) else: logger.error(f'Invalid file extension encountered: {extension}') logger.error(f"Valid file extensions: {['json', 'yaml']}") logger.error(f'Path specified: {save_path}') raise Exception
def write_cropped_json(src_img_path: str, src_json_path: str, dst_img_path: str, dst_json_path: str, bound_type='rect', verbose: bool = False): def process_shape(shape: Shape, bbox: BBox, new_shape_handler: ShapeHandler): points = [Point.from_list(point) for point in shape.points] contained_count = 0 for point in points: if bbox.contains(point): contained_count += 1 if contained_count == 0: return elif contained_count == len(points): pass else: logger.error( f"Found a shape that is only partially contained by a bbox.") logger.error(f"Shape: {shape}") logger.error(f"BBox: {bbox}") cropped_points = [ Point(x=point.x - bbox.xmin, y=point.y - bbox.ymin) for point in points ] for point in cropped_points: if point.x < 0 or point.y < 0: logger.error(f"Encountered negative point after crop: {point}") raise Exception new_shape = shape.copy() new_shape.points = [ cropped_point.to_list() for cropped_point in cropped_points ] new_shape_handler.add(new_shape) check_input_path_and_output_dir(input_path=src_img_path, output_path=dst_img_path) check_input_path_and_output_dir(input_path=src_json_path, output_path=dst_json_path) output_img_dir = get_dirpath_from_filepath(dst_img_path) annotation = LabelMeAnnotation(annotation_path=src_img_path, img_dir=dst_img_path, bound_type=bound_type) parser = LabelMeAnnotationParser(annotation_path=src_json_path) parser.load() bbox_list = [] for rect in parser.shape_handler.rectangles: numpy_array = np.array(rect.points) if numpy_array.shape != (2, 2): logger.error( f"Encountered rectangle with invalid shape: {numpy_array.shape}" ) logger.error(f"rect: {rect}") raise Exception xmin, xmax = numpy_array.T[0].min(), numpy_array.T[0].max() ymin, ymax = numpy_array.T[1].min(), numpy_array.T[1].max() bbox_list.append(BBox.from_list([xmin, ymin, xmax, ymax])) img = cv2.imread(src_img_path) img_h, img_w = img.shape[:2] for i, bbox in enumerate(bbox_list): bbox = BBox.buffer(bbox) new_shape_handler = ShapeHandler() for shape_group in [ parser.shape_handler.points, parser.shape_handler.rectangles, parser.shape_handler.polygons ]: for shape in shape_group: process_shape(shape=shape, bbox=bbox, new_shape_handler=new_shape_handler) new_shape_list = new_shape_handler.to_shape_list() if len(new_shape_list) > 0: img_rootname, json_rootname = get_rootname_from_path( dst_img_path), get_rootname_from_path(dst_json_path) dst_img_dir, dst_json_dir = get_dirpath_from_filepath( dst_img_path), get_dirpath_from_filepath(dst_json_path) dst_img_extension = get_extension_from_path(dst_img_path) dst_cropped_img_path = f"{dst_img_dir}/{img_rootname}_{i}.{dst_img_extension}" dst_cropped_json_path = f"{dst_json_dir}/{json_rootname}_{i}.json" write_cropped_image(src_path=src_img_path, dst_path=dst_cropped_img_path, bbox=bbox, verbose=verbose) cropped_labelme_ann = annotation.copy() cropped_labelme_ann.annotation_path = dst_cropped_json_path cropped_labelme_ann.img_dir = dst_img_dir cropped_labelme_ann.img_path = dst_cropped_img_path cropped_img = cv2.imread(dst_cropped_img_path) cropped_img_h, cropped_img_w = cropped_img.shape[:2] cropped_labelme_ann.img_height = cropped_img_h cropped_labelme_ann.img_width = cropped_img_w cropped_labelme_ann.shapes = new_shape_list cropped_labelme_ann.shape_handler = new_shape_handler writer = LabelMeAnnotationWriter(cropped_labelme_ann) writer.write() if verbose: logger.info(f"Wrote {dst_cropped_json_path}")