Ejemplo n.º 1
0
def test_download_artifacts_calls_expected_gcs_client_methods(
        gcs_mock, tmpdir):
    repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock)

    def mkfile(fname):
        fname = os.path.basename(fname)
        f = tmpdir.join(fname)
        f.write("hello world!")

    mock_method_chain(
        gcs_mock,
        [
            "Client",
            "bucket",
            "blob",
            "download_to_filename",
        ],
        side_effect=mkfile,
    )

    repo.download_artifacts("test.txt")
    assert os.path.exists(os.path.join(tmpdir.strpath, "test.txt"))
    gcs_mock.Client().bucket.assert_called_with("test_bucket")
    gcs_mock.Client().bucket().blob.assert_called_with("some/path/test.txt")
    download_calls = gcs_mock.Client().bucket().blob(
    ).download_to_filename.call_args_list
    assert len(download_calls) == 1
    download_path_arg = download_calls[0][0][0]
    assert "test.txt" in download_path_arg
Ejemplo n.º 2
0
def test_log_artifact(gcs_mock, tmpdir):
    repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock)

    d = tmpdir.mkdir("data")
    f = d.join("test.txt")
    f.write("hello world!")
    fpath = d + "/test.txt"
    fpath = fpath.strpath

    # This will call isfile on the code path being used,
    # thus testing that it's being called with an actually file path
    mock_method_chain(
        gcs_mock,
        [
            "Client",
            "bucket",
            "blob",
            "upload_from_filename",
        ],
        side_effect=os.path.isfile,
    )
    repo.log_artifact(fpath)

    gcs_mock.Client().bucket.assert_called_with("test_bucket")
    gcs_mock.Client().bucket().blob.assert_called_with("some/path/test.txt")
    gcs_mock.Client().bucket().blob().upload_from_filename.assert_called_with(
        fpath)
Ejemplo n.º 3
0
def test_log_artifacts(gcs_mock, tmpdir):
    repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock)

    subd = tmpdir.mkdir("data").mkdir("subdir")
    subd.join("a.txt").write("A")
    subd.join("b.txt").write("B")
    subd.join("c.txt").write("C")

    mock_method_chain(
        gcs_mock,
        [
            "Client",
            "bucket",
            "blob",
            "upload_from_filename",
        ],
        side_effect=os.path.isfile,
    )
    repo.log_artifacts(subd.strpath)

    gcs_mock.Client().bucket.assert_called_with("test_bucket")
    gcs_mock.Client().bucket().blob().upload_from_filename.assert_has_calls(
        [
            mock.call(os.path.normpath("%s/a.txt" % subd.strpath)),
            mock.call(os.path.normpath("%s/b.txt" % subd.strpath)),
            mock.call(os.path.normpath("%s/c.txt" % subd.strpath)),
        ],
        any_order=True,
    )
Ejemplo n.º 4
0
def test_log_artifact(gcs_mock, tmpdir):
    repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock)

    d = tmpdir.mkdir("data")
    f = d.join("test.txt")
    f.write("hello world!")
    fpath = d + "/test.txt"
    fpath = fpath.strpath

    # This will call isfile on the code path being used,
    # thus testing that it's being called with an actually file path
    def custom_isfile(*args, **kwargs):
        if args:
            return os.path.isfile(args[0])
        return os.path.isfile(kwargs.get("filename"))

    mock_method_chain(
        gcs_mock,
        [
            "Client",
            "bucket",
            "blob",
            "upload_from_filename",
        ],
        side_effect=custom_isfile,
    )
    repo.log_artifact(fpath)

    gcs_mock.Client().bucket.assert_called_with("test_bucket")
    gcs_mock.Client().bucket().blob.assert_called_with(
        "some/path/test.txt", chunk_size=repo._GCS_UPLOAD_CHUNK_SIZE)
    gcs_mock.Client().bucket().blob().upload_from_filename.assert_called_with(
        fpath, timeout=repo._GCS_DEFAULT_TIMEOUT)
Ejemplo n.º 5
0
def test_delete_artifacts(gcs_mock):
    experiment_root_path = "/experiment_id/"
    repo = GCSArtifactRepository("gs://test_bucket" + experiment_root_path,
                                 gcs_mock)

    def delete_file():
        del obj_mock.name
        del obj_mock.size
        return obj_mock

    obj_mock = mock.Mock()
    run_id_path = experiment_root_path + "run_id/"
    file_path = "file"
    attrs = {
        "name": run_id_path + file_path,
        "size": 1,
        "delete.side_effect": delete_file
    }
    obj_mock.configure_mock(**attrs)

    def get_mock_listing(prefix, **kwargs):
        """
        Produces a mock listing that only contains content if the
        specified prefix is the artifact root. This allows us to mock
        `list_artifacts` during the `_download_artifacts_into` subroutine
        without recursively listing the same artifacts at every level of the
        directory traversal.
        """

        # pylint: disable=unused-argument
        if hasattr(obj_mock, "name") and hasattr(obj_mock, "size"):
            mock_results = mock.MagicMock()
            mock_results.__iter__.return_value = [obj_mock]
            return mock_results
        else:
            mock_empty_results = mock.MagicMock()
            mock_empty_results.__iter__.return_value = []
            return mock_empty_results

    mock_method_chain(
        gcs_mock,
        [
            "Client",
            "bucket",
            "list_blobs",
        ],
        side_effect=get_mock_listing,
    )

    artifact_file_names = [obj.path for obj in repo.list_artifacts()]
    assert "run_id/file" in artifact_file_names
    repo.delete_artifacts()
    artifact_file_names = [obj.path for obj in repo.list_artifacts()]
    assert not artifact_file_names
Ejemplo n.º 6
0
def test_get_workspace_info_from_dbutils():
    mock_dbutils = mock.MagicMock()
    methods = ["notebook.entry_point.getDbutils", "notebook", "getContext"]
    mock_method_chain(
        mock_dbutils, methods + ["browserHostName", "get"], return_value="mlflow.databricks.com"
    )
    mock_method_chain(mock_dbutils, methods + ["workspaceId", "get"], return_value="1111")

    with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils):
        workspace_host, workspace_id = get_workspace_info_from_dbutils()
        assert workspace_host == "https://mlflow.databricks.com"
        assert workspace_id == "1111"
Ejemplo n.º 7
0
def test_get_workspace_info_from_dbutils_old_runtimes():
    mock_dbutils = mock.MagicMock()
    methods = ["notebook.entry_point.getDbutils", "notebook", "getContext"]
    mock_method_chain(
        mock_dbutils,
        methods + ["toJson", "get"],
        return_value='{"tags": {"orgId" : "1111", "browserHostName": "mlflow.databricks.com"}}',
    )
    mock_method_chain(
        mock_dbutils, methods + ["browserHostName", "get"], return_value="mlflow.databricks.com"
    )

    # Mock out workspace ID tag
    mock_workspace_id_tag_opt = mock.MagicMock()
    mock_workspace_id_tag_opt.isDefined.return_value = True
    mock_workspace_id_tag_opt.get.return_value = "1111"
    mock_method_chain(
        mock_dbutils, methods + ["tags", "get"], return_value=mock_workspace_id_tag_opt
    )

    # Mimic old runtimes by raising an exception when the nonexistent "workspaceId" method is called
    mock_method_chain(
        mock_dbutils,
        methods + ["workspaceId"],
        side_effect=Exception("workspaceId method not defined!"),
    )
    with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils):
        workspace_host, workspace_id = get_workspace_info_from_dbutils()
        assert workspace_host == "https://mlflow.databricks.com"
        assert workspace_id == "1111"
Ejemplo n.º 8
0
def test_log_artifacts(gcs_mock, tmpdir):
    repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock)

    subd = tmpdir.mkdir("data").mkdir("subdir")
    subd.join("a.txt").write("A")
    subd.join("b.txt").write("B")
    subd.join("c.txt").write("C")

    def custom_isfile(*args, **kwargs):
        if args:
            return os.path.isfile(args[0])
        return os.path.isfile(kwargs.get("filename"))

    mock_method_chain(
        gcs_mock,
        [
            "Client",
            "bucket",
            "blob",
            "upload_from_filename",
        ],
        side_effect=custom_isfile,
    )
    repo.log_artifacts(subd.strpath)

    gcs_mock.Client().bucket.assert_called_with("test_bucket")
    gcs_mock.Client().bucket().blob().upload_from_filename.assert_has_calls(
        [
            mock.call(os.path.normpath("%s/a.txt" % subd.strpath),
                      timeout=repo._GCS_DEFAULT_TIMEOUT),
            mock.call(os.path.normpath("%s/b.txt" % subd.strpath),
                      timeout=repo._GCS_DEFAULT_TIMEOUT),
            mock.call(os.path.normpath("%s/c.txt" % subd.strpath),
                      timeout=repo._GCS_DEFAULT_TIMEOUT),
        ],
        any_order=True,
    )
Ejemplo n.º 9
0
def test_download_artifacts_downloads_expected_content(gcs_mock, tmpdir):
    artifact_root_path = "/experiment_id/run_id/"
    repo = GCSArtifactRepository("gs://test_bucket" + artifact_root_path,
                                 gcs_mock)

    obj_mock_1 = mock.Mock()
    file_path_1 = "file1"
    obj_mock_1.configure_mock(name=os.path.join(artifact_root_path,
                                                file_path_1),
                              size=1)
    obj_mock_2 = mock.Mock()
    file_path_2 = "file2"
    obj_mock_2.configure_mock(name=os.path.join(artifact_root_path,
                                                file_path_2),
                              size=1)
    mock_populated_results = mock.MagicMock()
    mock_populated_results.__iter__.return_value = [obj_mock_1, obj_mock_2]

    mock_empty_results = mock.MagicMock()
    mock_empty_results.__iter__.return_value = []

    def get_mock_listing(prefix, **kwargs):
        """
        Produces a mock listing that only contains content if the
        specified prefix is the artifact root. This allows us to mock
        `list_artifacts` during the `_download_artifacts_into` subroutine
        without recursively listing the same artifacts at every level of the
        directory traversal.
        """
        # pylint: disable=unused-argument
        prefix = os.path.join("/", prefix)
        if os.path.abspath(prefix) == os.path.abspath(artifact_root_path):
            return mock_populated_results
        else:
            return mock_empty_results

    def mkfile(fname):
        fname = os.path.basename(fname)
        f = tmpdir.join(fname)
        f.write("hello world!")

    mock_method_chain(
        gcs_mock,
        [
            "Client",
            "bucket",
            "list_blobs",
        ],
        side_effect=get_mock_listing,
    )
    mock_method_chain(
        gcs_mock,
        [
            "Client",
            "bucket",
            "blob",
            "download_to_filename",
        ],
        side_effect=mkfile,
    )

    # Ensure that the root directory can be downloaded successfully
    repo.download_artifacts("")
    # Ensure that the `mkfile` side effect copied all of the download artifacts into `tmpdir`
    dir_contents = os.listdir(tmpdir.strpath)
    assert file_path_1 in dir_contents
    assert file_path_2 in dir_contents