def test_list_registered_models(mock_get_request_message, mock_model_registry_store): mock_get_request_message.return_value = ListRegisteredModels() rmds = [ RegisteredModelDetailed(name="model_1", creation_timestamp=111, last_updated_timestamp=222, description="Test model", latest_versions=[]), RegisteredModelDetailed(name="model_2", creation_timestamp=111, last_updated_timestamp=333, description="Another model", latest_versions=[]), ] mock_model_registry_store.list_registered_models.return_value = rmds resp = _list_registered_models() args, _ = mock_model_registry_store.list_registered_models.call_args assert args == () assert json.loads(resp.get_data()) == {"registered_models_detailed": jsonify(rmds)}
def get_registered_model_details(self, registered_model): """ :param registered_model: :py:class:`mlflow.entities.model_registry.RegisteredModel` object. :return: A single :py:class:`mlflow.entities.model_registry.RegisteredModelDetailed` object. """ req_body = message_to_json(GetRegisteredModelDetails( registered_model=registered_model.to_proto())) response_proto = self._call_endpoint(GetRegisteredModelDetails, req_body) return RegisteredModelDetailed.from_proto(response_proto.registered_model_detailed)
def list_registered_models(self): """ List of all registered models. :return: List of :py:class:`mlflow.entities.model_registry.RegisteredModel` objects. """ req_body = message_to_json(ListRegisteredModels()) response_proto = self._call_endpoint(ListRegisteredModels, req_body) return [RegisteredModelDetailed.from_proto(registered_model_detailed) for registered_model_detailed in response_proto.registered_models_detailed]
def test_get_registered_model_details(mock_store): mock_store.get_registered_model_details.return_value = RegisteredModelDetailed( "Model 1", "1263283747835", "1283168374623874", "I am a model", [ _model_version_detailed("Model 1", 3, "None"), _model_version_detailed("Model 1", 2, "Staging"), _model_version_detailed("Model 1", 1, "Production") ]) result = newModelRegistryClient().get_registered_model_details("Model 1") mock_store.get_registered_model_details.assert_called_once() assert result.name == "Model 1" assert len(result.latest_versions) == 3
def test_get_registered_model_details(mock_get_request_message, mock_model_registry_store): rm = RegisteredModel("model1") mock_get_request_message.return_value = GetRegisteredModelDetails( registered_model=rm.to_proto()) rmd = RegisteredModelDetailed(name="model_1", creation_timestamp=111, last_updated_timestamp=222, description="Test model", latest_versions=[]) mock_model_registry_store.get_registered_model_details.return_value = rmd resp = _get_registered_model_details() args, _ = mock_model_registry_store.get_registered_model_details.call_args assert args == (rm, ) assert json.loads(resp.get_data()) == {"registered_model_detailed": jsonify(rmd)}
def to_mlflow_detailed_entity(self): # SqlRegisteredModel has backref to all "model_versions". Filter latest for each stage. latest_versions = {} for mv in self.model_versions: stage = mv.current_stage if stage not in latest_versions or latest_versions[ stage].version < mv.version: latest_versions[stage] = mv return RegisteredModelDetailed(self.name, self.creation_time, self.last_updated_time, self.description, [ mvd.to_mlflow_detailed_entity() for mvd in latest_versions.values() ])