def test_runs_artifact_repo_uses_repo_download_artifacts(): """ The RunsArtifactRepo should delegate `download_artifacts` to it's self.repo.download_artifacts function """ artifact_location = "s3://blah_bucket/" experiment_id = mlflow.create_experiment("expr_abcd", artifact_location) with mlflow.start_run(experiment_id=experiment_id): run_id = mlflow.active_run().info.run_id runs_repo = RunsArtifactRepository('runs:/{}'.format(run_id)) runs_repo.repo = Mock() runs_repo.download_artifacts('artifact_path', 'dst_path') runs_repo.repo.download_artifacts.assert_called_once()
def test_runs_artifact_repo_init(): artifact_location = "s3://blah_bucket/" experiment_id = mlflow.create_experiment("expr_abc", artifact_location) with mlflow.start_run(experiment_id=experiment_id): run_id = mlflow.active_run().info.run_id runs_uri = "runs:/%s/path/to/model" % run_id runs_repo = RunsArtifactRepository(runs_uri) assert runs_repo.artifact_uri == runs_uri assert isinstance(runs_repo.repo, S3ArtifactRepository) expected_absolute_uri = "%s%s/artifacts/path/to/model" % (artifact_location, run_id) assert runs_repo.repo.artifact_uri == expected_absolute_uri
def test_parse_runs_uri_invalid_input(uri): with pytest.raises(MlflowException): RunsArtifactRepository.parse_runs_uri(uri)
def test_parse_runs_uri_valid_input(uri, expected_run_id, expected_artifact_path): (run_id, artifact_path) = RunsArtifactRepository.parse_runs_uri(uri) assert run_id == expected_run_id assert artifact_path == expected_artifact_path