示例#1
0
    def __init__(
        self,
        model_dir,
        weights_path,
        model_config=defaults.DefaultConfig(),
        class_names=defaults.CLASSES,
        layout_params={},
        init=None
    ):
        self.model_config = model_config

        # Check that `class_names` make sense
        if len(class_names) == (model_config.NUM_CLASSES - 1):
            self.class_names = ["BG"] + class_names
        else:
            assert (
                len(class_names) == model_config.NUM_CLASSES
            ), "Number of `class_names` must match number in `model_config`"
            assert (
                class_names[0] == "BG"
            ), "Background `BG` must be first in `class_names`"
            self.class_names = class_names

        # Set defaults and check validity of any new params via setter
        self._layout_params = defaults.LAYOUT_PARAMS
        self.layout_params = layout_params

        # Build and load the saved model
        print(f"Building MRCNN model from directory: {str(model_dir)}")
        self.model = modellib.MaskRCNN(
            mode="inference", config=self.model_config, model_dir=str(model_dir)
        )

        print(f"Loading model weights from file: {str(weights_path)}")
        if init == "coco":
            self.model.load_weights(str(weights_path), by_name=True, 
                                    exclude=["mrcnn_class_logits",
                                             "mrcnn_bbox_fc",
                                             "mrcnn_bbox",
                                             "mrcnn_mask"])
        elif init == "imagenet":
            self.model.load_weights(self.model.get_imagenet_weights(), by_name=True)

        else:
            try:
                self.model.load_weights(str(weights_path), by_name=True)
            except ValueError as exc:
                   msg = (f"Error during loading pretrained weights: {exc}")
                   raise CorebreakoutError(msg) from None
示例#2
0
    def __init__(
        self,
        model_dir,
        weights_path,
        model_config=defaults.DefaultConfig(),
        class_names=defaults.CLASSES,
        layout_params={},
    ):
        """
        Parameters
        ----------
        model_dir : str or Path
            Path to saved MRCNN model directory
        weights_path : str or Path
            Path to saved weights file of corresponding model
        model_config : `mrcnn.Config`, optional
            Instance of MRCNN configuration object, default=`defaults.DefaultConfig()`.
        class_names : list(str), optional
            A list of the class names for model output. Should be in same order as in
            the `Dataset` object that model was trained on. Default=`defaults.CLASSES`
        layout_params : dict, optional
            Any layout parameters to override from default=`defaults.LAYOUT_PARAMS`.
            See `docs/layout_parameters.md` for explanations and options for each parameter.
        """
        self.model_config = model_config

        # Check that `class_names` make sense
        if len(class_names) == (model_config.NUM_CLASSES - 1):
            self.class_names = ["BG"] + class_names
        else:
            assert (
                len(class_names) == model_config.NUM_CLASSES
            ), "Number of `class_names` must match number in `model_config`"
            assert (class_names[0] == "BG"
                    ), "Background `BG` must be first in `class_names`"
            self.class_names = class_names

        # Set defaults and check validity of any new params via setter
        self._layout_params = defaults.LAYOUT_PARAMS
        self.layout_params = layout_params

        # Build and load the saved model
        print(f"Building MRCNN model from directory: {str(model_dir)}")
        self.model = modellib.MaskRCNN(mode="inference",
                                       config=self.model_config,
                                       model_dir=str(model_dir))

        print(f"Loading model weights from file: {str(weights_path)}")
        self.model.load_weights(str(weights_path), by_name=True)
"""
import os
import argparse
from glob import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from functools import reduce
from operator import add

from corebreakout import defaults
from corebreakout import CoreSegmenter, CoreColumn

# Change Config selection manually
model_config = defaults.DefaultConfig()

# Change class_names manually
class_names = defaults.CLASSES

# Change any non-default layout_params manually
layout_params = defaults.LAYOUT_PARAMS

parser = argparse.ArgumentParser(
    description=
    'Convert image directories with Mask R-CNN and save results as `CoreColumn`s.'
)
parser.add_argument(
    'path',
    type=str,
    help="Path to directory of images (and depth information csv) to process.")