コード例 #1
0
def test_project_file_db_roundtrip(create_server):
    server: Server = create_server()
    db: HTTPRunDB = server.conn

    project_name = "project-name"
    description = "project description"
    goals = "project goals"
    desired_state = mlrun.api.schemas.ProjectState.archived
    params = {"param_key": "param value"}
    artifact_path = "/tmp"
    conda = "conda"
    source = "source"
    subpath = "subpath"
    origin_url = "origin_url"
    labels = {"key": "value"}
    annotations = {"annotation-key": "annotation-value"}
    project_metadata = mlrun.projects.project.ProjectMetadata(
        project_name, labels=labels, annotations=annotations,
    )
    project_spec = mlrun.projects.project.ProjectSpec(
        description,
        params,
        artifact_path=artifact_path,
        conda=conda,
        source=source,
        subpath=subpath,
        origin_url=origin_url,
        goals=goals,
        desired_state=desired_state,
    )
    project = mlrun.projects.project.MlrunProject(
        metadata=project_metadata, spec=project_spec
    )
    function_name = "trainer-function"
    function = mlrun.new_function(function_name, project_name)
    project.set_function(function, function_name)
    project.set_function("hub://describe", "describe")
    workflow_name = "workflow-name"
    workflow_file_path = Path(tests_root_directory) / "rundb" / "workflow.py"
    project.set_workflow(workflow_name, str(workflow_file_path))
    artifact_dict = {
        "key": "raw-data",
        "kind": "",
        "iter": 0,
        "tree": "latest",
        "target_path": "https://raw.githubusercontent.com/mlrun/demos/master/customer-churn-prediction/WA_Fn-UseC_-Telc"
        "o-Customer-Churn.csv",
        "db_key": "raw-data",
    }
    project.artifacts = [artifact_dict]
    created_project = db.create_project(project)
    _assert_projects(project, created_project)
    stored_project = db.store_project(project_name, project)
    _assert_projects(project, stored_project)
    patched_project = db.patch_project(project_name, {})
    _assert_projects(project, patched_project)
    get_project = db.get_project(project_name)
    _assert_projects(project, get_project)
    list_projects = db.list_projects()
    _assert_projects(project, list_projects[0])
コード例 #2
0
ファイル: test_project.py プロジェクト: yonittanenbaum/mlrun
def test_sync_functions():
    project_name = "project-name"
    project = mlrun.new_project(project_name)
    project.set_function("hub://describe")
    project_function_object = project.spec._function_objects
    project_file_path = pathlib.Path(tests.conftest.results) / "project.yaml"
    project.export(str(project_file_path))
    imported_project = mlrun.load_project(None, str(project_file_path))
    assert imported_project.spec._function_objects == {}
    imported_project.sync_functions()
    _assert_project_function_objects(imported_project, project_function_object)
コード例 #3
0
def test_sync_functions():
    project_name = "project-name"
    project = mlrun.new_project(project_name)
    project.set_function("hub://describe", "describe")
    project_function_object = project.spec._function_objects
    project_file_path = pathlib.Path(tests.conftest.results) / "project.yaml"
    project.export(str(project_file_path))
    imported_project = mlrun.load_project("./", str(project_file_path))
    assert imported_project.spec._function_objects == {}
    imported_project.sync_functions()
    _assert_project_function_objects(imported_project, project_function_object)

    fn = project.func("describe")
    assert fn.metadata.name == "describe", "func did not return"

    # test that functions can be fetched from the DB (w/o set_function)
    mlrun.import_function("hub://sklearn_classifier", new_name="train").save()
    fn = project.func("train")
    assert fn.metadata.name == "train", "train func did not return"