def test_plugin_registration():
    artifact_repository_registry = ArtifactRepositoryRegistry()

    mock_plugin = mock.Mock()
    artifact_repository_registry.register("mock-scheme", mock_plugin)
    assert "mock-scheme" in artifact_repository_registry._registry
    repository_instance = artifact_repository_registry.get_artifact_repository(
        artifact_uri="mock-scheme://fake-host/fake-path"
    )
    assert repository_instance == mock_plugin.return_value

    mock_plugin.assert_called_once_with("mock-scheme://fake-host/fake-path")
def test_dbfs_instantiation():
    artifact_repository_registry = ArtifactRepositoryRegistry()

    mock_dbfs_constructor = mock.Mock()
    artifact_repository_registry.register("dbfs", mock_dbfs_constructor)

    mock_get_host_creds = mock.Mock()
    rest_store = RestStore(mock_get_host_creds)

    mock_dbfs_repo = artifact_repository_registry.get_artifact_repository(
        artifact_uri="dbfs://test-path", store=rest_store
    )
    assert mock_dbfs_repo == mock_dbfs_constructor.return_value
    mock_dbfs_constructor.assert_called_once_with("dbfs://test-path", mock_get_host_creds)
def test_plugin_registration_failure_via_entrypoints(exception):
    mock_entrypoint = mock.Mock(load=mock.Mock(side_effect=exception))
    mock_entrypoint.name = "mock-scheme"

    with mock.patch(
        "entrypoints.get_group_all", return_value=[mock_entrypoint]
    ) as mock_get_group_all:

        repo_registry = ArtifactRepositoryRegistry()

        # Check that the raised warning contains the message from the original exception
        with pytest.warns(UserWarning, match="test exception"):
            repo_registry.register_entrypoints()

    mock_entrypoint.load.assert_called_once()
    mock_get_group_all.assert_called_once_with("mlflow.artifact_repository")
def test_plugin_registration_via_entrypoints():
    mock_plugin_function = mock.Mock()
    mock_entrypoint = mock.Mock(load=mock.Mock(return_value=mock_plugin_function))
    mock_entrypoint.name = "mock-scheme"

    with mock.patch(
        "entrypoints.get_group_all", return_value=[mock_entrypoint]
    ) as mock_get_group_all:

        artifact_repository_registry = ArtifactRepositoryRegistry()
        artifact_repository_registry.register_entrypoints()

    assert (
        artifact_repository_registry.get_artifact_repository("mock-scheme://fake-host/fake-path")
        == mock_plugin_function.return_value
    )

    mock_plugin_function.assert_called_once_with("mock-scheme://fake-host/fake-path")
    mock_get_group_all.assert_called_once_with("mlflow.artifact_repository")
def test_incorrect_dbfs_instantiation():
    artifact_repository_registry = ArtifactRepositoryRegistry()

    mock_dbfs_constructor = mock.Mock()
    artifact_repository_registry.register("dbfs", mock_dbfs_constructor)

    sql_store = SqlAlchemyStore("sqlite://", "./mlruns")

    with pytest.raises(mlflow.exceptions.MlflowException, match="must be an instance of RestStore"):
        artifact_repository_registry.get_artifact_repository(
            artifact_uri="dbfs://test-path", store=sql_store
        )

    mock_dbfs_constructor.assert_not_called()
def test_get_unknown_scheme():
    artifact_repository_registry = ArtifactRepositoryRegistry()

    with pytest.raises(mlflow.exceptions.MlflowException,
                       match="Could not find a registered artifact repository"):
        artifact_repository_registry.get_artifact_repository("unknown-scheme://")