def use_release(self, gpus=1): """Use the latest DeepForest model release from github and load model. Optionally download if release doesn't exist. Returns: model (object): A trained keras model gpus: number of gpus to parallelize, default to 1 """ # Download latest model from github release release_tag, self.weights = utilities.use_release() # load saved model and tag release self.__release_version__ = release_tag print("Loading pre-built model: {}".format(release_tag)) if gpus == 1: with warnings.catch_warnings(): # Suppress compilte warning, not relevant here warnings.filterwarnings("ignore", category=UserWarning) self.model = utilities.read_model(self.weights, self.config) # Convert model self.prediction_model = convert_model(self.model) elif gpus > 1: backbone = models.backbone(self.config["backbone"]) n_classes = len(self.labels.keys()) self.model, self.training_model, self.prediction_model = create_models( backbone.retinanet, num_classes=n_classes, weights=self.weights, multi_gpu=gpus) # add to config self.config["weights"] = self.weights
def use_release(self): '''Use the latest DeepForest model release from github and load model. Optionally download if release doesn't exist Returns: model (object): A trained keras model ''' #Download latest model from github release weight_path = utilities.use_release() #load weights self.weights = weight_path self.model = utilities.read_model(self.weights, self.config)
def use_release(self): """Use the latest DeepForest model release from github and load model. Optionally download if release doesn't exist. Returns: model (object): A trained keras model """ # Download latest model from github release release_tag, self.release_state_dict = utilities.use_release() self.model.load_state_dict( torch.load(self.release_state_dict, map_location=self.device)) # load saved model and tag release self.__release_version__ = release_tag print("Loading pre-built model: {}".format(release_tag))
def use_release(self, check_release=True): """Use the latest DeepForest model release from github and load model. Optionally download if release doesn't exist. Args: check_release (logical): whether to check github for a model recent release. In cases where you are hitting the github API rate limit, set to False and any local model will be downloaded. If no model has been downloaded an error will raise. Returns: model (object): A trained PyTorch model """ # Download latest model from github release release_tag, self.release_state_dict = utilities.use_release(check_release=check_release) self.model.load_state_dict( torch.load(self.release_state_dict, map_location=self.device)) # load saved model and tag release self.__release_version__ = release_tag print("Loading pre-built model: {}".format(release_tag))
def test_use_release(download_release): # Download latest model from github release release_tag, state_dict = utilities.use_release() assert os.path.exists(get_data("NEON.pt"))
def test_use_release(): #Download latest model from github release release_tag, weights = utilities.use_release() assert os.path.exists(get_data("NEON.h5"))
def download_release(): print("running fixtures") utilities.use_release()
def download_release(): print("running fixtures") utilities.use_release(save_dir="tests/data")
def download_release(): print("running fixtures") utilities.use_release() assert os.path.exists(get_data("NEON.pt"))
def test_use_release(download_release): # Download latest model from github release release_tag, state_dict = utilities.use_release(check_release=False)