def test_failures(self): def create_test_dag(dag_id): dag = DAG( dag_id, default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE, 'end_date': DEFAULT_DATE, 'project_id': 'test-project', 'region': 'us-east1', }, schedule_interval='@daily', ) return dag input_with_model = self.INPUT_MISSING_ORIGIN.copy() other_params_but_models = { 'task_prefix': 'eval-test', 'batch_prediction_job_id': 'eval-test-prediction', 'data_format': input_with_model['dataFormat'], 'input_paths': input_with_model['inputPaths'], 'prediction_path': input_with_model['outputPath'], 'metric_fn_and_keys': (self.metric_fn, ['err']), 'validate_fn': (lambda x: 'err=%.1f' % x['err']), } with self.assertRaisesRegex(AirflowException, 'Missing model origin'): mlengine_operator_utils.create_evaluate_ops( dag=create_test_dag('test_dag_1'), **other_params_but_models ) with self.assertRaisesRegex(AirflowException, 'Ambiguous model origin'): mlengine_operator_utils.create_evaluate_ops( dag=create_test_dag('test_dag_2'), model_uri='abc', model_name='cde', **other_params_but_models, ) with self.assertRaisesRegex(AirflowException, 'Ambiguous model origin'): mlengine_operator_utils.create_evaluate_ops( dag=create_test_dag('test_dag_3'), model_uri='abc', version_name='vvv', **other_params_but_models, ) with self.assertRaisesRegex(AirflowException, '`metric_fn` param must be callable'): params = other_params_but_models.copy() params['metric_fn_and_keys'] = (None, ['abc']) mlengine_operator_utils.create_evaluate_ops( dag=create_test_dag('test_dag_4'), model_uri='gs://blah', **params ) with self.assertRaisesRegex(AirflowException, '`validate_fn` param must be callable'): params = other_params_but_models.copy() params['validate_fn'] = None mlengine_operator_utils.create_evaluate_ops( dag=create_test_dag('test_dag_5'), model_uri='gs://blah', **params )
def test_non_callable_validate_fn(self): with self.assertRaises(AirflowException): create_evaluate_ops(task_prefix=TASK_PREFIX, data_format=DATA_FORMAT, input_paths=INPUT_PATHS, prediction_path=PREDICTION_PATH, metric_fn_and_keys=get_metric_fn_and_keys(), validate_fn="validate_err_and_count")
def test_non_callable_metric_fn(self): with pytest.raises(AirflowException): create_evaluate_ops( task_prefix=TASK_PREFIX, data_format=DATA_FORMAT, input_paths=INPUT_PATHS, prediction_path=PREDICTION_PATH, metric_fn_and_keys=("error_and_squared_error", ['err', 'mse']), validate_fn=validate_err_and_count, )
def test_invalid_task_prefix(self): invalid_task_prefix_values = ["test-task-prefix&", "~test-task-prefix", "test-task(-prefix"] for invalid_task_prefix_value in invalid_task_prefix_values: with pytest.raises(AirflowException): create_evaluate_ops( task_prefix=invalid_task_prefix_value, data_format=DATA_FORMAT, input_paths=INPUT_PATHS, prediction_path=PREDICTION_PATH, metric_fn_and_keys=get_metric_fn_and_keys(), validate_fn=validate_err_and_count, )
def test_apply_validate_fn(self, mock_dataflow, mock_python, mock_download): result = create_evaluate_ops( task_prefix=TASK_PREFIX, data_format=DATA_FORMAT, input_paths=INPUT_PATHS, prediction_path=PREDICTION_PATH, metric_fn_and_keys=get_metric_fn_and_keys(), validate_fn=validate_err_and_count, batch_prediction_job_id=BATCH_PREDICTION_JOB_ID, project_id=PROJECT_ID, region=REGION, dataflow_options=DATAFLOW_OPTIONS, model_uri=MODEL_URI, ) _, _, evaluate_validation = result mock_download.return_value = json.dumps({"err": 0.3, "mse": 0.04, "count": 1100}) templates_dict = {"prediction_path": PREDICTION_PATH} with pytest.raises(ValueError) as ctx: evaluate_validation.python_callable(templates_dict=templates_dict) assert "Too high err>0.2; summary={'err': 0.3, 'mse': 0.04, 'count': 1100}" == str(ctx.value) mock_download.assert_called_once_with("path", "to/output/predictions.json/prediction.summary.json") invalid_prediction_paths = ["://path/to/output/predictions.json", "gs://", ""] for path in invalid_prediction_paths: templates_dict = {"prediction_path": path} with pytest.raises(ValueError) as ctx: evaluate_validation.python_callable(templates_dict=templates_dict) assert "Wrong format prediction_path:" == str(ctx.value)[:29]
def test_create_evaluate_ops_dag(self, mock_dataflow, mock_python): result = create_evaluate_ops( task_prefix=TASK_PREFIX, data_format=DATA_FORMAT, input_paths=INPUT_PATHS, prediction_path=PREDICTION_PATH, metric_fn_and_keys=get_metric_fn_and_keys(), validate_fn=validate_err_and_count, batch_prediction_job_id=BATCH_PREDICTION_JOB_ID, dag=TEST_DAG, ) evaluate_prediction, evaluate_summary, evaluate_validation = result mock_dataflow.assert_called_once_with(evaluate_prediction) mock_python.assert_called_once_with(evaluate_summary) assert TASK_PREFIX_PREDICTION == evaluate_prediction.task_id assert PROJECT_ID == evaluate_prediction._project_id assert BATCH_PREDICTION_JOB_ID == evaluate_prediction._job_id assert REGION == evaluate_prediction._region assert DATA_FORMAT == evaluate_prediction._data_format assert INPUT_PATHS == evaluate_prediction._input_paths assert PREDICTION_PATH == evaluate_prediction._output_path assert MODEL_NAME == evaluate_prediction._model_name assert VERSION_NAME == evaluate_prediction._version_name assert TASK_PREFIX_SUMMARY == evaluate_summary.task_id assert DATAFLOW_OPTIONS == evaluate_summary.dataflow_default_options assert PREDICTION_PATH == evaluate_summary.options["prediction_path"] assert METRIC_FN_ENCODED == evaluate_summary.options["metric_fn_encoded"] assert METRIC_KEYS_EXPECTED == evaluate_summary.options["metric_keys"] assert TASK_PREFIX_VALIDATION == evaluate_validation.task_id assert PREDICTION_PATH == evaluate_validation.templates_dict["prediction_path"]
def test_create_evaluate_ops_model_and_version_name( self, mock_dataflow, mock_python): result = create_evaluate_ops( task_prefix=TASK_PREFIX, data_format=DATA_FORMAT, input_paths=INPUT_PATHS, prediction_path=PREDICTION_PATH, metric_fn_and_keys=get_metric_fn_and_keys(), validate_fn=validate_err_and_count, batch_prediction_job_id=BATCH_PREDICTION_JOB_ID, project_id=PROJECT_ID, region=REGION, dataflow_options=DATAFLOW_OPTIONS, model_name=MODEL_NAME, version_name=VERSION_NAME, ) evaluate_prediction, evaluate_summary, evaluate_validation = result mock_dataflow.assert_called_once_with(evaluate_prediction) mock_python.assert_called_once_with(evaluate_summary) self.assertEqual(TASK_PREFIX_PREDICTION, evaluate_prediction.task_id) self.assertEqual(PROJECT_ID, evaluate_prediction._project_id) self.assertEqual(BATCH_PREDICTION_JOB_ID, evaluate_prediction._job_id) self.assertEqual(REGION, evaluate_prediction._region) self.assertEqual(DATA_FORMAT, evaluate_prediction._data_format) self.assertEqual(INPUT_PATHS, evaluate_prediction._input_paths) self.assertEqual(PREDICTION_PATH, evaluate_prediction._output_path) self.assertEqual(MODEL_NAME, evaluate_prediction._model_name) self.assertEqual(VERSION_NAME, evaluate_prediction._version_name) self.assertEqual(TASK_PREFIX_SUMMARY, evaluate_summary.task_id) self.assertEqual(DATAFLOW_OPTIONS, evaluate_summary.dataflow_default_options) self.assertEqual(PREDICTION_PATH, evaluate_summary.options["prediction_path"]) self.assertEqual(METRIC_FN_ENCODED, evaluate_summary.options["metric_fn_encoded"]) self.assertEqual(METRIC_KEYS_EXPECTED, evaluate_summary.options["metric_keys"]) self.assertEqual(TASK_PREFIX_VALIDATION, evaluate_validation.task_id) self.assertEqual(PREDICTION_PATH, evaluate_validation.templates_dict["prediction_path"])
raise ValueError(f'Invalid value val != 20; summary={summary}') return summary # [END howto_operator_gcp_mlengine_validate_error] # [START howto_operator_gcp_mlengine_evaluate] evaluate_prediction, evaluate_summary, evaluate_validation = mlengine_operator_utils.create_evaluate_ops( task_prefix="evaluate-ops", data_format="TEXT", input_paths=[PREDICTION_INPUT], prediction_path=PREDICTION_OUTPUT, metric_fn_and_keys=get_metric_fn_and_keys(), validate_fn=validate_err_and_count, batch_prediction_job_id= "evaluate-ops-{{ ts_nodash }}-{{ params.model_name }}", project_id=PROJECT_ID, region="us-central1", dataflow_options={ 'project': PROJECT_ID, 'tempLocation': SUMMARY_TMP, 'stagingLocation': SUMMARY_STAGING, }, model_name=MODEL_NAME, version_name="v1", py_interpreter="python3", ) # [END howto_operator_gcp_mlengine_evaluate] create_model >> create_version >> evaluate_prediction evaluate_validation >> delete_version
def test_successful_run(self): input_with_model = self.INPUT_MISSING_ORIGIN.copy() pred, summary, validate = mlengine_operator_utils.create_evaluate_ops( task_prefix='eval-test', batch_prediction_job_id='eval-test-prediction', data_format=input_with_model['dataFormat'], input_paths=input_with_model['inputPaths'], prediction_path=input_with_model['outputPath'], metric_fn_and_keys=(self.metric_fn, ['err']), validate_fn=(lambda x: 'err=%.1f' % x['err']), dag=self.dag, py_interpreter="python3", ) with patch( 'airflow.providers.google.cloud.operators.mlengine.MLEngineHook' ) as mock_mlengine_hook: success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() success_message['predictionInput'] = input_with_model hook_instance = mock_mlengine_hook.return_value hook_instance.create_job.return_value = success_message result = pred.execute(None) mock_mlengine_hook.assert_called_once_with('google_cloud_default', None) hook_instance.create_job.assert_called_once_with( project_id='test-project', job={ 'jobId': 'eval_test_prediction', 'predictionInput': input_with_model, }, use_existing_job_fn=ANY) self.assertEqual(success_message['predictionOutput'], result) with patch( 'airflow.providers.google.cloud.operators.dataflow.DataflowHook' ) as mock_dataflow_hook: hook_instance = mock_dataflow_hook.return_value hook_instance.start_python_dataflow.return_value = None summary.execute(None) mock_dataflow_hook.assert_called_once_with( gcp_conn_id='google_cloud_default', delegate_to=None, poll_sleep=10) hook_instance.start_python_dataflow.assert_called_once_with( job_name='{{task.task_id}}', variables={ 'prediction_path': 'gs://legal-bucket/fake-output-path', 'labels': { 'airflow-version': TEST_VERSION }, 'metric_keys': 'err', 'metric_fn_encoded': self.metric_fn_encoded, }, dataflow=mock.ANY, py_options=[], py_requirements=['apache-beam[gcp]>=2.14.0'], py_interpreter='python3', py_system_site_packages=False, on_new_job_id_callback=ANY, project_id='test-project', ) with patch( 'airflow.providers.google.cloud.utils.mlengine_operator_utils.GCSHook' ) as mock_gcs_hook: hook_instance = mock_gcs_hook.return_value hook_instance.download.return_value = '{"err": 0.9, "count": 9}' result = validate.execute({}) hook_instance.download.assert_called_once_with( 'legal-bucket', 'fake-output-path/prediction.summary.json') self.assertEqual('err=0.9', result)
def test_successful_run(self): input_with_model = self.INPUT_MISSING_ORIGIN.copy() pred, summary, validate = mlengine_operator_utils.create_evaluate_ops( task_prefix='eval-test', batch_prediction_job_id='eval-test-prediction', data_format=input_with_model['dataFormat'], input_paths=input_with_model['inputPaths'], prediction_path=input_with_model['outputPath'], metric_fn_and_keys=(self.metric_fn, ['err']), validate_fn=(lambda x: f"err={x['err']:.1f}"), dag=self.dag, py_interpreter="python3", ) with patch( 'airflow.providers.google.cloud.operators.mlengine.MLEngineHook' ) as mock_mlengine_hook: success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() success_message['predictionInput'] = input_with_model hook_instance = mock_mlengine_hook.return_value hook_instance.create_job.return_value = success_message result = pred.execute(None) mock_mlengine_hook.assert_called_once_with( 'google_cloud_default', None, impersonation_chain=None, ) hook_instance.create_job.assert_called_once_with( project_id='test-project', job={ 'jobId': 'eval_test_prediction', 'predictionInput': input_with_model, }, use_existing_job_fn=ANY, ) assert success_message['predictionOutput'] == result with patch( 'airflow.providers.google.cloud.operators.dataflow.DataflowHook' ) as mock_dataflow_hook, patch( 'airflow.providers.google.cloud.operators.dataflow.BeamHook' ) as mock_beam_hook: dataflow_hook_instance = mock_dataflow_hook.return_value dataflow_hook_instance.start_python_dataflow.return_value = None beam_hook_instance = mock_beam_hook.return_value summary.execute(None) mock_dataflow_hook.assert_called_once_with( gcp_conn_id='google_cloud_default', delegate_to=None, poll_sleep=10, drain_pipeline=False, cancel_timeout=600, wait_until_finished=None, impersonation_chain=None, ) mock_beam_hook.assert_called_once_with(runner="DataflowRunner") beam_hook_instance.start_python_pipeline.assert_called_once_with( variables={ 'prediction_path': 'gs://legal-bucket/fake-output-path', 'labels': { 'airflow-version': TEST_VERSION }, 'metric_keys': 'err', 'metric_fn_encoded': self.metric_fn_encoded, 'project': 'test-project', 'region': 'us-central1', 'job_name': mock.ANY, }, py_file=mock.ANY, py_options=[], py_interpreter='python3', py_requirements=['apache-beam[gcp]>=2.14.0'], py_system_site_packages=False, process_line_callback=mock.ANY, ) dataflow_hook_instance.wait_for_done.assert_called_once_with( job_name=mock.ANY, location='us-central1', job_id=mock.ANY, multiple_jobs=False) with patch( 'airflow.providers.google.cloud.utils.mlengine_operator_utils.GCSHook' ) as mock_gcs_hook: hook_instance = mock_gcs_hook.return_value hook_instance.download.return_value = '{"err": 0.9, "count": 9}' result = validate.execute({}) hook_instance.download.assert_called_once_with( 'legal-bucket', 'fake-output-path/prediction.summary.json') assert 'err=0.9' == result