def test_get_loaded_booster(model_info):
    """Test model loading

    'pickled_model' directory has a model dumped using pickle module
    'saved_booster' directory has a model saved using booster.save_model()
    """
    model_dir_name, model_format = model_info
    model_dir = os.path.join(RESOURCES_PATH, 'models', model_dir_name)
    loaded_booster, loaded_model_format = serve_utils.get_loaded_booster(model_dir)
    assert loaded_model_format == model_format
Example #2
0
 def default_model_fn(self, model_dir):
     """Load a model. For XGBoost Framework, a default function to load a model is not provided.
     Users should provide customized model_fn() in script.
     Args:
         model_dir: a directory where model is saved.
     Returns:
         A XGBoost model.
         XGBoost model format type.
     """
     try:
         booster, format = serve_utils.get_loaded_booster(model_dir)
     except Exception as e:
         raise ModelLoadInferenceError("Unable to load model: {}".format(str(e)))
     return booster, format
Example #3
0
 def load_model(cls):
     if cls.booster is None:
         cls.booster, cls.format = serve_utils.get_loaded_booster(
             ScoringService.MODEL_PATH)
     return cls.format
Example #4
0
 def load_model(cls, ensemble=True):
     if cls.booster is None:
         cls.booster, cls.format = serve_utils.get_loaded_booster(ScoringService.MODEL_PATH, ensemble)
         cls.get_config_json()
     return cls.format