def test_failures(self): dag = DAG('test_dag', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE, 'end_date': DEFAULT_DATE, 'project_id': 'test-project', 'region': 'us-east1', }, schedule_interval='@daily') input_with_model = self.INPUT_MISSING_ORIGIN.copy() other_params_but_models = { 'task_prefix': 'eval-test', 'batch_prediction_job_id': 'eval-test-prediction', 'data_format': input_with_model['dataFormat'], 'input_paths': input_with_model['inputPaths'], 'prediction_path': input_with_model['outputPath'], 'metric_fn_and_keys': (self.metric_fn, ['err']), 'validate_fn': (lambda x: 'err=%.1f' % x['err']), 'dag': dag, } with self.assertRaisesRegex(AirflowException, 'Missing model origin'): mlengine_operator_utils.create_evaluate_ops( **other_params_but_models) with self.assertRaisesRegex(AirflowException, 'Ambiguous model origin'): mlengine_operator_utils.create_evaluate_ops( model_uri='abc', model_name='cde', **other_params_but_models) with self.assertRaisesRegex(AirflowException, 'Ambiguous model origin'): mlengine_operator_utils.create_evaluate_ops( model_uri='abc', version_name='vvv', **other_params_but_models) with self.assertRaisesRegex(AirflowException, '`metric_fn` param must be callable'): params = other_params_but_models.copy() params['metric_fn_and_keys'] = (None, ['abc']) mlengine_operator_utils.create_evaluate_ops(model_uri='gs://blah', **params) with self.assertRaisesRegex(AirflowException, '`validate_fn` param must be callable'): params = other_params_but_models.copy() params['validate_fn'] = None mlengine_operator_utils.create_evaluate_ops(model_uri='gs://blah', **params)
def test_successful_run(self): input_with_model = self.INPUT_MISSING_ORIGIN.copy() pred, summary, validate = mlengine_operator_utils.create_evaluate_ops( task_prefix='eval-test', batch_prediction_job_id='eval-test-prediction', data_format=input_with_model['dataFormat'], input_paths=input_with_model['inputPaths'], prediction_path=input_with_model['outputPath'], metric_fn_and_keys=(self.metric_fn, ['err']), validate_fn=(lambda x: 'err=%.1f' % x['err']), dag=self.dag, py_interpreter="python3", ) with patch('airflow.gcp.operators.mlengine.MLEngineHook' ) as mock_mlengine_hook: success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() success_message['predictionInput'] = input_with_model hook_instance = mock_mlengine_hook.return_value hook_instance.create_job.return_value = success_message result = pred.execute(None) mock_mlengine_hook.assert_called_once_with('google_cloud_default', None) hook_instance.create_job.assert_called_once_with( 'test-project', { 'jobId': 'eval_test_prediction', 'predictionInput': input_with_model, }, ANY) self.assertEqual(success_message['predictionOutput'], result) with patch('airflow.gcp.operators.dataflow.DataFlowHook' ) as mock_dataflow_hook: hook_instance = mock_dataflow_hook.return_value hook_instance.start_python_dataflow.return_value = None summary.execute(None) mock_dataflow_hook.assert_called_once_with( gcp_conn_id='google_cloud_default', delegate_to=None, poll_sleep=10) hook_instance.start_python_dataflow.assert_called_once_with( '{{task.task_id}}', { 'prediction_path': 'gs://legal-bucket/fake-output-path', 'labels': { 'airflow-version': TEST_VERSION }, 'metric_keys': 'err', 'metric_fn_encoded': self.metric_fn_encoded, }, 'airflow.gcp.utils.mlengine_prediction_summary', ['-m'], py_interpreter='python3') with patch('airflow.gcp.utils.mlengine_operator_utils.' 'GoogleCloudStorageHook') as mock_gcs_hook: hook_instance = mock_gcs_hook.return_value hook_instance.download.return_value = '{"err": 0.9, "count": 9}' result = validate.execute({}) hook_instance.download.assert_called_once_with( 'legal-bucket', 'fake-output-path/prediction.summary.json') self.assertEqual('err=0.9', result)
if summary['val'] < 0: raise ValueError('Too low val<0; summary={}'.format(summary)) if summary['count'] != 20: raise ValueError( 'Invalid value val != 20; summary={}'.format(summary)) return summary evaluate_prediction, evaluate_summary, evaluate_validation = mlengine_operator_utils.create_evaluate_ops( task_prefix="evalueate-ops", # pylint:disable=too-many-arguments data_format="TEXT", input_paths=[PREDICTION_INPUT], prediction_path=PREDICTION_OUTPUT, metric_fn_and_keys=get_metric_fn_and_keys(), validate_fn=validate_err_and_count, batch_prediction_job_id= "evalueate-ops-{{ ts_nodash }}-{{ params.model_name }}", project_id=PROJECT_ID, region="us-central1", dataflow_options={ 'project': PROJECT_ID, 'tempLocation': SUMMARY_TMP, 'stagingLocation': SUMMARY_STAGING, }, model_name=MODEL_NAME, version_name="v1", py_interpreter="python3", ) create_model >> create_version >> evaluate_prediction evaluate_validation >> delete_version