示例#1
0
文件: model.py 项目: timroz24/h1st
    def load(self, version: str = None):
        """
        Load the specified `version` from the ModelRepository. Leave version blank to load latest version.
        """
        mm = ModelRepository.get_model_repo(self)
        mm.load(model=self, version=version)

        return self
示例#2
0
    def load(self, version: str = None) -> "Model":
        """
        Load parameters from the specified `version` from the ModelRepository.
        Leave version blank to load latest version.
        """
        repo = ModelRepository.get_model_repo(self)
        repo.load(model=self, version=version)

        return self
示例#3
0
    def persist(self, version=None):
        """
        Persist this model's properties to the ModelRepository. Currently, only `stats`, `metrics`, `model` properties are supported.

        `model` property could be single model, list or dict of models
        Currently, only sklearn and tensorflow-keras are supported.

        :param version: model version, leave blank for autogeneration
        :returns: model version
        """
        repo = ModelRepository.get_model_repo(self)
        return repo.persist(model=self, version=version)
示例#4
0
    def test_serialize_sklearn_model(self):
        class MyModel(Model):
            def __init__(self):
                super().__init__()
                self.model = LogisticRegression(random_state=0)

            def train(self, prepared_data):
                X, y = prepared_data['X'], prepared_data['y']
                self.model.fit(X, y)

        X, y = load_iris(return_X_y=True)
        prepared_data = {'X': X, 'y': y}

        model = MyModel()
        model.train(prepared_data)
        with tempfile.TemporaryDirectory() as path:
            mm = ModelRepository(storage=LocalStorage(storage_path=path))
            version = mm.persist(model=model)

            model_2 = MyModel()
            mm.load(model=model_2, version=version)

            assert 'sklearn' in str(type(model_2.model))
示例#5
0
 def test_explicit_init_model_repo(self):
     from h1st.model_repository import ModelRepository
     init(MODEL_REPO_PATH=".model")
     mm = ModelRepository.get_model_repo()
     assert(mm._storage.storage_path == ".model")
     delattr(ModelRepository, 'MODEL_REPO')
示例#6
0
 def init_model_repo(cls, repo_path):
     from h1st.model_repository import ModelRepository
     if not hasattr(ModelRepository, 'MODEL_REPO'):
         setattr(ModelRepository, 'MODEL_REPO', ModelRepository(storage=repo_path))