예제 #1
0
def test_model_version_with_labels(integration_test_url, project_name,
                                   use_google_oauth):
    merlin.set_url(integration_test_url, use_google_oauth=use_google_oauth)
    merlin.set_project(project_name)
    merlin.set_model("sklearn-labels", ModelType.SKLEARN)

    model_dir = "test/sklearn-model"
    MODEL_FILE = "model.joblib"

    undeploy_all_version()

    with merlin.new_model_version(labels={"model": "T-800"}) as v:
        clf = svm.SVC(gamma="scale")
        iris = load_iris()
        X, y = iris.data, iris.target
        clf.fit(X, y)
        dump(clf, os.path.join(model_dir, MODEL_FILE))

        # Upload the serialized model to MLP
        merlin.log_model(model_dir=model_dir)
        assert len(v.labels) == 1
        assert v.labels["model"] == "T-800"

    merlin_active_model = merlin.active_model()
    all_versions = merlin_active_model.list_version(
        labels={"model": ["T-800"]})
    for version in all_versions:
        assert version.labels["model"] == "T-800"

    should_not_exist_versions = merlin_active_model.list_version(
        labels={"model": ["T-1000"]})
    assert len(should_not_exist_versions) == 0
예제 #2
0
def test_set_model(url, project, model, mock_oauth, use_google_oauth):
    # expect exception when setting model but client and project is not set
    with pytest.raises(Exception):
        merlin.set_model(model.name, model.type)

    merlin.set_url(url, use_google_oauth=use_google_oauth)

    with pytest.raises(Exception):
        merlin.set_model(model.name, model.type)

    _mock_get_project_call(project)
    merlin.set_project(project.name)

    _mock_get_model_call(project, model)
    merlin.set_model(model.name, model.type)

    assert merlin.active_model().name == model.name
    assert merlin.active_model().type == model.type
    assert merlin.active_model().id == model.id
    assert merlin.active_model(
    ).mlflow_experiment_id == model.mlflow_experiment_id
예제 #3
0
def test_batch_pyfunc_v2_batch(integration_test_url, project_name,
                               service_account, use_google_oauth,
                               batch_bigquery_source, batch_bigquery_sink,
                               batch_gcs_staging_bucket):
    merlin.set_url(integration_test_url, use_google_oauth=use_google_oauth)
    merlin.set_project(project_name)
    merlin.set_model("batch-iris", ModelType.PYFUNC_V2)
    service_account_name = "*****@*****.**"
    _create_secret(merlin.active_project(), service_account_name,
                   service_account)

    clf = svm.SVC(gamma='scale')
    iris = load_iris()
    X, y = iris.data, iris.target
    clf.fit(X, y)
    joblib.dump(clf, MODEL_PATH)
    # Create new version of the model
    mdl = merlin.active_model()
    v = mdl.new_model_version()
    v.start()
    # Upload the serialized model to MLP
    v.log_pyfunc_model(model_instance=IrisClassifier(),
                       conda_env=ENV_PATH,
                       code_dir=["test"],
                       artifacts={MODEL_PATH_ARTIFACT_KEY: MODEL_PATH})

    v.finish()

    bq_source = BigQuerySource(batch_bigquery_source,
                               features=[
                                   "sepal_length", "sepal_width",
                                   "petal_length", "petal_width"
                               ])
    bq_sink = BigQuerySink(batch_bigquery_sink,
                           staging_bucket=batch_gcs_staging_bucket,
                           result_column="prediction",
                           save_mode=SaveMode.OVERWRITE)
    job_config = PredictionJobConfig(source=bq_source,
                                     sink=bq_sink,
                                     service_account_name=service_account_name,
                                     env_vars={"ALPHA": "0.2"})
    job = v.create_prediction_job(job_config=job_config)

    assert job.status == JobStatus.COMPLETED

    job = v.create_prediction_job(job_config=job_config, sync=False)
    while job.status == JobStatus.PENDING:
        sleep(20)
        job = job.refresh()
    job = job.stop()

    assert job.status == JobStatus.TERMINATED
예제 #4
0
def test_xgboost(integration_test_url, project_name, use_google_oauth):
    merlin.set_url(integration_test_url, use_google_oauth=use_google_oauth)
    merlin.set_project(project_name)
    merlin.set_model("xgboost-sample", ModelType.XGBOOST)
    v = _get_latest_version(merlin.active_model())
    port = _get_free_port()
    p = Process(target=v.start_server, kwargs={"port": port, "build_image": True})
    p.start()
    _wait_server_ready(f"http://{host}:{port}")
    resp = requests.post(_get_local_endpoint(v, port), json=request_json)

    assert resp.status_code == 200
    assert resp.json() is not None
    assert len(resp.json()['predictions']) == len(request_json['instances'])
    p.terminate()
예제 #5
0
def test_cli_deployment_undeployment(deployment_info, runner,
                                     use_google_oauth):

    model_name = 'cli-test'
    merlin.set_url(deployment_info['url'], use_google_oauth=use_google_oauth)
    merlin.set_project(deployment_info['project'])
    merlin.set_model(model_name, ModelType.SKLEARN)

    undeploy_all_version()

    # Deployment
    result = runner.invoke(cli, [
        'deploy', '--env', deployment_info['env'], '--model-type',
        deployment_info['model_type'], '--model-dir',
        deployment_info['model_dir'], '--model-name', model_name, '--project',
        deployment_info['project'], '--url', deployment_info['url']
    ])

    if result.exception:
        traceback.print_exception(*result.exc_info)

    test_deployed_model_version = result.output.split('\n')[0].split(' ')[-1]

    # Get latest deployed model's version
    merlin.set_url(deployment_info['url'], use_google_oauth=use_google_oauth)
    merlin.set_project(deployment_info['project'])
    merlin.set_model(model_name, ModelType.SKLEARN)

    merlin_active_model = merlin.active_model()
    all_versions = merlin_active_model.list_version()

    latest_version = all_versions[0]

    # Undeployment
    undeploy_result = runner.invoke(cli, [
        'undeploy', '--model-version', test_deployed_model_version,
        '--model-name', model_name, '--project', deployment_info['project'],
        '--url', deployment_info['url']
    ])
    if result.exception:
        traceback.print_exception(*result.exc_info)

    planned_output = "Deleting deployment of model {} version {}".format(
        model_name, test_deployed_model_version)
    received_output = undeploy_result.output.split(' from')[0]

    assert latest_version._id == int(test_deployed_model_version)
    assert received_output == planned_output
예제 #6
0
파일: merlin.py 프로젝트: zhangchi1/merlin
def undeploy(model_name, model_version, project, url):

    merlin.set_url(url)
    merlin.set_project(project)
    merlin.set_model(model_name)

    merlin_active_model = merlin.active_model()
    all_versions = merlin_active_model.list_version()

    try:
        wanted_model_info = [
            model_info for model_info in all_versions
            if model_info._id == int(model_version)
        ][0]
    except Exception as e:
        print(e)
        print('Model Version {} is not found.'.format(model_version))

    try:
        merlin.undeploy(wanted_model_info)
    except Exception as e:
        print(e)
예제 #7
0
def test_tensorflow(integration_test_url, project_name, use_google_oauth):
    merlin.set_url(integration_test_url, use_google_oauth=use_google_oauth)
    merlin.set_project(project_name)
    merlin.set_model("tensorflow-sample", ModelType.TENSORFLOW)
    v = _get_latest_version(merlin.active_model())
    port = _get_free_port()
    p = Process(target=v.start_server, kwargs={"port": port, "build_image": True})
    p.start()
    _wait_server_ready(f"http://{host}:{port}/v1/models/{v.model.name}-{v.id}")
    request_json = {
        "signature_name": "predict",
        "instances": [
            {"sepal_length": 2.8, "sepal_width": 1.0, "petal_length": 6.8,
             "petal_width": 0.4},
            {"sepal_length": 0.1, "sepal_width": 0.5, "petal_length": 1.8,
             "petal_width": 2.4}
        ]
    }
    resp = requests.post(_get_local_endpoint(v, port), json=request_json)

    assert resp.status_code == 200
    assert resp.json() is not None
    assert len(resp.json()['predictions']) == len(request_json['instances'])
    p.terminate()
예제 #8
0
def undeploy_all_version():
    for v in merlin.active_model().list_version():
        ve = v.endpoint
        if ve is not None and ve.status == Status.RUNNING:
            merlin.undeploy(v)