def create_multiclass_mask(image_shape, results: dict, config: Config = None):
    """
    Creates an image containing all the masks where pixel color is the mask's class ID
    :param image_shape: the shape of the initial image
    :param results: the results dictionary containing all the masks
    :param config: the config object used to expand mini_masks if enabled
    :return: the multi-mask image
    """
    res = np.zeros((image_shape[0], image_shape[1]), np.uint8)

    masks = results['masks']
    class_ids = results['class_ids']
    rois = results['rois']
    indices = np.arange(len(class_ids))

    classes_hierarchy = config.get_classes_hierarchy()
    if classes_hierarchy is None:
        levels = [[i + 1 for i in range(len(config.get_classes_info()))]]
    else:
        levels = utils.remove_redundant_classes(
            utils.classes_level(classes_hierarchy), keepFirst=False)

    for lvl in levels:
        current_indices = indices[np.isin(class_ids, lvl)]
        for idx in current_indices:
            mask = masks[:, :, idx].astype(bool).astype(np.uint8) * 255
            roi = rois[idx]
            classID = int(class_ids[idx])
            if config is not None and config.is_using_mini_mask():
                shifted_bbox = utils.shift_bbox(roi)
                mask = utils.expand_mask(shifted_bbox, mask, shifted_bbox[2:])
            res = apply_mask(res, mask, classID, 1, roi)
    return res
def mask_to_class_histogram(results: dict,
                            classes: dict,
                            config: Config = None,
                            count_zeros=True):
    """
    Gather all histograms into a general one that looses the information of 'which base mask contains which masks'
    :param results: the results containing per mask histograms
    :param classes: dict that link previous classes to current classes that we want to count
    :param config: the config
    :param count_zeros: if True, base masks without included masks will be counted
    :return: the global histogram
    """
    if config is None:
        return
    selectedClasses = {}
    if "all" in classes.values() or any(
        ["all" in classes[c] for c in classes if type(classes[c]) is list]):
        selectedClasses.update(
            {c['display_name']: c['id']
             for c in config.get_classes_info()})
    else:
        tempClasses = []
        for aClass in classes.values():
            if type(aClass) is list:
                tempClasses.extend(aClass)
            else:
                tempClasses.append(aClass)
        selectedClasses.update({
            c['display_name']: c['id']
            for c in config.get_classes_info() if c['name'] in tempClasses
        })

    histogram = {c: {} for c in selectedClasses.keys()}
    for mask_histogram in results['histos']:
        if mask_histogram is None:
            continue
        else:
            for eltClass in histogram:
                if eltClass in selectedClasses:
                    class_id = selectedClasses[eltClass]
                    if class_id not in mask_histogram and count_zeros:
                        if 0 not in histogram[eltClass]:
                            histogram[eltClass][0] = 0
                        histogram[eltClass][0] += 1
                    elif class_id in mask_histogram:
                        nb = mask_histogram[class_id]
                        if nb not in histogram[eltClass]:
                            histogram[eltClass][nb] = 0
                        histogram[eltClass][nb] += 1
    return histogram
def reduce_memory(results, config: Config, allow_sparse=True):
    """
    Minimize all masks in the results dict from inference
    :param results: dict containing results of the inference
    :param config: the config object
    :param allow_sparse: if False, will only keep biggest region of a mask
    :return:
    """
    _masks = results['masks']
    _bbox = results['rois']
    if not allow_sparse:
        emptyMasks = []
        for idx in range(results['masks'].shape[-1]):
            mask = unsparse_mask(results['masks'][:, :, idx])
            if mask is None:
                emptyMasks.append(idx)
            else:
                results['masks'][:, :, idx] = mask
        if len(emptyMasks) > 0:
            results['scores'] = np.delete(results['scores'], emptyMasks)
            results['class_ids'] = np.delete(results['class_ids'], emptyMasks)
            results['masks'] = np.delete(results['masks'], emptyMasks, axis=2)
            results['rois'] = np.delete(results['rois'], emptyMasks, axis=0)
        results['rois'] = extract_bboxes(results['masks'])
    results['masks'] = minimize_mask(results['rois'], results['masks'],
                                     config.get_mini_mask_shape())
    return results
 def fromLabelMe(cls,
                 filepath: str,
                 config: Config,
                 group_id_2_class: dict = None):
     """
     Constructs an ASAPAdapter from an exported LabelMe annotations file
     :param filepath: path to the LabelMe annotations file path
     :param config: the config
     :param group_id_2_class: dict that links group ids to classes names (if None, label is used)
     :return:
     """
     res = cls({})
     for c in config.get_classes_info():
         res.addAnnotationClass(c)
     with open(filepath, 'r') as labelMeFile:
         data = json.load(labelMeFile)
     for mask in data['shapes']:
         points = mask['points']
         if group_id_2_class is not None and mask[
                 'group_id'] in group_id_2_class:
             name = group_id_2_class[mask['group_id']]
         else:
             name = mask['label'].split(' ')[0]
         res.addAnnotation({'name': name}, points)
     return res
 def __init__(self, dataset_id, image_info, config: Config, previous_mode=False, enable_occlusion=False):
     super().__init__()
     self.__ID = dataset_id
     self.__CONFIG = config
     self.__CUSTOM_CLASS_NAMES = [c['name'] for c in config.get_classes_info("previous" if previous_mode else None)]
     self.__CLASS_ASSOCIATION = {format_text(c): c for c in self.__CUSTOM_CLASS_NAMES}
     self.__IMAGE_INFO = image_info
     self.__ENABLE_OCCLUSION = enable_occlusion
Exemplo n.º 6
0
def getCenteredClassBboxes(datasetPath: str,
                           imageName: str,
                           classToCenter: str,
                           image_size=1024,
                           imageFormat="jpg",
                           allow_oversized=True,
                           config: Config = None,
                           verbose=0):
    """
    Computes and returns bboxes of all masks of the given image and class
    :param datasetPath: path to the dataset containing the image folder
    :param imageName: the image name
    :param classToCenter: the class to center and get the bbox from
    :param image_size: the minimal height and width of the bboxes
    :param imageFormat: the image format to use to get original image
    :param allow_oversized: if False, masks that does not fit image_size will be skipped
    :param config: if given, config file is used to know if mini_masks are used
    :param verbose: level of verbosity
    :return: (N, 4) ndarray of [y1, x1, y2, x2] matching bboxes
    """
    imagePath = os.path.join(datasetPath, imageName, 'images',
                             f'{imageName}.{imageFormat}')
    image = cv2.imread(imagePath, cv2.IMREAD_COLOR)
    image_shape = image.shape[:2]
    classDirPath = os.path.join(datasetPath, imageName, classToCenter)
    maskList = os.listdir(classDirPath)
    classBboxes = np.zeros((len(maskList), 4), dtype=int)
    toDelete = []
    for idx, mask in enumerate(maskList):
        maskPath = os.path.join(classDirPath, mask)
        if config is not None and config.is_using_mini_mask():
            bbox = getBboxFromName(mask)
        else:
            maskImage = cv2.imread(maskPath, cv2.IMREAD_GRAYSCALE)
            bbox = utils.extract_bboxes(maskImage)
        if not allow_oversized:
            h, w = bbox[2:] - bbox[:2]
            if h > image_size or w > image_size:
                if verbose > 1:
                    print(
                        f"{mask} mask could not fit into {(image_size, image_size)} image"
                    )
                toDelete.append(idx)
        classBboxes[idx] = center_mask(bbox,
                                       image_shape,
                                       min_output_shape=image_size,
                                       verbose=verbose)
    classBboxes = np.delete(classBboxes, toDelete, axis=0)
    return classBboxes
def export_annotations(image_info: dict,
                       results: dict,
                       adapterClass: AnnotationAdapter.__class__,
                       save_path="predicted",
                       config: Config = None,
                       verbose=0):
    """
    Exports predicted results to an XML annotation file using given XMLExporter
    :param image_info: Dict with at least {"NAME": str, "HEIGHT": int, "WIDTH": int} about the inferred image
    :param results: inference results of the image
    :param adapterClass: class inheriting XMLExporter
    :param save_path: path to the dir you want to save the annotation file
    :param config: the config to get mini_mask informations
    :param verbose: verbose level of the method (0 = nothing, 1 = information)
    :return: None
    """
    if config is None:
        print("Cannot export annotations as config is not given.")
        return

    rois = results['rois']
    masks = results['masks']
    class_ids = results['class_ids']
    height = masks.shape[0]
    width = masks.shape[1]
    adapter_instance = adapterClass(
        {
            "name": image_info['NAME'],
            "height": image_info['HEIGHT'],
            'width': image_info['WIDTH'],
            'format': image_info['IMAGE_FORMAT']
        },
        verbose=verbose)
    if verbose > 0:
        print(
            f"Exporting to {adapter_instance.getName()} annotation file format."
        )
    # For each prediction
    for i in range(masks.shape[2]):
        if config is not None and config.is_using_mini_mask():
            shifted_roi = shift_bbox(rois[i])
            shifted_roi += [5, 5, 5, 5]
            image_size = shifted_roi[2:] + [5, 5]
            mask = expand_mask(shifted_roi, masks[:, :, i], image_size)
            yStart, xStart = rois[i][:2] - [5, 5]
        else:
            # Getting the RoI coordinates and the corresponding area
            # y1, x1, y2, x2
            yStart, xStart, yEnd, xEnd = rois[i]
            yStart = max(yStart - 10, 0)
            xStart = max(xStart - 10, 0)
            yEnd = min(yEnd + 10, height)
            xEnd = min(xEnd + 10, width)
            mask = masks[yStart:yEnd, xStart:xEnd, i]

        # Getting list of points coordinates and adding the prediction to XML
        points = getPoints(np.uint8(mask),
                           xOffset=xStart,
                           yOffset=yStart,
                           show=False,
                           waitSeconds=0,
                           info=False)
        if points is None:
            continue
        adapter_instance.addAnnotation(
            config.get_classes_info()[class_ids[i] - 1], points)

    for classInfo in config.get_classes_info():
        adapter_instance.addAnnotationClass(classInfo)

    os.makedirs(save_path, exist_ok=True)
    if verbose > 0:
        print('  - ', end='')
    adapter_instance.saveToFile(save_path, image_info['NAME'])
Exemplo n.º 8
0
def cleanImage(datasetPath: str,
               imageName: str,
               cleaningClasses: str,
               excludeClasses=None,
               imageFormat="jpg",
               cleanMasks=False,
               minAreaThreshold=300,
               config: Config = None):
    """
    Creating the full_images directory and cleaning the base image by removing non-cleaning-class areas
    :param excludeClasses:
    :param datasetPath: the dataset that have been wrapped
    :param imageName: the image name
    :param cleaningClasses: the class to use to clean the image
    :param cleanMasks: if true, will clean masks based on the cleaning-class-mask
    :param imageFormat: the image format to use to save the image
    :param minAreaThreshold: remove mask if its area is smaller than this threshold
    :param config: config object
    :return: None
    """
    assert cleaningClasses is not None and cleaningClasses != "", "Cleaning class is required."
    if type(cleaningClasses) is str:
        cleaningClasses = [cleaningClasses]
    if type(excludeClasses) is str:
        excludeClasses = [excludeClasses]

    # Getting the base image
    path = os.path.join(datasetPath, imageName, '{folder}',
                        f"{imageName}.{imageFormat}")
    imagePath = path.format(folder='images')
    fullImagePath = path.format(folder='full_images')
    image = cv2.imread(imagePath)

    # Fusing all the cleaning-class masks and then cleaning the image and if needed the masks
    cleaningClassMasks = gatherClassesMasks(datasetPath, imageName,
                                            image.shape, cleaningClasses)
    if excludeClasses is None:
        excludedClassMasks = None
    else:
        excludedClassMasks = gatherClassesMasks(datasetPath, imageName,
                                                image.shape, excludeClasses)

    if cleaningClassMasks is not None or excludedClassMasks is not None:
        if cleaningClassMasks is None:
            cleaningClassMasks = np.ones_like(image)[..., 0] * 255
        if excludedClassMasks is not None:
            excludedClassMasks = cv2.bitwise_not(excludedClassMasks)
            cleaningClassMasks = cv2.bitwise_and(cleaningClassMasks,
                                                 excludedClassMasks)
        # Copying the full image into the correct directory
        os.makedirs(os.path.dirname(fullImagePath), exist_ok=True)
        shutil.copy2(imagePath, fullImagePath)

        # Cleaning the image and saving it
        image = cv2.bitwise_and(
            image, np.repeat(cleaningClassMasks[:, :, np.newaxis], 3, axis=2))
        cv2.imwrite(imagePath, image, CV2_IMWRITE_PARAM)

        # Cleaning masks so that they cannot exist elsewhere
        if cleanMasks:
            folderToRemove = []
            for folder in os.listdir(os.path.join(datasetPath, imageName)):
                folderPath = os.path.join(datasetPath, imageName, folder)
                # Checking only for the other classes folder
                skipClasses = ["images", "full_images"]
                skipClasses.extend(cleaningClasses)
                skipClasses.extend(excludeClasses)
                if os.path.isdir(folderPath) and folder not in skipClasses:
                    # For each mask of the folder
                    for maskImageFileName in os.listdir(folderPath):
                        maskImagePath = os.path.join(folderPath,
                                                     maskImageFileName)
                        mask = loadSameResImage(maskImagePath, image.shape)
                        areaBefore = getBWCount(mask)[1]

                        # If mask is not empty
                        if areaBefore > 0:
                            # Cleaning it with the cleaning-class masks
                            mask = cv2.bitwise_and(mask, cleaningClassMasks)
                            areaAfter = getBWCount(mask)[1]
                        else:
                            areaAfter = areaBefore

                        # If mask was empty or too small after cleaning, we remove it
                        if areaBefore == 0 or areaAfter < minAreaThreshold:
                            os.remove(maskImagePath)
                        elif areaBefore != areaAfter:
                            # If mask has is different after cleaning, we replace the original one
                            try:
                                try:
                                    idMask = int(
                                        maskImageFileName.split('.')[0].split(
                                            '_')[1])
                                except ValueError:
                                    # If we could not retrieve the original mask ID, give it a unique one
                                    idMask = int(time())

                                # If mini-mask are enabled, we minimize it before saving it
                                bbox_coordinates = ""
                                if config is not None and config.is_using_mini_mask(
                                ):
                                    bbox = extract_bboxes(mask)
                                    mask = minimize_mask(
                                        bbox, mask,
                                        config.get_mini_mask_shape())
                                    mask = mask.astype(np.uint8) * 255
                                    y1, x1, y2, x2 = bbox
                                    bbox_coordinates = f"_{y1}_{x1}_{y2}_{x2}"

                                # Saving cleaned mask
                                outputName = f"{imageName}_{idMask:03d}{bbox_coordinates}.{imageFormat}"
                                cv2.imwrite(
                                    os.path.join(folderPath, outputName), mask,
                                    CV2_IMWRITE_PARAM)
                                if outputName != maskImageFileName:  # Remove former mask if not the same name
                                    os.remove(maskImagePath)
                            except Exception:
                                print(f"Error on {maskImagePath} update")

                    if len(os.listdir(folderPath)) == 0:
                        folderToRemove.append(folderPath)
            for folderPath in folderToRemove:
                shutil.rmtree(folderPath, ignore_errors=True)
            pass
Exemplo n.º 9
0
def createMasksOfImage(rawDatasetPath: str,
                       imgName: str,
                       datasetName: str = 'dataset_train',
                       adapter: AnnotationAdapter = None,
                       classesInfo: dict = None,
                       imageFormat="jpg",
                       resize=None,
                       config: Config = None):
    """
    Create all the masks of a given image by parsing xml annotations file
    :param rawDatasetPath: path to the folder containing images and associated annotations
    :param imgName: name w/o extension of an image
    :param datasetName: name of the output dataset
    :param adapter: the annotation adapter to use to create masks, if None looking for an adapter that can read the file
    :param classesInfo: Information about all classes that are used, by default will be nephrology classes Info
    :param imageFormat: output format of the image and masks
    :param resize: if the image and masks have to be resized
    :param config: config object
    :return: None
    """
    # Getting shape of original image (same for all this masks)
    if classesInfo is None:
        classesInfo = NEPHRO_CLASSES if config is None else config.get_classes_info(
        )

    img = cv2.imread(os.path.join(rawDatasetPath, f"{imgName}.{imageFormat}"))
    if img is None:
        print(f'Problem with {imgName} image')
        return
    shape = img.shape
    if resize is not None:
        yRatio = resize[0] / shape[0]
        xRatio = resize[1] / shape[1]
        assert yRatio > 0 and xRatio > 0, f"Error resize ratio not correct ({yRatio:3.2f}, {xRatio:3.2f})"
        img = cv2.resize(img, resize, interpolation=cv2.INTER_CUBIC)
        shape = img.shape

    # Copying the original image in the dataset
    targetDirectoryPath = os.path.join(datasetName, imgName, 'images')
    if not os.path.exists(targetDirectoryPath):
        os.makedirs(targetDirectoryPath)
        # TODO use file copy if unchanged else cv2
        cv2.imwrite(
            os.path.join(targetDirectoryPath, f"{imgName}.{imageFormat}"), img,
            CV2_IMWRITE_PARAM)

    # Finding annotation files
    formats = adapt.ANNOTATION_FORMAT
    fileList = os.listdir(rawDatasetPath)
    imageFiles = []
    for file in fileList:
        if imgName in file:
            if file.split('.')[-1] in formats:
                imageFiles.append(file)

    # Choosing the adapter to use (parameters to force it ?)
    file = None
    assert len(imageFiles) > 0
    if adapter is None:
        # No adapter given, we are looking for the adapter with highest priority level that can read an/the annotation
        # file
        adapters = list(adapt.ANNOTATION_ADAPTERS.values())
        adapterPriority = -1
        for f in imageFiles:
            for a in adapters:
                if a.canRead(os.path.join(rawDatasetPath, f)):
                    if a.getPriorityLevel() > adapterPriority:
                        adapterPriority = a.getPriorityLevel()
                        adapter = a
                        file = f
    else:
        # Using given adapter, we are looking for a file that can be read
        file = None
        for f in imageFiles:
            if adapter.canRead(os.path.join(rawDatasetPath,
                                            f)) and file is None:
                file = f

    # Getting the masks data
    masks = adapter.readFile(os.path.join(rawDatasetPath, file))

    # Creating masks
    for noMask, (datasetClass, maskPoints) in enumerate(masks):
        # Converting class id to class name if needed
        if type(datasetClass) is int:
            if datasetClass < len(classesInfo) and classesInfo[datasetClass][
                    "id"] == datasetClass:
                maskClass = classesInfo[datasetClass]["name"]
            else:
                for classInfo in classesInfo:
                    if classInfo["id"] == datasetClass:
                        maskClass = classInfo["name"]
                        break
        else:
            maskClass = datasetClass
            if maskClass == "None":
                print(f" /!\\ {imgName} : None class present /!\\ ")
        if resize is not None:
            resizedMasks = resizeMasks(maskPoints, xRatio, yRatio)
        createMask(imgName,
                   shape,
                   noMask,
                   maskPoints if resize is None else resizedMasks,
                   datasetName,
                   maskClass,
                   imageFormat,
                   config=config)
Exemplo n.º 10
0
def createMask(imgName: str,
               imgShape,
               idMask: int,
               ptsMask,
               datasetName: str = 'dataset_train',
               maskClass: str = 'masks',
               imageFormat="jpg",
               config: Config = None):
    """
    Create the mask image based on its polygon points
    :param imgName: name w/o extension of the base image
    :param imgShape: shape of the image
    :param idMask: the ID of the mask, a number not already used for that image
    :param ptsMask: array of [x, y] coordinates which are all the polygon points representing the mask
    :param datasetName: name of the output dataset
    :param maskClass: name of the associated class of the current mask
    :param imageFormat: output format of the masks' images
    :param config: config object
    :return: None
    """
    # https://www.programcreek.com/python/example/89415/cv2.fillPoly
    # Formatting coordinates matrix to get int
    ptsMask = np.double(ptsMask)
    ptsMask = np.matrix.round(ptsMask)
    ptsMask = np.int32(ptsMask)

    bbox_coordinates = ""
    if config is not None and config.is_using_mini_mask():
        bbox = get_bbox_from_points(ptsMask)
        if get_bboxes_intersection(bbox, [0, 0, *imgShape[:2]]) <= 0:
            return
        kept_bbox = [0, 0, 0, 0]
        for i in range(4):
            kept_bbox[i] = min(max(0, bbox[i]), imgShape[i % 2])
        y1, x1, y2, x2 = kept_bbox
        bbox_coordinates = f"_{y1}_{x1}_{y2}_{x2}"

        shiftedBbox = shift_bbox(bbox)
        shift = bbox[:2]
        mask = np.uint8(np.zeros((shiftedBbox[2], shiftedBbox[3])))
        cv2.fillPoly(mask, [ptsMask - shift[::-1]], 255)

        shifted_kept_bbox = shift_bbox(kept_bbox, customShift=shift)
        y1, x1, y2, x2 = shifted_kept_bbox
        mask = mask[y1:y2, x1:x2]

        # Creating black matrix with same size than original image and then drawing the mask
        mask = minimize_mask(shiftedBbox, mask, config.get_mini_mask_shape())
        mask = mask.astype(np.uint8) * 255
    else:
        # Creating black matrix with same size than original image and then drawing the mask
        mask = np.uint8(np.zeros((imgShape[0], imgShape[1])))
        cv2.fillPoly(mask, [ptsMask], 255)

    # Saving result image
    maskClass = maskClass.lower().strip(' ').replace(" ", "_")
    output_directory = os.path.join(datasetName, imgName, maskClass)
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    output_name = f"{imgName}_{idMask:03d}{bbox_coordinates}.{imageFormat}"
    cv2.imwrite(os.path.join(output_directory, output_name), mask,
                CV2_IMWRITE_PARAM)
def get_count_and_area(results: dict,
                       image_info: dict,
                       selected_classes: [str],
                       save=None,
                       display=True,
                       config: Config = None,
                       verbose=0):
    """
    Computing count and area of classes from results
    :param results: the results
    :param image_info: Dict containing informations about the image
    :param selected_classes: list of classes' names that you want to get statistics on
    :param save: if given, path to the json file that will contains statistics
    :param display: if True, will print the statistics
    :param config: the config to get mini_mask informations
    :param verbose: 0 : nothing, 1+ : errors/problems, 2 : general information, ...
    :return: Dict of "className": {"count": int, "area": int} elements for each classes
    """
    if config is None or (save is None and not display):
        return

    print(" - Computing statistics on predictions")

    rois = results['rois']
    masks = results['masks']
    class_ids = results['class_ids']
    indices = np.arange(len(class_ids))
    mini_mask_used = config.is_using_mini_mask()

    resize = config.get_param().get('resize', None)
    ratio = 1
    if resize is not None:
        ratio = image_info['HEIGHT'] / resize[0]
        ratio *= (image_info['WIDTH'] / resize[1])

    if type(selected_classes) is str:
        selected_classes_ = [selected_classes]
    else:
        selected_classes_ = selected_classes

    # Getting the inferenceIDs of the wanted classes
    if "all" in selected_classes_:
        selectedClassesID = {
            aClass['id']: aClass['name']
            for aClass in config.get_classes_info()
        }
    else:
        selectedClassesID = {
            config.get_class_id(name): name
            for name in selected_classes_
        }
        indices = indices[np.isin(class_ids, list(selectedClassesID.keys()))]
    res = {
        c_name: {
            "display_name": config.get_class_name(c_id, display=True),
            "count": 0,
            "area": 0
        }
        for c_id, c_name in selectedClassesID.items()
    }

    # For each predictions, if class ID matching with one we want
    for index in indices:
        # Getting current values of count and area
        className = selectedClassesID[class_ids[index]]
        res[className]["count"] += 1
        # Getting the area of current mask
        if mini_mask_used:
            shifted_roi = utils.shift_bbox(rois[index])
            mask = utils.expand_mask(shifted_roi, masks[:, :, index],
                                     shifted_roi[2:])
        else:
            yStart, xStart, yEnd, xEnd = rois[index]
            mask = masks[yStart:yEnd, xStart:xEnd, index]
        mask = mask.astype(np.uint8)
        if "mask_areas" in results and results['mask_areas'][index] != -1:
            area = int(results['mask_areas'][index])
        else:
            area, _ = utils.get_mask_area(mask)
        if resize is None:
            res[className][
                "area"] += area  # Cast to int to avoid "json 'int64' not serializable"
        else:
            res[className]["area"] += int(round(area * ratio))

    if 'BASE_CLASS' in image_info:
        mode = config.get_class_mode(image_info['BASE_CLASS'],
                                     only_in_previous="current")[0]
        res[image_info['BASE_CLASS']] = {
            "display_name":
            config.get_class_name(config.get_class_id(image_info['BASE_CLASS'],
                                                      mode),
                                  mode,
                                  display=True),
            "count":
            image_info['BASE_COUNT'],
            "area":
            image_info["BASE_AREA"]
        }
    if save is not None:
        with open(os.path.join(save, f"{image_info['NAME']}_stats.json"),
                  "w") as saveFile:
            try:
                json.dump(res, saveFile, indent='\t')
            except TypeError:
                if verbose > 0:
                    print("    Failed to save statistics", flush=True)
    if display:
        for className in res:
            mode = config.get_class_mode(className,
                                         only_in_previous="current")[0]
            displayName = config.get_class_name(config.get_class_id(
                className, mode),
                                                mode,
                                                display=True)
            stat = res[className]
            print(
                f"    - {displayName} : count = {stat['count']}, area = {stat['area']} px"
            )

    return res
def mask_histo_per_base_mask(base_results,
                             results,
                             image_info,
                             classes=None,
                             box_epsilon: int = 0,
                             test_masks=True,
                             mask_threshold=0.9,
                             count_zeros=True,
                             config: Config = None,
                             display_per_base_mask=False,
                             display_global=False,
                             save=None,
                             verbose=0):
    """
    Return an histogram of the number of each mask of a class inside each base mask
    :param base_results: results of the previous inference mode or ground-truth
    :param results: results of the current inference mode or ground-truth
    :param image_info: Dict containing informations about the image
    :param classes: dict that link previous classes to current classes that we want to count
    :param box_epsilon: margin of the RoI to allow boxes that are not exactly inside
    :param test_masks: if True, will test that masks are at least 'mask_threshold' inside the base mask
    :param mask_threshold: threshold that will define if a mask is included inside the base mask
    :param count_zeros: if True, base masks without included masks will be counted
    :param config: Config object of the Inference Tool
    :param display_per_base_mask: if True, will display each base mask histogram
    :param display_global: if True, will display global histogram
    :param save: if given, will be used as directory path to save json file of
    :param verbose: 0 : nothing, 1+ : errors/problems, 2 : general information
    :return: global histogram of how many base masks contain a certain amount an included class mask
    """
    # If classes is None or empty, skip method
    if classes is None or classes == {} or config is None \
            or (save is None and not (display_per_base_mask or display_global)):
        return results

    print(" - Computing base masks histograms")

    if box_epsilon < 0:
        raise ValueError(f"box_epsilon ({box_epsilon}) cannot be negative")

    def get_class_data(classname):
        fromPreviousRes = config.get_previous_mode() in config.get_class_mode(
            classname, "current")
        class_id = config.get_class_id(
            b_class, "previous" if fromPreviousRes else "current")
        tempRes = base_results if fromPreviousRes else results
        if 'histos' not in tempRes:  # If histo does not exists, initiate it
            tempRes['histos'] = np.empty(len(tempRes['class_ids']),
                                         dtype=object)
        return (class_id, tempRes['class_ids'], tempRes['rois'],
                tempRes['masks'], tempRes['histos'],
                np.arange(len(tempRes['class_ids']),
                          dtype=int), fromPreviousRes)

    # Getting all the results/current data
    c_class_ids = results['class_ids']
    c_rois = results['rois']
    c_masks = results['masks']
    c_indices = np.arange(len(results['class_ids']), dtype=int)
    c_areas = results.get('mask_areas',
                          np.ones(len(c_class_ids), dtype=int) * -1)

    # For each base class that we want to get an histogram of the included current classes
    for b_class in classes:
        b_class_id, b_class_ids, b_rois, b_masks, histograms, b_indices, fromPrevious = get_class_data(
            b_class)
        b_cur_idx = b_indices[np.isin(b_class_ids, [b_class_id])]
        if classes[b_class] == "all" or (type(classes[b_class]) is list
                                         and "all" in classes[b_class]):
            c_cur_idx = c_indices
        else:
            if type(classes[b_class]) is str:
                temp_ = [classes[b_class]]
            else:
                temp_ = classes[b_class]
            c_class_id = [config.get_class_id(aClass) for aClass in temp_]
            c_cur_idx = c_indices[np.isin(c_class_ids, c_class_id)]
        for b_idx in b_cur_idx:  # For each base class mask
            b_roi = b_rois[b_idx]
            custom_shift = b_roi[:2] - box_epsilon
            padded_size = b_roi[2:] - b_roi[:2] + (box_epsilon * 2)
            if test_masks:
                b_mask = b_masks[..., b_idx]
                if config.is_using_mini_mask(config.get_previous_mode()):
                    b_shifted_roi = utils.shift_bbox(b_roi, custom_shift)
                    b_mask = utils.expand_mask(b_shifted_roi, b_mask,
                                               padded_size)
                else:
                    b_mask = np.pad(
                        b_mask[b_roi[0]:b_roi[2], b_roi[1]:b_roi[3]],
                        box_epsilon)
            if histograms[b_idx] is None:
                histograms[b_idx] = {}

            for c_idx in c_cur_idx:  # For each mask of one of the current classes
                c_roi = c_rois[c_idx]
                c_class = c_class_ids[c_idx]
                if fromPrevious and c_class == b_class_id:  # If using same results, skip base class elements
                    continue

                # If the bbox of the current mask is inside the base bbox
                if utils.in_roi(c_roi, b_roi, epsilon=box_epsilon):

                    if test_masks:  # If we have to check that masks are included
                        c_mask = c_masks[..., c_idx]
                        if config.is_using_mini_mask():
                            c_shifted_roi = utils.shift_bbox(
                                c_roi, custom_shift)
                            c_mask = utils.expand_mask(c_shifted_roi, c_mask,
                                                       padded_size)
                        else:
                            c_mask = np.pad(
                                c_mask[b_roi[0]:b_roi[2], b_roi[1]:b_roi[3]],
                                box_epsilon)
                        if c_areas[c_idx] == -1:
                            c_areas[c_idx] = dD.getBWCount(c_mask)[1]
                        c_mask = np.bitwise_and(b_mask, c_mask)
                        c_area_in = dD.getBWCount(c_mask)[1]
                        if c_area_in <= c_areas[
                                c_idx] * mask_threshold:  # If the included part is not enough, skip it
                            continue
                    if c_class not in histograms[b_idx]:
                        histograms[b_idx][c_class] = 0
                    histograms[b_idx][c_class] += 1

    # Display of each individual histogram
    if display_per_base_mask:
        for res in [base_results, results]:
            if 'histos' not in res:
                continue
            for idx, histogram in enumerate(res['histos']):
                if histogram is not None:
                    print(
                        f"    - mask n°{idx}:", ", ".join([
                            f"{nb} {config.get_class_name(c, display=True)}"
                            for c, nb in histogram.items()
                        ]))

    # Computing global histograms
    first = True
    for res in [base_results, results]:
        if 'histos' in res:
            if first:
                first = False
                global_histo = mask_to_class_histogram(res,
                                                       classes=classes,
                                                       count_zeros=count_zeros,
                                                       config=config)
            else:  # Updating manually global histo if there are base classes from both previous and current res
                temp_histo = mask_to_class_histogram(res,
                                                     classes=classes,
                                                     count_zeros=count_zeros,
                                                     config=config)
                for c in temp_histo:
                    if c not in global_histo:
                        global_histo[c] = temp_histo[c]
                    else:
                        for nb in temp_histo[c]:
                            if nb not in global_histo[c]:
                                global_histo[c][nb] = 0
                            global_histo[c][nb] += temp_histo[c][nb]

    for key in global_histo.keys():
        global_histo[key] = sort_dict(global_histo[key], key_type=int)

    # Displaying global histogram if needed
    baseName = 'BASE' if len(classes) > 1 else list(classes.keys())[0]
    if display_global:
        for class_, histogram in global_histo.items():
            print(
                f"    - {class_}:", ", ".join([
                    f"{nb_elt} [{nb_mask} {baseName.lower()} mask{'s' if nb_mask > 1 else ''}]"
                    for nb_elt, nb_mask in histogram.items()
                ]))
    if save is not None:
        temp = {
            '_comment':
            f"<class A>: {{N: <nb {baseName} masks with N class A masks>}}"
        }
        temp.update(global_histo)
        with open(os.path.join(save, f'{image_info["NAME"]}_histo.json'),
                  'w') as saveFile:
            json.dump(temp, saveFile, indent='\t')
    return global_histo
def display_instances(image,
                      boxes,
                      masks,
                      class_ids,
                      class_names,
                      scores=None,
                      title="",
                      figsize=(16, 16),
                      ax=None,
                      fig=None,
                      image_format="jpg",
                      show_mask=True,
                      show_bbox=True,
                      colors=None,
                      colorPerClass=False,
                      captions=None,
                      fileName=None,
                      save_cleaned_img=False,
                      silent=False,
                      config: Config = None):
    """
    boxes: [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates.
    masks: [height, width, num_instances]
    class_ids: [num_instances]
    class_names: list of class names of the dataset
    scores: (optional) confidence scores for each box
    title: (optional) Figure title
    show_mask, show_bbox: To show masks and bounding boxes or not
    figsize: (optional) the size of the image
    colors: (optional) An array or colors to use with each object
    captions: (optional) A list of strings to use as captions for each object
    """
    # Number of instances
    N = boxes.shape[0]
    if not N:
        print("\n*** No instances to display *** \n")
    else:
        assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0]

    # If no axis is passed, create one and automatically call show()
    auto_show = False
    ownFig = False
    if ax is None or fig is None:
        ownFig = True
        fig, ax = plt.subplots(1, figsize=figsize)
        auto_show = not silent

    # Generate random colors
    nb_color = (len(class_names) - 1) if colorPerClass else N
    colors = colors if colors is not None else random_colors(
        nb_color, shuffle=(not colorPerClass))
    if type(colors[0][0]) is int:
        _colors = []
        for color in colors:
            _colors.append([c / 255. for c in color])
    else:
        _colors = colors
    # Show area outside image boundaries.
    height, width = image.shape[:2]
    ax.set_ylim(height + 10, -10)
    ax.set_xlim(-10, width + 10)
    ax.axis('off')
    ax.set_title(title)

    # To be usable on Google Colab we do not make a copy of the image leading to too much ram usage if it is a biopsy
    # or nephrectomy image
    masked_image = image
    # masked_image = image.astype(np.uint32).copy()
    for i in range(N):
        if colorPerClass:
            color = _colors[class_ids[i] - 1]
        else:
            color = _colors[i]
        # Bounding box
        if not np.any(boxes[i]):
            # Skip this instance. Has no bbox. Likely lost in image cropping.
            continue
        y1, x1, y2, x2 = boxes[i]
        if show_bbox:
            p = patches.Rectangle((x1, y1),
                                  x2 - x1,
                                  y2 - y1,
                                  linewidth=2,
                                  alpha=0.7,
                                  linestyle="dashed",
                                  edgecolor=color,
                                  facecolor='none')
            ax.add_patch(p)

        # Label
        if not captions:
            class_id = class_ids[i]
            score = scores[i] if scores is not None else None
            label = class_names[class_id]
            caption = "{} {:.3f}".format(label, score) if score else label
        else:
            caption = captions[i]
        ax.text(x1 + 4,
                y1 + 19,
                caption,
                color=get_text_color(color[0], color[1], color[2]),
                size=12,
                backgroundcolor=color)

        # Mask
        mask = masks[:, :, i]
        bbox = boxes[i]
        shift = np.array([0, 0])
        if config is not None and config.is_using_mini_mask():
            shifted_bbox = utils.shift_bbox(bbox)
            shift = bbox[:2]
            mask = utils.expand_mask(shifted_bbox, mask,
                                     tuple(shifted_bbox[2:]))
            mask = mask.astype(np.uint8) * 255
        if show_mask:
            masked_image = apply_mask(masked_image, mask, color, bbox=bbox)

        # Mask Polygon
        # Pad to ensure proper polygons for masks that touch image edges.
        padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2),
                               dtype=np.uint8)
        padded_mask[1:-1, 1:-1] = mask
        contours = find_contours(padded_mask, 0.5)
        for verts in contours:
            verts = verts + shift
            # Subtract the padding and flip (y, x) to (x, y)
            verts = np.fliplr(verts) - 1
            p = Polygon(verts, facecolor="none", edgecolor=color)
            ax.add_patch(p)
    # masked_image = masked_image.astype(np.uint8)
    ax.imshow(masked_image)
    fig.tight_layout()
    if fileName is not None:
        fig.savefig(f"{fileName}.{image_format}")
        if save_cleaned_img:
            BGR_img = cv2.cvtColor(masked_image, cv2.COLOR_RGB2BGR)
            cv2.imwrite(f"{fileName}_clean.{image_format}", BGR_img,
                        CV2_IMWRITE_PARAM)
    if auto_show:
        plt.show()
    fig.clf()
    if ownFig:
        del ax, fig
    return masked_image