def test_create_prediction_job(self, version): job_1.status = "completed" responses.add("POST", '/v1/models/1/versions/1/jobs', body=json.dumps(job_1.to_dict()), status=200, content_type='application/json') bq_src = BigQuerySource(table="project.dataset.source_table", features=["feature_1", "feature2"], options={"key": "val"}) bq_sink = BigQuerySink(table="project.dataset.result_table", result_column="prediction", save_mode=SaveMode.OVERWRITE, staging_bucket="gs://test", options={"key": "val"}) job_config = PredictionJobConfig(source=bq_src, sink=bq_sink, service_account_name="my-service-account", result_type=ResultType.INTEGER) j = version.create_prediction_job(job_config=job_config) assert j.status == JobStatus.COMPLETED assert j.id == job_1.id assert j.error == job_1.error assert j.name == job_1.name actual_req = json.loads(responses.calls[0].request.body) assert actual_req["config"]["job_config"]["bigquery_source"] == bq_src.to_dict() assert actual_req["config"]["job_config"]["bigquery_sink"] == bq_sink.to_dict() assert actual_req["config"]["job_config"]["model"]["result"]["type"] == ResultType.INTEGER.value assert actual_req["config"]["job_config"]["model"]["uri"] == f"{version.artifact_uri}/model" assert actual_req["config"]["job_config"]["model"]["type"] == ModelType.PYFUNC_V2.value.upper() assert actual_req["config"]["service_account_name"] == "my-service-account"
def test_batch_pyfunc_v2_batch(integration_test_url, project_name, service_account, use_google_oauth, batch_bigquery_source, batch_bigquery_sink, batch_gcs_staging_bucket): merlin.set_url(integration_test_url, use_google_oauth=use_google_oauth) merlin.set_project(project_name) merlin.set_model("batch-iris", ModelType.PYFUNC_V2) service_account_name = "*****@*****.**" _create_secret(merlin.active_project(), service_account_name, service_account) clf = svm.SVC(gamma='scale') iris = load_iris() X, y = iris.data, iris.target clf.fit(X, y) joblib.dump(clf, MODEL_PATH) # Create new version of the model mdl = merlin.active_model() v = mdl.new_model_version() v.start() # Upload the serialized model to MLP v.log_pyfunc_model(model_instance=IrisClassifier(), conda_env=ENV_PATH, code_dir=["test"], artifacts={MODEL_PATH_ARTIFACT_KEY: MODEL_PATH}) v.finish() bq_source = BigQuerySource(batch_bigquery_source, features=[ "sepal_length", "sepal_width", "petal_length", "petal_width" ]) bq_sink = BigQuerySink(batch_bigquery_sink, staging_bucket=batch_gcs_staging_bucket, result_column="prediction", save_mode=SaveMode.OVERWRITE) job_config = PredictionJobConfig(source=bq_source, sink=bq_sink, service_account_name=service_account_name, env_vars={"ALPHA": "0.2"}) job = v.create_prediction_job(job_config=job_config) assert job.status == JobStatus.COMPLETED job = v.create_prediction_job(job_config=job_config, sync=False) while job.status == JobStatus.PENDING: sleep(20) job = job.refresh() job = job.stop() assert job.status == JobStatus.TERMINATED
def test_stop_prediction_job(self, version): job_1.status = "pending" responses.add("POST", '/v1/models/1/versions/1/jobs', body=json.dumps(job_1.to_dict()), status=200, content_type='application/json') responses.add("PUT", '/v1/models/1/versions/1/jobs/1/stop', status=204, content_type='application/json') job_1.status = "terminated" responses.add("GET", '/v1/models/1/versions/1/jobs/1', body=json.dumps(job_1.to_dict()), status=200, content_type='application/json') bq_src = BigQuerySource(table="project.dataset.source_table", features=["feature_1", "feature2"], options={"key": "val"}) bq_sink = BigQuerySink(table="project.dataset.result_table", result_column="prediction", save_mode=SaveMode.OVERWRITE, staging_bucket="gs://test", options={"key": "val"}) job_config = PredictionJobConfig( source=bq_src, sink=bq_sink, service_account_name="my-service-account", result_type=ResultType.INTEGER) j = version.create_prediction_job(job_config=job_config, sync=False) j = j.stop() assert j.status == JobStatus.TERMINATED assert j.id == job_1.id assert j.error == job_1.error assert j.name == job_1.name
def test_create_prediction_job_with_retry_failed(self, version): job_1.status = "pending" responses.add("POST", '/v1/models/1/versions/1/jobs', body=json.dumps(job_1.to_dict()), status=200, content_type='application/json') for i in range(5): responses.add("GET", '/v1/models/1/versions/1/jobs/1', body=json.dumps(job_1.to_dict()), status=500, content_type='application/json') bq_src = BigQuerySource(table="project.dataset.source_table", features=["feature_1", "feature2"], options={"key": "val"}) bq_sink = BigQuerySink(table="project.dataset.result_table", result_column="prediction", save_mode=SaveMode.OVERWRITE, staging_bucket="gs://test", options={"key": "val"}) job_config = PredictionJobConfig( source=bq_src, sink=bq_sink, service_account_name="my-service-account", result_type=ResultType.INTEGER) with pytest.raises(ValueError): j = version.create_prediction_job(job_config=job_config) assert j.id == job_1.id assert j.error == job_1.error assert j.name == job_1.name assert len(responses.calls) == 6
def test_create_prediction_job_with_retry_pending_then_failed( self, version): job_1.status = "pending" responses.add("POST", '/v1/models/1/versions/1/jobs', body=json.dumps(job_1.to_dict()), status=200, content_type='application/json') # Patch the method as currently it is not supported in the library # https://github.com/getsentry/responses/issues/135 def _find_match(self, request): for match in self._urls: if request.method == match['method'] and \ self._has_url_match(match, request.url): return match def _find_match_patched(self, request): for index, match in enumerate(self._urls): if request.method == match['method'] and \ self._has_url_match(match, request.url): if request.method == "GET" and request.url == "/v1/models/1/versions/1/jobs/1": return self._urls.pop(index) else: return match responses._find_match = types.MethodType(_find_match_patched, responses) for i in range(3): responses.add("GET", '/v1/models/1/versions/1/jobs/1', body=json.dumps(job_1.to_dict()), status=500, content_type='application/json') responses.add("GET", '/v1/models/1/versions/1/jobs/1', body=json.dumps(job_1.to_dict()), status=200, content_type='application/json') job_1.status = "failed" for i in range(5): responses.add("GET", '/v1/models/1/versions/1/jobs/1', body=json.dumps(job_1.to_dict()), status=500, content_type='application/json') bq_src = BigQuerySource(table="project.dataset.source_table", features=["feature_1", "feature2"], options={"key": "val"}) bq_sink = BigQuerySink(table="project.dataset.result_table", result_column="prediction", save_mode=SaveMode.OVERWRITE, staging_bucket="gs://test", options={"key": "val"}) job_config = PredictionJobConfig( source=bq_src, sink=bq_sink, service_account_name="my-service-account", result_type=ResultType.INTEGER) with pytest.raises(ValueError): j = version.create_prediction_job(job_config=job_config) assert j.id == job_1.id assert j.error == job_1.error assert j.name == job_1.name # unpatch responses._find_match = types.MethodType(_find_match, responses) assert len(responses.calls) == 10
def test_create_prediction_job_with_retry_success(self, version): job_1.status = "pending" responses.add("POST", '/v1/models/1/versions/1/jobs', body=json.dumps(job_1.to_dict()), status=200, content_type='application/json') # Patch the method as currently it is not supported in the library # https://github.com/getsentry/responses/issues/135 def _find_match(self, request): for match in self._urls: if request.method == match['method'] and \ self._has_url_match(match, request.url): return match def _find_match_patched(self, request): for index, match in enumerate(self._urls): if request.method == match['method'] and \ self._has_url_match(match, request.url): if request.method == "GET" and request.url == "/v1/models/1/versions/1/jobs/1": return self._urls.pop(index) else: return match responses._find_match = types.MethodType(_find_match_patched, responses) for i in range(4): responses.add("GET", '/v1/models/1/versions/1/jobs/1', body=json.dumps(job_1.to_dict()), status=500, content_type='application/json') job_1.status = "completed" responses.add("GET", '/v1/models/1/versions/1/jobs/1', body=json.dumps(job_1.to_dict()), status=200, content_type='application/json') bq_src = BigQuerySource(table="project.dataset.source_table", features=["feature_1", "feature2"], options={"key": "val"}) bq_sink = BigQuerySink(table="project.dataset.result_table", result_column="prediction", save_mode=SaveMode.OVERWRITE, staging_bucket="gs://test", options={"key": "val"}) job_config = PredictionJobConfig( source=bq_src, sink=bq_sink, service_account_name="my-service-account", result_type=ResultType.INTEGER) j = version.create_prediction_job(job_config=job_config) assert j.status == JobStatus.COMPLETED assert j.id == job_1.id assert j.error == job_1.error assert j.name == job_1.name actual_req = json.loads(responses.calls[0].request.body) assert actual_req["config"]["job_config"][ "bigquery_source"] == bq_src.to_dict() assert actual_req["config"]["job_config"][ "bigquery_sink"] == bq_sink.to_dict() assert actual_req["config"]["job_config"]["model"]["result"][ "type"] == ResultType.INTEGER.value assert actual_req["config"]["job_config"]["model"][ "uri"] == f"{version.artifact_uri}/model" assert actual_req["config"]["job_config"]["model"][ "type"] == ModelType.PYFUNC_V2.value.upper() assert actual_req["config"][ "service_account_name"] == "my-service-account" assert len(responses.calls) == 6 # unpatch responses._find_match = types.MethodType(_find_match, responses)