Ejemplo n.º 1
0
def test_get_latest_versions(mock_get_request_message, mock_model_registry_store):
    name = "model1"
    mock_get_request_message.return_value = GetLatestVersions(name=name)
    mvds = [
        ModelVersion(name=name, version="5", creation_timestamp=1,
                     last_updated_timestamp=12, description="v 5", user_id="u1",
                     current_stage="Production", source="A/B", run_id=uuid.uuid4().hex,
                     status="READY", status_message=None),
        ModelVersion(name=name, version="1", creation_timestamp=1,
                     last_updated_timestamp=1200, description="v 1", user_id="u1",
                     current_stage="Archived", source="A/B2", run_id=uuid.uuid4().hex,
                     status="READY", status_message=None),
        ModelVersion(name=name, version="12", creation_timestamp=100,
                     last_updated_timestamp=None, description="v 12", user_id="u2",
                     current_stage="Staging", source="A/B3", run_id=uuid.uuid4().hex,
                     status="READY", status_message=None),
    ]
    mock_model_registry_store.get_latest_versions.return_value = mvds
    resp = _get_latest_versions()
    _, args = mock_model_registry_store.get_latest_versions.call_args
    assert args == {"name": name, "stages": []}
    assert json.loads(resp.get_data()) == {"model_versions": jsonify(mvds)}

    for stages in [[], ["None"], ["Staging"], ["Staging", "Production"]]:
        mock_get_request_message.return_value = GetLatestVersions(name=name,
                                                                  stages=stages)
        _get_latest_versions()
        _, args = mock_model_registry_store.get_latest_versions.call_args
        assert args == {"name": name, "stages": stages}
Ejemplo n.º 2
0
def test_get_latest_versions(mock_get_request_message,
                             mock_model_registry_store):
    rm = RegisteredModel("model1")
    mock_get_request_message.return_value = GetLatestVersions(
        registered_model=rm.to_proto())
    mvds = [
        ModelVersionDetailed(registered_model=rm,
                             version=5,
                             creation_timestamp=1,
                             last_updated_timestamp=12,
                             description="v 5",
                             user_id="u1",
                             current_stage="Production",
                             source="A/B",
                             run_id=uuid.uuid4().hex,
                             status="READY",
                             status_message=None),
        ModelVersionDetailed(registered_model=rm,
                             version=1,
                             creation_timestamp=1,
                             last_updated_timestamp=1200,
                             description="v 1",
                             user_id="u1",
                             current_stage="Archived",
                             source="A/B2",
                             run_id=uuid.uuid4().hex,
                             status="READY",
                             status_message=None),
        ModelVersionDetailed(registered_model=rm,
                             version=12,
                             creation_timestamp=100,
                             last_updated_timestamp=None,
                             description="v 12",
                             user_id="u2",
                             current_stage="Staging",
                             source="A/B3",
                             run_id=uuid.uuid4().hex,
                             status="READY",
                             status_message=None),
    ]
    mock_model_registry_store.get_latest_versions.return_value = mvds
    resp = _get_latest_versions()
    args, _ = mock_model_registry_store.get_latest_versions.call_args
    assert args == (rm, [])
    assert json.loads(resp.get_data()) == {
        "model_versions_detailed": jsonify(mvds)
    }

    for stages in [[], ["None"], ["Staging"], ["Staging", "Production"]]:
        mock_get_request_message.return_value = GetLatestVersions(
            registered_model=rm.to_proto(), stages=stages)
        _get_latest_versions()
        args, _ = mock_model_registry_store.get_latest_versions.call_args
        assert args == (rm, stages)