예제 #1
0
    def test_list_artifacts(self, databricks_artifact_repo):
        list_artifact_file_proto_mock = [
            FileInfo(path="a.txt", is_dir=False, file_size=0)
        ]
        list_artifacts_dir_proto_mock = [
            FileInfo(path="test/a.txt", is_dir=False, file_size=100),
            FileInfo(path="test/dir", is_dir=True, file_size=0),
        ]
        with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                        "._call_endpoint") as call_endpoint_mock:
            list_artifact_response_proto = ListArtifacts.Response(
                root_uri="",
                files=list_artifacts_dir_proto_mock,
                next_page_token=None)
            call_endpoint_mock.return_value = list_artifact_response_proto
            artifacts = databricks_artifact_repo.list_artifacts("test/")
            assert isinstance(artifacts, list)
            assert isinstance(artifacts[0], FileInfoEntity)
            assert len(artifacts) == 2
            assert artifacts[0].path == "test/a.txt"
            assert artifacts[0].is_dir is False
            assert artifacts[0].file_size == 100
            assert artifacts[1].path == "test/dir"
            assert artifacts[1].is_dir is True
            assert artifacts[1].file_size is None

            # Calling list_artifacts() on a path that's a file should return an empty list
            list_artifact_response_proto = ListArtifacts.Response(
                root_uri="", files=list_artifact_file_proto_mock)
            call_endpoint_mock.return_value = list_artifact_response_proto
            artifacts = databricks_artifact_repo.list_artifacts("a.txt")
            assert len(artifacts) == 0
예제 #2
0
    def test_download_artifacts_provides_failure_info(self, databricks_artifact_repo):
        with mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credential_infos"
        ) as read_credentials_mock, mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts"
        ) as get_list_mock, mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud"
        ) as download_mock:
            get_list_mock.return_value = [
                FileInfo(path="file_1.txt", is_dir=False, file_size=100),
                FileInfo(path="file_2.txt", is_dir=False, file_size=0),
            ]
            read_credentials_mock.return_value = [
                ArtifactCredentialInfo(
                    signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI
                ),
                ArtifactCredentialInfo(
                    signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI
                ),
            ]
            download_mock.side_effect = [
                MlflowException("MOCK ERROR 1"),
                MlflowException("MOCK ERROR 2"),
            ]

            with pytest.raises(MlflowException) as exc:
                databricks_artifact_repo.download_artifacts("test_path")

            err_msg = str(exc.value)
            assert MOCK_RUN_ROOT_URI in err_msg
            assert "file_1.txt" in err_msg
            assert "MOCK ERROR 1" in err_msg
            assert "file_2.txt" in err_msg
            assert "MOCK ERROR 2" in err_msg
예제 #3
0
 def to_proto(self):
     proto = ProtoFileInfo()
     proto.path = self.path
     proto.is_dir = self.is_dir
     if self.file_size:
         proto.file_size = self.file_size
     return proto
예제 #4
0
    def test_list_artifacts_with_relative_path(self):
        list_artifact_file_proto_mock = [
            FileInfo(path=posixpath.join(MOCK_SUBDIR, "a.txt"),
                     is_dir=False,
                     file_size=0)
        ]
        list_artifacts_dir_proto_mock = [
            FileInfo(path=posixpath.join(MOCK_SUBDIR, "test/a.txt"),
                     is_dir=False,
                     file_size=100),
            FileInfo(path=posixpath.join(MOCK_SUBDIR, "test/dir"),
                     is_dir=True,
                     file_size=0),
        ]
        with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                        "._get_run_artifact_root"
                        ) as get_run_artifact_root_mock, mock.patch(
                            DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE +
                            ".message_to_json") as message_mock, mock.patch(
                                DATABRICKS_ARTIFACT_REPOSITORY +
                                "._call_endpoint") as call_endpoint_mock:
            get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
            list_artifact_response_proto = ListArtifacts.Response(
                root_uri="",
                files=list_artifacts_dir_proto_mock,
                next_page_token=None)
            call_endpoint_mock.return_value = list_artifact_response_proto
            message_mock.return_value = None
            databricks_artifact_repo = get_artifact_repository(
                MOCK_SUBDIR_ROOT_URI)
            artifacts = databricks_artifact_repo.list_artifacts("test")
            assert isinstance(artifacts, list)
            assert isinstance(artifacts[0], FileInfoEntity)
            assert len(artifacts) == 2
            assert artifacts[0].path == "test/a.txt"
            assert artifacts[0].is_dir is False
            assert artifacts[0].file_size == 100
            assert artifacts[1].path == "test/dir"
            assert artifacts[1].is_dir is True
            assert artifacts[1].file_size is None
            message_mock.assert_called_with(
                ListArtifacts(run_id=MOCK_RUN_ID,
                              path=posixpath.join(MOCK_SUBDIR, "test")))

            # Calling list_artifacts() on a relative path that's a file should return an empty list
            list_artifact_response_proto = ListArtifacts.Response(
                root_uri="",
                files=list_artifact_file_proto_mock,
                next_page_token=None)
            call_endpoint_mock.return_value = list_artifact_response_proto
            artifacts = databricks_artifact_repo.list_artifacts("a.txt")
            assert len(artifacts) == 0
예제 #5
0
    def test_download_artifacts_awaits_download_completion(self, databricks_artifact_repo, tmpdir):
        """
        Verifies that all asynchronous artifact downloads are joined before `download_artifacts()`
        returns a result to the caller
        """
        with mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credential_infos"
        ) as read_credential_infos_mock, mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts"
        ) as get_list_mock, mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud"
        ) as download_mock:
            ArtifactCredentialInfo(
                signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI
            )
            read_credential_infos_mock.return_value = [
                ArtifactCredentialInfo(
                    signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI
                ),
                ArtifactCredentialInfo(
                    signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI
                ),
            ]
            get_list_mock.return_value = [
                FileInfo(path="file_1.txt", is_dir=False, file_size=100),
                FileInfo(path="file_2.txt", is_dir=False, file_size=0),
            ]

            def mock_download_from_cloud(
                cloud_credential_info, dst_local_file_path
            ):  # pylint: disable=unused-argument
                # Sleep in order to simulate a longer-running asynchronous download
                time.sleep(2)
                with open(dst_local_file_path, "w") as f:
                    f.write("content")

            download_mock.side_effect = mock_download_from_cloud

            databricks_artifact_repo.download_artifacts("test_path", str(tmpdir))

            expected_file1_path = os.path.join(str(tmpdir), "file_1.txt")
            expected_file2_path = os.path.join(str(tmpdir), "file_2.txt")
            for path in [expected_file1_path, expected_file2_path]:
                assert os.path.exists(path)
                with open(path, "r") as f:
                    assert f.read() == "content"
 def test_list_artifacts_with_relative_path(self):
     list_artifacts_dir_proto_mock = [
         FileInfo(path=posixpath.join(MOCK_SUBDIR, 'test/a.txt'),
                  is_dir=False,
                  file_size=100),
         FileInfo(path=posixpath.join(MOCK_SUBDIR, 'test/dir'),
                  is_dir=True,
                  file_size=0)
     ]
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \
             as get_run_artifact_root_mock, \
             mock.patch(
                 DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + '.message_to_json')as message_mock, \
             mock.patch(
                 DATABRICKS_ARTIFACT_REPOSITORY + '._call_endpoint') as call_endpoint_mock:
         get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
         list_artifact_response_proto = \
             ListArtifacts.Response(root_uri='', files=list_artifacts_dir_proto_mock,
                                    next_page_token=None)
         call_endpoint_mock.return_value = list_artifact_response_proto
         message_mock.return_value = None
         databricks_artifact_repo = get_artifact_repository(
             MOCK_SUBDIR_ROOT_URI)
         artifacts = databricks_artifact_repo.list_artifacts('test')
         assert isinstance(artifacts, list)
         assert isinstance(artifacts[0], FileInfoEntity)
         assert len(artifacts) == 2
         assert artifacts[0].path == 'test/a.txt'
         assert artifacts[0].is_dir is False
         assert artifacts[0].file_size == 100
         assert artifacts[1].path == 'test/dir'
         assert artifacts[1].is_dir is True
         assert artifacts[1].file_size is None
         message_mock.assert_called_with(
             ListArtifacts(run_id=MOCK_RUN_ID,
                           path=posixpath.join(MOCK_SUBDIR, "test")))
예제 #7
0
 def test_paginated_list_artifacts(self, databricks_artifact_repo):
     list_artifacts_proto_mock_1 = [
         FileInfo(path="a.txt", is_dir=False, file_size=100),
         FileInfo(path="b", is_dir=True, file_size=0),
     ]
     list_artifacts_proto_mock_2 = [
         FileInfo(path="c.txt", is_dir=False, file_size=100),
         FileInfo(path="d", is_dir=True, file_size=0),
     ]
     list_artifacts_proto_mock_3 = [
         FileInfo(path="e.txt", is_dir=False, file_size=100),
         FileInfo(path="f", is_dir=True, file_size=0),
     ]
     list_artifacts_proto_mock_4 = []
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE +
                     ".message_to_json") as message_mock, mock.patch(
                         DATABRICKS_ARTIFACT_REPOSITORY +
                         "._call_endpoint") as call_endpoint_mock:
         list_artifact_paginated_response_protos = [
             ListArtifacts.Response(root_uri="",
                                    files=list_artifacts_proto_mock_1,
                                    next_page_token="2"),
             ListArtifacts.Response(root_uri="",
                                    files=list_artifacts_proto_mock_2,
                                    next_page_token="4"),
             ListArtifacts.Response(root_uri="",
                                    files=list_artifacts_proto_mock_3,
                                    next_page_token="6"),
             ListArtifacts.Response(root_uri="",
                                    files=list_artifacts_proto_mock_4,
                                    next_page_token="8"),
         ]
         call_endpoint_mock.side_effect = list_artifact_paginated_response_protos
         message_mock.return_value = None
         artifacts = databricks_artifact_repo.list_artifacts()
         assert set(["a.txt", "b", "c.txt", "d", "e.txt",
                     "f"]) == set([file.path for file in artifacts])
         calls = [
             mock.call(ListArtifacts(run_id=MOCK_RUN_ID, path="")),
             mock.call(
                 ListArtifacts(run_id=MOCK_RUN_ID, path="",
                               page_token="2")),
             mock.call(
                 ListArtifacts(run_id=MOCK_RUN_ID, path="",
                               page_token="4")),
             mock.call(
                 ListArtifacts(run_id=MOCK_RUN_ID, path="",
                               page_token="6")),
         ]
         message_mock.assert_has_calls(calls)
 def test_paginated_list_artifacts(self, databricks_artifact_repo):
     list_artifacts_proto_mock_1 = [
         FileInfo(path='a.txt', is_dir=False, file_size=100),
         FileInfo(path='b', is_dir=True, file_size=0)
     ]
     list_artifacts_proto_mock_2 = [
         FileInfo(path='c.txt', is_dir=False, file_size=100),
         FileInfo(path='d', is_dir=True, file_size=0)
     ]
     list_artifacts_proto_mock_3 = [
         FileInfo(path='e.txt', is_dir=False, file_size=100),
         FileInfo(path='f', is_dir=True, file_size=0)
     ]
     list_artifacts_proto_mock_4 = []
     with mock.patch(
             DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + '.message_to_json')as message_mock, \
             mock.patch(
                 DATABRICKS_ARTIFACT_REPOSITORY + '._call_endpoint') as call_endpoint_mock:
         list_artifact_paginated_response_protos = [
             ListArtifacts.Response(root_uri='',
                                    files=list_artifacts_proto_mock_1,
                                    next_page_token='2'),
             ListArtifacts.Response(root_uri='',
                                    files=list_artifacts_proto_mock_2,
                                    next_page_token='4'),
             ListArtifacts.Response(root_uri='',
                                    files=list_artifacts_proto_mock_3,
                                    next_page_token='6'),
             ListArtifacts.Response(root_uri='',
                                    files=list_artifacts_proto_mock_4,
                                    next_page_token='8'),
         ]
         call_endpoint_mock.side_effect = list_artifact_paginated_response_protos
         message_mock.return_value = None
         artifacts = databricks_artifact_repo.list_artifacts()
         assert set(['a.txt', 'b', 'c.txt', 'd', 'e.txt',
                     'f']) == set([file.path for file in artifacts])
         calls = [
             mock.call(ListArtifacts(run_id=MOCK_RUN_ID, path="")),
             mock.call(
                 ListArtifacts(run_id=MOCK_RUN_ID, path="",
                               page_token='2')),
             mock.call(
                 ListArtifacts(run_id=MOCK_RUN_ID, path="",
                               page_token='4')),
             mock.call(
                 ListArtifacts(run_id=MOCK_RUN_ID, path="", page_token='6'))
         ]
         message_mock.assert_has_calls(calls)