Beispiel #1
0
    def __init__(self, num_classes=1, label_dict={"Tree": 0}):
        """
        Args:
            num_classes (int): number of classes in the model
        Returns:
            self: a deepforest pytorch ligthning module
        """
        super().__init__()

        # Read config file - if a config file exists in local dir use it,
        # if not use installed.
        if os.path.exists("deepforest_config.yml"):
            config_path = "deepforest_config.yml"
        else:
            try:
                config_path = get_data("deepforest_config.yml")
            except Exception as e:
                raise ValueError(
                    "No deepforest_config.yml found either in local "
                    "directory or in installed package location. {}".format(e))

        print("Reading config file: {}".format(config_path))
        self.config = utilities.read_config(config_path)

        # release version id to flag if release is being used
        self.__release_version__ = None

        self.num_classes = num_classes
        self.create_model()
        #Label encoder and decoder
        self.label_dict = label_dict
        self.numeric_to_label_dict = {v: k for k, v in label_dict.items()}
Beispiel #2
0
    def __init__(self, weights=None):
        self.weights = weights

        #Read config file
        self.config = utilities.read_config()

        #Load model weights if needed
        if self.weights is not None:
            self.model = utilities.read_model(self.weights, self.config)
        else:
            self.model = None
Beispiel #3
0
    def __init__(self, num_classes=1, label_dict = {"Tree":0}, transforms=None, config_file='deepforest_config.yml'):
        """
        Args:
            num_classes (int): number of classes in the model
            config_file (str): path to deepforest config file
        Returns:
            self: a deepforest pytorch lightning module
        """
        super().__init__()
        
        #Pytorch lightning handles the device, but we need one for adhoc methods like predict_image.
        if torch.cuda.is_available():
            self.current_device = torch.device("cuda")
        else:
            self.current_device = torch.device("cpu")
            
        # Read config file. Defaults to deepforest_config.yml in working directory.
        # Falls back to default installed version
        if os.path.exists(config_file):
            config_path = config_file
        else:
            try:
                config_path = get_data("deepforest_config.yml")
            except Exception as e:
                raise ValueError(
                    "No config file provided and deepforest_config.yml not found either in local "
                    "directory or in installed package location. {}".format(e))

        print("Reading config file: {}".format(config_path))
        self.config = utilities.read_config(config_path)

        # release version id to flag if release is being used
        self.__release_version__ = None

        self.num_classes = num_classes
        self.create_model()
                
        #Label encoder and decoder
        if not len(label_dict) == num_classes:
            raise ValueError(
                'label_dict {} does not match requested number of classes {}, please supply a label_dict argument {{"label1":0, "label2":1, "label3":2 ... etc}} for each label in the dataset'.format(label_dict, num_classes)
            )
        
        self.label_dict = label_dict
        self.numeric_to_label_dict = {v: k for k, v in label_dict.items()}
        
        #Add user supplied transforms
        if transforms is None:
            self.transforms = dataset.get_transform
        else:
            self.transforms = transforms
        
        self.save_hyperparameters()
Beispiel #4
0
def config():
    config = utilities.read_config(get_data("deepforest_config.yml"))
    config["patch_size"] = 200
    config["patch_overlap"] = 0.25
    config["annotations_xml"] = get_data("OSBS_029.xml")
    config["rgb_dir"] = "data"
    config["annotations_file"] = "tests/data/OSBS_029.csv"
    config["path_to_raster"] = get_data("OSBS_029.tif")

    # Create a clean config test data
    annotations = utilities.xml_to_annotations(xml_path=config["annotations_xml"])
    annotations.to_csv("tests/data/OSBS_029.csv", index=False)

    return config
Beispiel #5
0
    def __init__(self, weights=None, saved_model=None):
        self.weights = weights
        self.saved_model = saved_model

        # Read config file - if a config file exists in local dir use it,
        # if not use installed.
        if os.path.exists("deepforest_config.yml"):
            config_path = "deepforest_config.yml"
        else:
            try:
                config_path = get_data("deepforest_config.yml")
            except Exception as e:
                raise ValueError(
                    "No deepforest_config.yml found either in local "
                    "directory or in installed package location. {}".format(e))

        print("Reading config file: {}".format(config_path))
        self.config = utilities.read_config(config_path)

        # Create a label dict, defaults to "Tree"
        self.read_classes()

        # release version id to flag if release is being used
        self.__release_version__ = None

        # Load saved model if needed
        if self.saved_model:
            print("Loading saved model")
            # Capture user warning, not relevant here
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=UserWarning)
                self.model = models.load_model(saved_model)
                self.prediction_model = convert_model(self.model)

        elif self.weights:
            print("Creating model from weights")
            backbone = models.backbone(self.config["backbone"])
            self.model, self.training_model, self.prediction_model = create_models(
                backbone.retinanet, num_classes=1, weights=self.weights)
        else:
            print(
                "A blank deepforest object created. "
                "To perform prediction, either train or load an existing model."
            )
            self.model = None
def config():
    config = utilities.read_config("deepforest_config.yml")
    return config
Beispiel #7
0
def config():
    config = utilities.read_config(get_data("deepforest_config.yml"))
    return config
Beispiel #8
0
                local_annotations = future.result()
            except Exception as e:
                print("future {} failed with {}".format(future, e))

if __name__=="__main__":

    #Generate anchor objects for each image and wrap in tfrecords
    DEBUG = False

    #Number of images per tfrecord
    SIZE = 50

    #Set paths
    if DEBUG:
        BASE_PATH = "/Users/ben/Documents/DeepForest_Model/"
        FILEPATH = "/Users/ben/Documents/DeepForest_Model/"
        BENCHMARK_PATH = "/Users/ben/Documents/NeonTreeEvaluation/"
        dask_client = None
    else:
        BASE_PATH = "/orange/ewhite/b.weinstein/NeonTreeEvaluation/"
        FILEPATH = "/orange/ewhite/b.weinstein/NeonTreeEvaluation/"
        BENCHMARK_PATH = "/home/b.weinstein/NeonTreeEvaluation/"
        dask_client = start_dask_cluster(number_of_workers=15, mem_size="5GB")

    #Read config
    config = read_config("deepforest_config.yml")

    #generate_hand_annotations(DEBUG, BASE_PATH, FILEPATH, SIZE, config, dask_client)
    #generate_pretraining(DEBUG, BASE_PATH, FILEPATH, SIZE, config, dask_client)
    generate_benchmark(DEBUG, BENCHMARK_PATH, BENCHMARK_PATH, SIZE, config, dask_client)