def test_download_artifacts_calls_expected_gcs_client_methods( gcs_mock, tmpdir): repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock) def mkfile(fname): fname = os.path.basename(fname) f = tmpdir.join(fname) f.write("hello world!") mock_method_chain( gcs_mock, [ "Client", "bucket", "blob", "download_to_filename", ], side_effect=mkfile, ) repo.download_artifacts("test.txt") assert os.path.exists(os.path.join(tmpdir.strpath, "test.txt")) gcs_mock.Client().bucket.assert_called_with("test_bucket") gcs_mock.Client().bucket().blob.assert_called_with("some/path/test.txt") download_calls = gcs_mock.Client().bucket().blob( ).download_to_filename.call_args_list assert len(download_calls) == 1 download_path_arg = download_calls[0][0][0] assert "test.txt" in download_path_arg
def test_log_artifact(gcs_mock, tmpdir): repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock) d = tmpdir.mkdir("data") f = d.join("test.txt") f.write("hello world!") fpath = d + "/test.txt" fpath = fpath.strpath # This will call isfile on the code path being used, # thus testing that it's being called with an actually file path mock_method_chain( gcs_mock, [ "Client", "bucket", "blob", "upload_from_filename", ], side_effect=os.path.isfile, ) repo.log_artifact(fpath) gcs_mock.Client().bucket.assert_called_with("test_bucket") gcs_mock.Client().bucket().blob.assert_called_with("some/path/test.txt") gcs_mock.Client().bucket().blob().upload_from_filename.assert_called_with( fpath)
def test_log_artifacts(gcs_mock, tmpdir): repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock) subd = tmpdir.mkdir("data").mkdir("subdir") subd.join("a.txt").write("A") subd.join("b.txt").write("B") subd.join("c.txt").write("C") mock_method_chain( gcs_mock, [ "Client", "bucket", "blob", "upload_from_filename", ], side_effect=os.path.isfile, ) repo.log_artifacts(subd.strpath) gcs_mock.Client().bucket.assert_called_with("test_bucket") gcs_mock.Client().bucket().blob().upload_from_filename.assert_has_calls( [ mock.call(os.path.normpath("%s/a.txt" % subd.strpath)), mock.call(os.path.normpath("%s/b.txt" % subd.strpath)), mock.call(os.path.normpath("%s/c.txt" % subd.strpath)), ], any_order=True, )
def test_log_artifact(gcs_mock, tmpdir): repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock) d = tmpdir.mkdir("data") f = d.join("test.txt") f.write("hello world!") fpath = d + "/test.txt" fpath = fpath.strpath # This will call isfile on the code path being used, # thus testing that it's being called with an actually file path def custom_isfile(*args, **kwargs): if args: return os.path.isfile(args[0]) return os.path.isfile(kwargs.get("filename")) mock_method_chain( gcs_mock, [ "Client", "bucket", "blob", "upload_from_filename", ], side_effect=custom_isfile, ) repo.log_artifact(fpath) gcs_mock.Client().bucket.assert_called_with("test_bucket") gcs_mock.Client().bucket().blob.assert_called_with( "some/path/test.txt", chunk_size=repo._GCS_UPLOAD_CHUNK_SIZE) gcs_mock.Client().bucket().blob().upload_from_filename.assert_called_with( fpath, timeout=repo._GCS_DEFAULT_TIMEOUT)
def test_delete_artifacts(gcs_mock): experiment_root_path = "/experiment_id/" repo = GCSArtifactRepository("gs://test_bucket" + experiment_root_path, gcs_mock) def delete_file(): del obj_mock.name del obj_mock.size return obj_mock obj_mock = mock.Mock() run_id_path = experiment_root_path + "run_id/" file_path = "file" attrs = { "name": run_id_path + file_path, "size": 1, "delete.side_effect": delete_file } obj_mock.configure_mock(**attrs) def get_mock_listing(prefix, **kwargs): """ Produces a mock listing that only contains content if the specified prefix is the artifact root. This allows us to mock `list_artifacts` during the `_download_artifacts_into` subroutine without recursively listing the same artifacts at every level of the directory traversal. """ # pylint: disable=unused-argument if hasattr(obj_mock, "name") and hasattr(obj_mock, "size"): mock_results = mock.MagicMock() mock_results.__iter__.return_value = [obj_mock] return mock_results else: mock_empty_results = mock.MagicMock() mock_empty_results.__iter__.return_value = [] return mock_empty_results mock_method_chain( gcs_mock, [ "Client", "bucket", "list_blobs", ], side_effect=get_mock_listing, ) artifact_file_names = [obj.path for obj in repo.list_artifacts()] assert "run_id/file" in artifact_file_names repo.delete_artifacts() artifact_file_names = [obj.path for obj in repo.list_artifacts()] assert not artifact_file_names
def test_get_workspace_info_from_dbutils(): mock_dbutils = mock.MagicMock() methods = ["notebook.entry_point.getDbutils", "notebook", "getContext"] mock_method_chain( mock_dbutils, methods + ["browserHostName", "get"], return_value="mlflow.databricks.com" ) mock_method_chain(mock_dbutils, methods + ["workspaceId", "get"], return_value="1111") with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils): workspace_host, workspace_id = get_workspace_info_from_dbutils() assert workspace_host == "https://mlflow.databricks.com" assert workspace_id == "1111"
def test_get_workspace_info_from_dbutils_old_runtimes(): mock_dbutils = mock.MagicMock() methods = ["notebook.entry_point.getDbutils", "notebook", "getContext"] mock_method_chain( mock_dbutils, methods + ["toJson", "get"], return_value='{"tags": {"orgId" : "1111", "browserHostName": "mlflow.databricks.com"}}', ) mock_method_chain( mock_dbutils, methods + ["browserHostName", "get"], return_value="mlflow.databricks.com" ) # Mock out workspace ID tag mock_workspace_id_tag_opt = mock.MagicMock() mock_workspace_id_tag_opt.isDefined.return_value = True mock_workspace_id_tag_opt.get.return_value = "1111" mock_method_chain( mock_dbutils, methods + ["tags", "get"], return_value=mock_workspace_id_tag_opt ) # Mimic old runtimes by raising an exception when the nonexistent "workspaceId" method is called mock_method_chain( mock_dbutils, methods + ["workspaceId"], side_effect=Exception("workspaceId method not defined!"), ) with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils): workspace_host, workspace_id = get_workspace_info_from_dbutils() assert workspace_host == "https://mlflow.databricks.com" assert workspace_id == "1111"
def test_log_artifacts(gcs_mock, tmpdir): repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock) subd = tmpdir.mkdir("data").mkdir("subdir") subd.join("a.txt").write("A") subd.join("b.txt").write("B") subd.join("c.txt").write("C") def custom_isfile(*args, **kwargs): if args: return os.path.isfile(args[0]) return os.path.isfile(kwargs.get("filename")) mock_method_chain( gcs_mock, [ "Client", "bucket", "blob", "upload_from_filename", ], side_effect=custom_isfile, ) repo.log_artifacts(subd.strpath) gcs_mock.Client().bucket.assert_called_with("test_bucket") gcs_mock.Client().bucket().blob().upload_from_filename.assert_has_calls( [ mock.call(os.path.normpath("%s/a.txt" % subd.strpath), timeout=repo._GCS_DEFAULT_TIMEOUT), mock.call(os.path.normpath("%s/b.txt" % subd.strpath), timeout=repo._GCS_DEFAULT_TIMEOUT), mock.call(os.path.normpath("%s/c.txt" % subd.strpath), timeout=repo._GCS_DEFAULT_TIMEOUT), ], any_order=True, )
def test_download_artifacts_downloads_expected_content(gcs_mock, tmpdir): artifact_root_path = "/experiment_id/run_id/" repo = GCSArtifactRepository("gs://test_bucket" + artifact_root_path, gcs_mock) obj_mock_1 = mock.Mock() file_path_1 = "file1" obj_mock_1.configure_mock(name=os.path.join(artifact_root_path, file_path_1), size=1) obj_mock_2 = mock.Mock() file_path_2 = "file2" obj_mock_2.configure_mock(name=os.path.join(artifact_root_path, file_path_2), size=1) mock_populated_results = mock.MagicMock() mock_populated_results.__iter__.return_value = [obj_mock_1, obj_mock_2] mock_empty_results = mock.MagicMock() mock_empty_results.__iter__.return_value = [] def get_mock_listing(prefix, **kwargs): """ Produces a mock listing that only contains content if the specified prefix is the artifact root. This allows us to mock `list_artifacts` during the `_download_artifacts_into` subroutine without recursively listing the same artifacts at every level of the directory traversal. """ # pylint: disable=unused-argument prefix = os.path.join("/", prefix) if os.path.abspath(prefix) == os.path.abspath(artifact_root_path): return mock_populated_results else: return mock_empty_results def mkfile(fname): fname = os.path.basename(fname) f = tmpdir.join(fname) f.write("hello world!") mock_method_chain( gcs_mock, [ "Client", "bucket", "list_blobs", ], side_effect=get_mock_listing, ) mock_method_chain( gcs_mock, [ "Client", "bucket", "blob", "download_to_filename", ], side_effect=mkfile, ) # Ensure that the root directory can be downloaded successfully repo.download_artifacts("") # Ensure that the `mkfile` side effect copied all of the download artifacts into `tmpdir` dir_contents = os.listdir(tmpdir.strpath) assert file_path_1 in dir_contents assert file_path_2 in dir_contents