def test_download_file_artifact_succeeds_when_artifact_root_is_s3_bucket_root( s3_artifact_root, tmpdir): file_a_name = "a.txt" file_a_text = "A" file_a_path = os.path.join(str(tmpdir), file_a_name) with open(file_a_path, "w") as f: f.write(file_a_text) repo = get_artifact_repository(s3_artifact_root) repo.log_artifact(file_a_path) downloaded_file_path = repo.download_artifacts(file_a_name) with open(downloaded_file_path, "r") as f: assert f.read() == file_a_text
def test_file_artifact_is_logged_and_downloaded_successfully( s3_artifact_root, tmpdir): file_name = "test.txt" file_path = os.path.join(str(tmpdir), file_name) file_text = "Hello world!" with open(file_path, "w") as f: f.write(file_text) repo = get_artifact_repository( posixpath.join(s3_artifact_root, "some/path")) repo.log_artifact(file_path) downloaded_text = open(repo.download_artifacts(file_name)).read() assert downloaded_text == file_text
def test_list_artifacts_with_relative_path(self): list_artifact_file_proto_mock = [ FileInfo(path=posixpath.join(MOCK_SUBDIR, 'a.txt'), is_dir=False, file_size=0) ] list_artifacts_dir_proto_mock = [ FileInfo(path=posixpath.join(MOCK_SUBDIR, 'test/a.txt'), is_dir=False, file_size=100), FileInfo(path=posixpath.join(MOCK_SUBDIR, 'test/dir'), is_dir=True, file_size=0) ] with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \ as get_run_artifact_root_mock, \ mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + '.message_to_json')as message_mock, \ mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + '._call_endpoint') as call_endpoint_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI list_artifact_response_proto = \ ListArtifacts.Response(root_uri='', files=list_artifacts_dir_proto_mock, next_page_token=None) call_endpoint_mock.return_value = list_artifact_response_proto message_mock.return_value = None databricks_artifact_repo = get_artifact_repository( MOCK_SUBDIR_ROOT_URI) artifacts = databricks_artifact_repo.list_artifacts('test') assert isinstance(artifacts, list) assert isinstance(artifacts[0], FileInfoEntity) assert len(artifacts) == 2 assert artifacts[0].path == 'test/a.txt' assert artifacts[0].is_dir is False assert artifacts[0].file_size == 100 assert artifacts[1].path == 'test/dir' assert artifacts[1].is_dir is True assert artifacts[1].file_size is None message_mock.assert_called_with( ListArtifacts(run_id=MOCK_RUN_ID, path=posixpath.join(MOCK_SUBDIR, "test"))) # Calling list_artifacts() on a relative path that's a file should return an empty list list_artifact_response_proto = \ ListArtifacts.Response(root_uri='', files=list_artifact_file_proto_mock, next_page_token=None) call_endpoint_mock.return_value = list_artifact_response_proto artifacts = databricks_artifact_repo.list_artifacts('a.txt') assert len(artifacts) == 0
def log_artifact(self, run_id, local_path, artifact_path=None): """ Write a local file or directory to the remote ``artifact_uri``. :param local_path: Path to the file or directory to write. :param artifact_path: If provided, the directory in ``artifact_uri`` to write to. """ run = self.get_run(run_id) artifact_repo = get_artifact_repository(run.info.artifact_uri) if os.path.isdir(local_path): dir_name = os.path.basename(os.path.normpath(local_path)) path_name = os.path.join(artifact_path, dir_name) \ if artifact_path is not None else dir_name artifact_repo.log_artifacts(local_path, path_name) else: artifact_repo.log_artifact(local_path, artifact_path)
def test_download_directory_artifact_succeeds_when_artifact_root_is_s3_bucket_root( s3_artifact_root, tmpdir): file_a_name = "a.txt" file_a_text = "A" subdir_path = str(tmpdir.mkdir("subdir")) nested_path = os.path.join(subdir_path, "nested") os.makedirs(nested_path) with open(os.path.join(nested_path, file_a_name), "w") as f: f.write(file_a_text) repo = get_artifact_repository(s3_artifact_root) repo.log_artifacts(subdir_path) downloaded_dir_path = repo.download_artifacts("nested") assert file_a_name in os.listdir(downloaded_dir_path) with open(os.path.join(downloaded_dir_path, file_a_name), "r") as f: assert f.read() == file_a_text
def test_plugin_registration_via_installed_package(): """This test requires the package in tests/resources/mlflow-test-plugin to be installed""" reload(artifact_repository_registry) assert ( "file-plugin" in artifact_repository_registry._artifact_repository_registry._registry ) from mlflow_test_plugin.local_artifact import PluginLocalArtifactRepository test_uri = "file-plugin:test-path" plugin_repo = artifact_repository_registry.get_artifact_repository(test_uri) assert isinstance(plugin_repo, PluginLocalArtifactRepository) assert plugin_repo.is_plugin
def test_init_validation_and_cleaning(self): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \ as get_run_artifact_root_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI # Basic artifact uri repo = get_artifact_repository( 'dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts' ) assert repo.artifact_uri == 'dbfs:/databricks/mlflow-tracking/' \ 'MOCK-EXP/MOCK-RUN-ID/artifacts' assert repo.run_id == MOCK_RUN_ID assert repo.run_relative_artifact_repo_root_path == "" with pytest.raises(MlflowException): DatabricksArtifactRepository('s3://test') with pytest.raises(MlflowException): DatabricksArtifactRepository( 'dbfs:/databricks/mlflow/EXP/RUN/artifact')
def test_file_artifact_is_logged_with_content_metadata(s3_artifact_root, tmpdir): file_name = "test.txt" file_path = os.path.join(str(tmpdir), file_name) file_text = "Hello world!" with open(file_path, "w") as f: f.write(file_text) repo = get_artifact_repository( posixpath.join(s3_artifact_root, "some/path")) repo.log_artifact(file_path) bucket, _ = repo.parse_s3_uri(s3_artifact_root) s3_client = repo._get_s3_client() response = s3_client.head_object(Bucket=bucket, Key="some/path/test.txt") assert response.get("ContentType") == "text/plain" assert response.get("ContentEncoding") is None
def download_artifacts(self, run_id, path, dst_path=None): """ Download an artifact file or directory from a run to a local directory if applicable, and return a local path for it. :param run_id: The run to download artifacts from. :param path: Relative source path to the desired artifact. :param dst_path: Absolute path of the local filesystem destination directory to which to download the specified artifacts. This directory must already exist. If unspecified, the artifacts will either be downloaded to a new uniquely-named directory on the local filesystem or will be returned directly in the case of the LocalArtifactRepository. :return: Local path of desired artifact. """ run = self.get_run(run_id) artifact_root = run.info.artifact_uri artifact_repo = get_artifact_repository(artifact_root) return artifact_repo.download_artifacts(path, dst_path)
def test_plugin_registration_via_entrypoints(): mock_plugin_function = mock.Mock() mock_entrypoint = mock.Mock(load=mock.Mock(return_value=mock_plugin_function)) mock_entrypoint.name = "mock-scheme" with mock.patch( "entrypoints.get_group_all", return_value=[mock_entrypoint] ) as mock_get_group_all: artifact_repository_registry = ArtifactRepositoryRegistry() artifact_repository_registry.register_entrypoints() assert ( artifact_repository_registry.get_artifact_repository("mock-scheme://fake-host/fake-path") == mock_plugin_function.return_value ) mock_plugin_function.assert_called_once_with("mock-scheme://fake-host/fake-path") mock_get_group_all.assert_called_once_with("mlflow.artifact_repository")
def download_artifacts(run_id, artifact_path, artifact_uri): """ Download an artifact file or directory to a local directory. The output is the name of the file or directory on the local disk. Either ``--run-id`` or ``--artifact-uri`` must be provided. """ if run_id is None and artifact_uri is None: _logger.error( "Either ``--run-id`` or ``--artifact-uri`` must be provided.") sys.exit(1) if artifact_uri is not None: print(_download_artifact_from_uri(artifact_uri)) return artifact_path = artifact_path if artifact_path is not None else "" store = _get_store() artifact_uri = store.get_run(run_id).info.artifact_uri artifact_repo = get_artifact_repository(artifact_uri) artifact_location = artifact_repo.download_artifacts(artifact_path) print(artifact_location)
def test_file_and_directories_artifacts_are_logged_and_downloaded_successfully_in_batch( s3_artifact_root, tmpdir): subdir_path = str(tmpdir.mkdir("subdir")) nested_path = os.path.join(subdir_path, "nested") os.makedirs(nested_path) with open(os.path.join(subdir_path, "a.txt"), "w") as f: f.write("A") with open(os.path.join(subdir_path, "b.txt"), "w") as f: f.write("B") with open(os.path.join(nested_path, "c.txt"), "w") as f: f.write("C") repo = get_artifact_repository( posixpath.join(s3_artifact_root, "some/path")) repo.log_artifacts(subdir_path) # Download individual files and verify correctness of their contents downloaded_file_a_text = open(repo.download_artifacts("a.txt")).read() assert downloaded_file_a_text == "A" downloaded_file_b_text = open(repo.download_artifacts("b.txt")).read() assert downloaded_file_b_text == "B" downloaded_file_c_text = open( repo.download_artifacts("nested/c.txt")).read() assert downloaded_file_c_text == "C" # Download the nested directory and verify correctness of its contents downloaded_dir = repo.download_artifacts("nested") assert os.path.basename(downloaded_dir) == "nested" text = open(os.path.join(downloaded_dir, "c.txt")).read() assert text == "C" # Download the root directory and verify correctness of its contents downloaded_dir = repo.download_artifacts("") dir_contents = os.listdir(downloaded_dir) assert "nested" in dir_contents assert os.path.isdir(os.path.join(downloaded_dir, "nested")) assert "a.txt" in dir_contents assert "b.txt" in dir_contents
def dbfs_fuse_artifact_repo(force_dbfs_fuse_repo): # pylint: disable=unused-argument return get_artifact_repository('dbfs:/unused/path/replaced/by/mock')
def test_artifact_uri_factory(): repo = get_artifact_repository("ftp://*****:*****@test_ftp:123/some/path") assert isinstance(repo, FTPArtifactRepository)
def dbfs_artifact_repo(): with mock.patch('mlflow.store.artifact.dbfs_artifact_repo._get_host_creds_from_default_store') \ as get_creds_mock: get_creds_mock.return_value = lambda: MlflowHostCreds('http://host') return get_artifact_repository('dbfs:/test/')
def __init__(self, artifact_uri): from kiwi.store.artifact.artifact_repository_registry import get_artifact_repository uri = RunsArtifactRepository.get_underlying_uri(artifact_uri) super(RunsArtifactRepository, self).__init__(artifact_uri) self.repo = get_artifact_repository(uri)
def test_artifact_uri_factory(): from paramiko.ssh_exception import SSHException with pytest.raises(SSHException): get_artifact_repository("sftp://*****:*****@test_sftp:123/some/path")
def _get_artifact_repo(run): return get_artifact_repository(run.info.artifact_uri)
def test_artifact_uri_factory(): repo = get_artifact_repository("gs://test_bucket/some/path") assert isinstance(repo, GCSArtifactRepository)
def databricks_artifact_repo(): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \ as get_run_artifact_root_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI return get_artifact_repository( "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts")
def test_get_unknown_scheme(): artifact_repository_registry = ArtifactRepositoryRegistry() with pytest.raises(kiwi.exceptions.MlflowException, match="Could not find a registered artifact repository"): artifact_repository_registry.get_artifact_repository("unknown-scheme://")