Пример #1
0
    def test_create_prediction_job(self, version):
        job_1.status = "completed"
        responses.add("POST", '/v1/models/1/versions/1/jobs',
                      body=json.dumps(job_1.to_dict()),
                      status=200,
                      content_type='application/json')

        bq_src = BigQuerySource(table="project.dataset.source_table",
                                features=["feature_1", "feature2"],
                                options={"key": "val"})

        bq_sink = BigQuerySink(table="project.dataset.result_table",
                               result_column="prediction",
                               save_mode=SaveMode.OVERWRITE,
                               staging_bucket="gs://test",
                               options={"key": "val"})

        job_config = PredictionJobConfig(source=bq_src,
                                         sink=bq_sink,
                                         service_account_name="my-service-account",
                                         result_type=ResultType.INTEGER)

        j = version.create_prediction_job(job_config=job_config)
        assert j.status == JobStatus.COMPLETED
        assert j.id == job_1.id
        assert j.error == job_1.error
        assert j.name == job_1.name

        actual_req = json.loads(responses.calls[0].request.body)
        assert actual_req["config"]["job_config"]["bigquery_source"] == bq_src.to_dict()
        assert actual_req["config"]["job_config"]["bigquery_sink"] == bq_sink.to_dict()
        assert actual_req["config"]["job_config"]["model"]["result"]["type"] == ResultType.INTEGER.value
        assert actual_req["config"]["job_config"]["model"]["uri"] == f"{version.artifact_uri}/model"
        assert actual_req["config"]["job_config"]["model"]["type"] == ModelType.PYFUNC_V2.value.upper()
        assert actual_req["config"]["service_account_name"] == "my-service-account"
Пример #2
0
def test_dictionary_generated(table, staging_bucket, result_column, save_mode,
                              options, expected_dict, expected_valid):
    bq_sink = BigQuerySink(table, staging_bucket, result_column, save_mode,
                           options)
    if expected_valid:
        assert bq_sink.to_dict() == expected_dict
    else:
        with pytest.raises(ValueError):
            bq_sink.to_dict()
Пример #3
0
def test_batch_pyfunc_v2_batch(integration_test_url, project_name,
                               service_account, use_google_oauth,
                               batch_bigquery_source, batch_bigquery_sink,
                               batch_gcs_staging_bucket):
    merlin.set_url(integration_test_url, use_google_oauth=use_google_oauth)
    merlin.set_project(project_name)
    merlin.set_model("batch-iris", ModelType.PYFUNC_V2)
    service_account_name = "*****@*****.**"
    _create_secret(merlin.active_project(), service_account_name,
                   service_account)

    clf = svm.SVC(gamma='scale')
    iris = load_iris()
    X, y = iris.data, iris.target
    clf.fit(X, y)
    joblib.dump(clf, MODEL_PATH)
    # Create new version of the model
    mdl = merlin.active_model()
    v = mdl.new_model_version()
    v.start()
    # Upload the serialized model to MLP
    v.log_pyfunc_model(model_instance=IrisClassifier(),
                       conda_env=ENV_PATH,
                       code_dir=["test"],
                       artifacts={MODEL_PATH_ARTIFACT_KEY: MODEL_PATH})

    v.finish()

    bq_source = BigQuerySource(batch_bigquery_source,
                               features=[
                                   "sepal_length", "sepal_width",
                                   "petal_length", "petal_width"
                               ])
    bq_sink = BigQuerySink(batch_bigquery_sink,
                           staging_bucket=batch_gcs_staging_bucket,
                           result_column="prediction",
                           save_mode=SaveMode.OVERWRITE)
    job_config = PredictionJobConfig(source=bq_source,
                                     sink=bq_sink,
                                     service_account_name=service_account_name,
                                     env_vars={"ALPHA": "0.2"})
    job = v.create_prediction_job(job_config=job_config)

    assert job.status == JobStatus.COMPLETED

    job = v.create_prediction_job(job_config=job_config, sync=False)
    while job.status == JobStatus.PENDING:
        sleep(20)
        job = job.refresh()
    job = job.stop()

    assert job.status == JobStatus.TERMINATED
Пример #4
0
def test_valid(table, staging_bucket, result_column, save_mode, options,
               expected):
    bq_sink = BigQuerySink(table, staging_bucket, result_column, save_mode,
                           options)
    if expected:
        bq_sink._validate()
    else:
        with pytest.raises(ValueError):
            bq_sink._validate()
Пример #5
0
    def test_stop_prediction_job(self, version):
        job_1.status = "pending"
        responses.add("POST",
                      '/v1/models/1/versions/1/jobs',
                      body=json.dumps(job_1.to_dict()),
                      status=200,
                      content_type='application/json')

        responses.add("PUT",
                      '/v1/models/1/versions/1/jobs/1/stop',
                      status=204,
                      content_type='application/json')

        job_1.status = "terminated"
        responses.add("GET",
                      '/v1/models/1/versions/1/jobs/1',
                      body=json.dumps(job_1.to_dict()),
                      status=200,
                      content_type='application/json')

        bq_src = BigQuerySource(table="project.dataset.source_table",
                                features=["feature_1", "feature2"],
                                options={"key": "val"})

        bq_sink = BigQuerySink(table="project.dataset.result_table",
                               result_column="prediction",
                               save_mode=SaveMode.OVERWRITE,
                               staging_bucket="gs://test",
                               options={"key": "val"})

        job_config = PredictionJobConfig(
            source=bq_src,
            sink=bq_sink,
            service_account_name="my-service-account",
            result_type=ResultType.INTEGER)

        j = version.create_prediction_job(job_config=job_config, sync=False)
        j = j.stop()
        assert j.status == JobStatus.TERMINATED
        assert j.id == job_1.id
        assert j.error == job_1.error
        assert j.name == job_1.name
Пример #6
0
    def test_create_prediction_job_with_retry_failed(self, version):
        job_1.status = "pending"
        responses.add("POST",
                      '/v1/models/1/versions/1/jobs',
                      body=json.dumps(job_1.to_dict()),
                      status=200,
                      content_type='application/json')

        for i in range(5):
            responses.add("GET",
                          '/v1/models/1/versions/1/jobs/1',
                          body=json.dumps(job_1.to_dict()),
                          status=500,
                          content_type='application/json')

        bq_src = BigQuerySource(table="project.dataset.source_table",
                                features=["feature_1", "feature2"],
                                options={"key": "val"})

        bq_sink = BigQuerySink(table="project.dataset.result_table",
                               result_column="prediction",
                               save_mode=SaveMode.OVERWRITE,
                               staging_bucket="gs://test",
                               options={"key": "val"})

        job_config = PredictionJobConfig(
            source=bq_src,
            sink=bq_sink,
            service_account_name="my-service-account",
            result_type=ResultType.INTEGER)

        with pytest.raises(ValueError):
            j = version.create_prediction_job(job_config=job_config)
            assert j.id == job_1.id
            assert j.error == job_1.error
            assert j.name == job_1.name
            assert len(responses.calls) == 6
Пример #7
0
    def test_create_prediction_job_with_retry_pending_then_failed(
            self, version):
        job_1.status = "pending"
        responses.add("POST",
                      '/v1/models/1/versions/1/jobs',
                      body=json.dumps(job_1.to_dict()),
                      status=200,
                      content_type='application/json')

        # Patch the method as currently it is not supported in the library
        # https://github.com/getsentry/responses/issues/135
        def _find_match(self, request):
            for match in self._urls:
                if request.method == match['method'] and \
                        self._has_url_match(match, request.url):
                    return match

        def _find_match_patched(self, request):
            for index, match in enumerate(self._urls):
                if request.method == match['method'] and \
                        self._has_url_match(match, request.url):
                    if request.method == "GET" and request.url == "/v1/models/1/versions/1/jobs/1":
                        return self._urls.pop(index)
                    else:
                        return match

        responses._find_match = types.MethodType(_find_match_patched,
                                                 responses)

        for i in range(3):
            responses.add("GET",
                          '/v1/models/1/versions/1/jobs/1',
                          body=json.dumps(job_1.to_dict()),
                          status=500,
                          content_type='application/json')

        responses.add("GET",
                      '/v1/models/1/versions/1/jobs/1',
                      body=json.dumps(job_1.to_dict()),
                      status=200,
                      content_type='application/json')

        job_1.status = "failed"
        for i in range(5):
            responses.add("GET",
                          '/v1/models/1/versions/1/jobs/1',
                          body=json.dumps(job_1.to_dict()),
                          status=500,
                          content_type='application/json')

        bq_src = BigQuerySource(table="project.dataset.source_table",
                                features=["feature_1", "feature2"],
                                options={"key": "val"})

        bq_sink = BigQuerySink(table="project.dataset.result_table",
                               result_column="prediction",
                               save_mode=SaveMode.OVERWRITE,
                               staging_bucket="gs://test",
                               options={"key": "val"})

        job_config = PredictionJobConfig(
            source=bq_src,
            sink=bq_sink,
            service_account_name="my-service-account",
            result_type=ResultType.INTEGER)

        with pytest.raises(ValueError):
            j = version.create_prediction_job(job_config=job_config)
            assert j.id == job_1.id
            assert j.error == job_1.error
            assert j.name == job_1.name

        # unpatch
        responses._find_match = types.MethodType(_find_match, responses)
        assert len(responses.calls) == 10
Пример #8
0
    def test_create_prediction_job_with_retry_success(self, version):
        job_1.status = "pending"
        responses.add("POST",
                      '/v1/models/1/versions/1/jobs',
                      body=json.dumps(job_1.to_dict()),
                      status=200,
                      content_type='application/json')

        # Patch the method as currently it is not supported in the library
        # https://github.com/getsentry/responses/issues/135
        def _find_match(self, request):
            for match in self._urls:
                if request.method == match['method'] and \
                        self._has_url_match(match, request.url):
                    return match

        def _find_match_patched(self, request):
            for index, match in enumerate(self._urls):
                if request.method == match['method'] and \
                        self._has_url_match(match, request.url):
                    if request.method == "GET" and request.url == "/v1/models/1/versions/1/jobs/1":
                        return self._urls.pop(index)
                    else:
                        return match

        responses._find_match = types.MethodType(_find_match_patched,
                                                 responses)

        for i in range(4):
            responses.add("GET",
                          '/v1/models/1/versions/1/jobs/1',
                          body=json.dumps(job_1.to_dict()),
                          status=500,
                          content_type='application/json')

        job_1.status = "completed"
        responses.add("GET",
                      '/v1/models/1/versions/1/jobs/1',
                      body=json.dumps(job_1.to_dict()),
                      status=200,
                      content_type='application/json')

        bq_src = BigQuerySource(table="project.dataset.source_table",
                                features=["feature_1", "feature2"],
                                options={"key": "val"})

        bq_sink = BigQuerySink(table="project.dataset.result_table",
                               result_column="prediction",
                               save_mode=SaveMode.OVERWRITE,
                               staging_bucket="gs://test",
                               options={"key": "val"})

        job_config = PredictionJobConfig(
            source=bq_src,
            sink=bq_sink,
            service_account_name="my-service-account",
            result_type=ResultType.INTEGER)

        j = version.create_prediction_job(job_config=job_config)
        assert j.status == JobStatus.COMPLETED
        assert j.id == job_1.id
        assert j.error == job_1.error
        assert j.name == job_1.name

        actual_req = json.loads(responses.calls[0].request.body)
        assert actual_req["config"]["job_config"][
            "bigquery_source"] == bq_src.to_dict()
        assert actual_req["config"]["job_config"][
            "bigquery_sink"] == bq_sink.to_dict()
        assert actual_req["config"]["job_config"]["model"]["result"][
            "type"] == ResultType.INTEGER.value
        assert actual_req["config"]["job_config"]["model"][
            "uri"] == f"{version.artifact_uri}/model"
        assert actual_req["config"]["job_config"]["model"][
            "type"] == ModelType.PYFUNC_V2.value.upper()
        assert actual_req["config"][
            "service_account_name"] == "my-service-account"
        assert len(responses.calls) == 6

        # unpatch
        responses._find_match = types.MethodType(_find_match, responses)