예제 #1
0
def test_get_model_version_download_uri(mock_get_request_message, mock_model_registry_store):
    name = "model1"
    version = "32"
    mock_get_request_message.return_value = GetModelVersionDownloadUri(name=name, version=version)
    mock_model_registry_store.get_model_version_download_uri.return_value = "some/download/path"
    resp = _get_model_version_download_uri()
    _, args = mock_model_registry_store.get_model_version_download_uri.call_args
    assert args == {"name": name, "version": version}
    assert json.loads(resp.get_data()) == {"artifact_uri": "some/download/path"}
예제 #2
0
def test_get_model_version_download_uri(mock_get_request_message, mock_model_registry_store):
    rm = RegisteredModel("model1")
    mv = ModelVersion(registered_model=rm, version=32)
    mock_get_request_message.return_value = GetModelVersionDownloadUri(model_version=mv.to_proto())
    mock_model_registry_store.get_model_version_download_uri.return_value = "some/download/path"
    resp = _get_model_version_download_uri()
    args, _ = mock_model_registry_store.get_model_version_download_uri.call_args
    assert args == (mv, )
    assert json.loads(resp.get_data()) == {"artifact_uri": "some/download/path"}