예제 #1
0
 def test_search_model_versions(self, mock_http):
     self.store.search_model_versions(filter_string="name='model_12'")
     self._verify_requests(mock_http, "model-versions/search", "GET",
                           SearchModelVersions(filter="name='model_12'"))
예제 #2
0
def test_search_model_versions(mock_get_request_message,
                               mock_model_registry_store):
    mock_get_request_message.return_value = SearchModelVersions(
        filter="source_path = 'A/B/CD'")
    mvds = [
        ModelVersion(
            name="model_1",
            version="5",
            creation_timestamp=100,
            last_updated_timestamp=1200,
            description="v 5",
            user_id="u1",
            current_stage="Production",
            source="A/B/CD",
            run_id=uuid.uuid4().hex,
            status="READY",
            status_message=None,
        ),
        ModelVersion(
            name="model_1",
            version="12",
            creation_timestamp=110,
            last_updated_timestamp=2000,
            description="v 12",
            user_id="u2",
            current_stage="Production",
            source="A/B/CD",
            run_id=uuid.uuid4().hex,
            status="READY",
            status_message=None,
        ),
        ModelVersion(
            name="ads_model",
            version="8",
            creation_timestamp=200,
            last_updated_timestamp=2000,
            description="v 8",
            user_id="u1",
            current_stage="Staging",
            source="A/B/CD",
            run_id=uuid.uuid4().hex,
            status="READY",
            status_message=None,
        ),
        ModelVersion(
            name="fraud_detection_model",
            version="345",
            creation_timestamp=1000,
            last_updated_timestamp=1001,
            description="newest version",
            user_id="u12",
            current_stage="None",
            source="A/B/CD",
            run_id=uuid.uuid4().hex,
            status="READY",
            status_message=None,
        ),
    ]
    mock_model_registry_store.search_model_versions.return_value = mvds
    resp = _search_model_versions()
    args, _ = mock_model_registry_store.search_model_versions.call_args
    assert args == ("source_path = 'A/B/CD'", )
    assert json.loads(resp.get_data()) == {"model_versions": jsonify(mvds)}