예제 #1
0
    def test_failures(self):
        def create_test_dag(dag_id):
            dag = DAG(
                dag_id,
                default_args={
                    'owner': 'airflow',
                    'start_date': DEFAULT_DATE,
                    'end_date': DEFAULT_DATE,
                    'project_id': 'test-project',
                    'region': 'us-east1',
                },
                schedule_interval='@daily',
            )
            return dag

        input_with_model = self.INPUT_MISSING_ORIGIN.copy()
        other_params_but_models = {
            'task_prefix': 'eval-test',
            'batch_prediction_job_id': 'eval-test-prediction',
            'data_format': input_with_model['dataFormat'],
            'input_paths': input_with_model['inputPaths'],
            'prediction_path': input_with_model['outputPath'],
            'metric_fn_and_keys': (self.metric_fn, ['err']),
            'validate_fn': (lambda x: 'err=%.1f' % x['err']),
        }

        with self.assertRaisesRegex(AirflowException, 'Missing model origin'):
            mlengine_operator_utils.create_evaluate_ops(
                dag=create_test_dag('test_dag_1'), **other_params_but_models
            )

        with self.assertRaisesRegex(AirflowException, 'Ambiguous model origin'):
            mlengine_operator_utils.create_evaluate_ops(
                dag=create_test_dag('test_dag_2'),
                model_uri='abc',
                model_name='cde',
                **other_params_but_models,
            )

        with self.assertRaisesRegex(AirflowException, 'Ambiguous model origin'):
            mlengine_operator_utils.create_evaluate_ops(
                dag=create_test_dag('test_dag_3'),
                model_uri='abc',
                version_name='vvv',
                **other_params_but_models,
            )

        with self.assertRaisesRegex(AirflowException, '`metric_fn` param must be callable'):
            params = other_params_but_models.copy()
            params['metric_fn_and_keys'] = (None, ['abc'])
            mlengine_operator_utils.create_evaluate_ops(
                dag=create_test_dag('test_dag_4'), model_uri='gs://blah', **params
            )

        with self.assertRaisesRegex(AirflowException, '`validate_fn` param must be callable'):
            params = other_params_but_models.copy()
            params['validate_fn'] = None
            mlengine_operator_utils.create_evaluate_ops(
                dag=create_test_dag('test_dag_5'), model_uri='gs://blah', **params
            )
예제 #2
0
 def test_non_callable_validate_fn(self):
     with self.assertRaises(AirflowException):
         create_evaluate_ops(task_prefix=TASK_PREFIX,
                             data_format=DATA_FORMAT,
                             input_paths=INPUT_PATHS,
                             prediction_path=PREDICTION_PATH,
                             metric_fn_and_keys=get_metric_fn_and_keys(),
                             validate_fn="validate_err_and_count")
예제 #3
0
 def test_non_callable_metric_fn(self):
     with pytest.raises(AirflowException):
         create_evaluate_ops(
             task_prefix=TASK_PREFIX,
             data_format=DATA_FORMAT,
             input_paths=INPUT_PATHS,
             prediction_path=PREDICTION_PATH,
             metric_fn_and_keys=("error_and_squared_error", ['err', 'mse']),
             validate_fn=validate_err_and_count,
         )
예제 #4
0
    def test_invalid_task_prefix(self):
        invalid_task_prefix_values = ["test-task-prefix&", "~test-task-prefix", "test-task(-prefix"]

        for invalid_task_prefix_value in invalid_task_prefix_values:
            with pytest.raises(AirflowException):
                create_evaluate_ops(
                    task_prefix=invalid_task_prefix_value,
                    data_format=DATA_FORMAT,
                    input_paths=INPUT_PATHS,
                    prediction_path=PREDICTION_PATH,
                    metric_fn_and_keys=get_metric_fn_and_keys(),
                    validate_fn=validate_err_and_count,
                )
예제 #5
0
    def test_apply_validate_fn(self, mock_dataflow, mock_python, mock_download):
        result = create_evaluate_ops(
            task_prefix=TASK_PREFIX,
            data_format=DATA_FORMAT,
            input_paths=INPUT_PATHS,
            prediction_path=PREDICTION_PATH,
            metric_fn_and_keys=get_metric_fn_and_keys(),
            validate_fn=validate_err_and_count,
            batch_prediction_job_id=BATCH_PREDICTION_JOB_ID,
            project_id=PROJECT_ID,
            region=REGION,
            dataflow_options=DATAFLOW_OPTIONS,
            model_uri=MODEL_URI,
        )

        _, _, evaluate_validation = result

        mock_download.return_value = json.dumps({"err": 0.3, "mse": 0.04, "count": 1100})
        templates_dict = {"prediction_path": PREDICTION_PATH}
        with pytest.raises(ValueError) as ctx:
            evaluate_validation.python_callable(templates_dict=templates_dict)

        assert "Too high err>0.2; summary={'err': 0.3, 'mse': 0.04, 'count': 1100}" == str(ctx.value)
        mock_download.assert_called_once_with("path", "to/output/predictions.json/prediction.summary.json")

        invalid_prediction_paths = ["://path/to/output/predictions.json", "gs://", ""]

        for path in invalid_prediction_paths:
            templates_dict = {"prediction_path": path}
            with pytest.raises(ValueError) as ctx:
                evaluate_validation.python_callable(templates_dict=templates_dict)
            assert "Wrong format prediction_path:" == str(ctx.value)[:29]
예제 #6
0
    def test_create_evaluate_ops_dag(self, mock_dataflow, mock_python):
        result = create_evaluate_ops(
            task_prefix=TASK_PREFIX,
            data_format=DATA_FORMAT,
            input_paths=INPUT_PATHS,
            prediction_path=PREDICTION_PATH,
            metric_fn_and_keys=get_metric_fn_and_keys(),
            validate_fn=validate_err_and_count,
            batch_prediction_job_id=BATCH_PREDICTION_JOB_ID,
            dag=TEST_DAG,
        )

        evaluate_prediction, evaluate_summary, evaluate_validation = result

        mock_dataflow.assert_called_once_with(evaluate_prediction)
        mock_python.assert_called_once_with(evaluate_summary)

        assert TASK_PREFIX_PREDICTION == evaluate_prediction.task_id
        assert PROJECT_ID == evaluate_prediction._project_id
        assert BATCH_PREDICTION_JOB_ID == evaluate_prediction._job_id
        assert REGION == evaluate_prediction._region
        assert DATA_FORMAT == evaluate_prediction._data_format
        assert INPUT_PATHS == evaluate_prediction._input_paths
        assert PREDICTION_PATH == evaluate_prediction._output_path
        assert MODEL_NAME == evaluate_prediction._model_name
        assert VERSION_NAME == evaluate_prediction._version_name

        assert TASK_PREFIX_SUMMARY == evaluate_summary.task_id
        assert DATAFLOW_OPTIONS == evaluate_summary.dataflow_default_options
        assert PREDICTION_PATH == evaluate_summary.options["prediction_path"]
        assert METRIC_FN_ENCODED == evaluate_summary.options["metric_fn_encoded"]
        assert METRIC_KEYS_EXPECTED == evaluate_summary.options["metric_keys"]

        assert TASK_PREFIX_VALIDATION == evaluate_validation.task_id
        assert PREDICTION_PATH == evaluate_validation.templates_dict["prediction_path"]
예제 #7
0
    def test_create_evaluate_ops_model_and_version_name(
            self, mock_dataflow, mock_python):
        result = create_evaluate_ops(
            task_prefix=TASK_PREFIX,
            data_format=DATA_FORMAT,
            input_paths=INPUT_PATHS,
            prediction_path=PREDICTION_PATH,
            metric_fn_and_keys=get_metric_fn_and_keys(),
            validate_fn=validate_err_and_count,
            batch_prediction_job_id=BATCH_PREDICTION_JOB_ID,
            project_id=PROJECT_ID,
            region=REGION,
            dataflow_options=DATAFLOW_OPTIONS,
            model_name=MODEL_NAME,
            version_name=VERSION_NAME,
        )

        evaluate_prediction, evaluate_summary, evaluate_validation = result

        mock_dataflow.assert_called_once_with(evaluate_prediction)
        mock_python.assert_called_once_with(evaluate_summary)

        self.assertEqual(TASK_PREFIX_PREDICTION, evaluate_prediction.task_id)
        self.assertEqual(PROJECT_ID, evaluate_prediction._project_id)
        self.assertEqual(BATCH_PREDICTION_JOB_ID, evaluate_prediction._job_id)
        self.assertEqual(REGION, evaluate_prediction._region)
        self.assertEqual(DATA_FORMAT, evaluate_prediction._data_format)
        self.assertEqual(INPUT_PATHS, evaluate_prediction._input_paths)
        self.assertEqual(PREDICTION_PATH, evaluate_prediction._output_path)
        self.assertEqual(MODEL_NAME, evaluate_prediction._model_name)
        self.assertEqual(VERSION_NAME, evaluate_prediction._version_name)

        self.assertEqual(TASK_PREFIX_SUMMARY, evaluate_summary.task_id)
        self.assertEqual(DATAFLOW_OPTIONS,
                         evaluate_summary.dataflow_default_options)
        self.assertEqual(PREDICTION_PATH,
                         evaluate_summary.options["prediction_path"])
        self.assertEqual(METRIC_FN_ENCODED,
                         evaluate_summary.options["metric_fn_encoded"])
        self.assertEqual(METRIC_KEYS_EXPECTED,
                         evaluate_summary.options["metric_keys"])

        self.assertEqual(TASK_PREFIX_VALIDATION, evaluate_validation.task_id)
        self.assertEqual(PREDICTION_PATH,
                         evaluate_validation.templates_dict["prediction_path"])
예제 #8
0
            raise ValueError(f'Invalid value val != 20; summary={summary}')
        return summary

    # [END howto_operator_gcp_mlengine_validate_error]

    # [START howto_operator_gcp_mlengine_evaluate]
    evaluate_prediction, evaluate_summary, evaluate_validation = mlengine_operator_utils.create_evaluate_ops(
        task_prefix="evaluate-ops",
        data_format="TEXT",
        input_paths=[PREDICTION_INPUT],
        prediction_path=PREDICTION_OUTPUT,
        metric_fn_and_keys=get_metric_fn_and_keys(),
        validate_fn=validate_err_and_count,
        batch_prediction_job_id=
        "evaluate-ops-{{ ts_nodash }}-{{ params.model_name }}",
        project_id=PROJECT_ID,
        region="us-central1",
        dataflow_options={
            'project': PROJECT_ID,
            'tempLocation': SUMMARY_TMP,
            'stagingLocation': SUMMARY_STAGING,
        },
        model_name=MODEL_NAME,
        version_name="v1",
        py_interpreter="python3",
    )
    # [END howto_operator_gcp_mlengine_evaluate]

    create_model >> create_version >> evaluate_prediction
    evaluate_validation >> delete_version
예제 #9
0
    def test_successful_run(self):
        input_with_model = self.INPUT_MISSING_ORIGIN.copy()

        pred, summary, validate = mlengine_operator_utils.create_evaluate_ops(
            task_prefix='eval-test',
            batch_prediction_job_id='eval-test-prediction',
            data_format=input_with_model['dataFormat'],
            input_paths=input_with_model['inputPaths'],
            prediction_path=input_with_model['outputPath'],
            metric_fn_and_keys=(self.metric_fn, ['err']),
            validate_fn=(lambda x: 'err=%.1f' % x['err']),
            dag=self.dag,
            py_interpreter="python3",
        )

        with patch(
                'airflow.providers.google.cloud.operators.mlengine.MLEngineHook'
        ) as mock_mlengine_hook:
            success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
            success_message['predictionInput'] = input_with_model
            hook_instance = mock_mlengine_hook.return_value
            hook_instance.create_job.return_value = success_message
            result = pred.execute(None)
            mock_mlengine_hook.assert_called_once_with('google_cloud_default',
                                                       None)
            hook_instance.create_job.assert_called_once_with(
                project_id='test-project',
                job={
                    'jobId': 'eval_test_prediction',
                    'predictionInput': input_with_model,
                },
                use_existing_job_fn=ANY)
            self.assertEqual(success_message['predictionOutput'], result)

        with patch(
                'airflow.providers.google.cloud.operators.dataflow.DataflowHook'
        ) as mock_dataflow_hook:
            hook_instance = mock_dataflow_hook.return_value
            hook_instance.start_python_dataflow.return_value = None
            summary.execute(None)
            mock_dataflow_hook.assert_called_once_with(
                gcp_conn_id='google_cloud_default',
                delegate_to=None,
                poll_sleep=10)
            hook_instance.start_python_dataflow.assert_called_once_with(
                job_name='{{task.task_id}}',
                variables={
                    'prediction_path': 'gs://legal-bucket/fake-output-path',
                    'labels': {
                        'airflow-version': TEST_VERSION
                    },
                    'metric_keys': 'err',
                    'metric_fn_encoded': self.metric_fn_encoded,
                },
                dataflow=mock.ANY,
                py_options=[],
                py_requirements=['apache-beam[gcp]>=2.14.0'],
                py_interpreter='python3',
                py_system_site_packages=False,
                on_new_job_id_callback=ANY,
                project_id='test-project',
            )

        with patch(
                'airflow.providers.google.cloud.utils.mlengine_operator_utils.GCSHook'
        ) as mock_gcs_hook:
            hook_instance = mock_gcs_hook.return_value
            hook_instance.download.return_value = '{"err": 0.9, "count": 9}'
            result = validate.execute({})
            hook_instance.download.assert_called_once_with(
                'legal-bucket', 'fake-output-path/prediction.summary.json')
            self.assertEqual('err=0.9', result)
예제 #10
0
    def test_successful_run(self):
        input_with_model = self.INPUT_MISSING_ORIGIN.copy()

        pred, summary, validate = mlengine_operator_utils.create_evaluate_ops(
            task_prefix='eval-test',
            batch_prediction_job_id='eval-test-prediction',
            data_format=input_with_model['dataFormat'],
            input_paths=input_with_model['inputPaths'],
            prediction_path=input_with_model['outputPath'],
            metric_fn_and_keys=(self.metric_fn, ['err']),
            validate_fn=(lambda x: f"err={x['err']:.1f}"),
            dag=self.dag,
            py_interpreter="python3",
        )

        with patch(
                'airflow.providers.google.cloud.operators.mlengine.MLEngineHook'
        ) as mock_mlengine_hook:
            success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
            success_message['predictionInput'] = input_with_model
            hook_instance = mock_mlengine_hook.return_value
            hook_instance.create_job.return_value = success_message
            result = pred.execute(None)
            mock_mlengine_hook.assert_called_once_with(
                'google_cloud_default',
                None,
                impersonation_chain=None,
            )
            hook_instance.create_job.assert_called_once_with(
                project_id='test-project',
                job={
                    'jobId': 'eval_test_prediction',
                    'predictionInput': input_with_model,
                },
                use_existing_job_fn=ANY,
            )
            assert success_message['predictionOutput'] == result

        with patch(
                'airflow.providers.google.cloud.operators.dataflow.DataflowHook'
        ) as mock_dataflow_hook, patch(
                'airflow.providers.google.cloud.operators.dataflow.BeamHook'
        ) as mock_beam_hook:
            dataflow_hook_instance = mock_dataflow_hook.return_value
            dataflow_hook_instance.start_python_dataflow.return_value = None
            beam_hook_instance = mock_beam_hook.return_value
            summary.execute(None)
            mock_dataflow_hook.assert_called_once_with(
                gcp_conn_id='google_cloud_default',
                delegate_to=None,
                poll_sleep=10,
                drain_pipeline=False,
                cancel_timeout=600,
                wait_until_finished=None,
                impersonation_chain=None,
            )
            mock_beam_hook.assert_called_once_with(runner="DataflowRunner")
            beam_hook_instance.start_python_pipeline.assert_called_once_with(
                variables={
                    'prediction_path': 'gs://legal-bucket/fake-output-path',
                    'labels': {
                        'airflow-version': TEST_VERSION
                    },
                    'metric_keys': 'err',
                    'metric_fn_encoded': self.metric_fn_encoded,
                    'project': 'test-project',
                    'region': 'us-central1',
                    'job_name': mock.ANY,
                },
                py_file=mock.ANY,
                py_options=[],
                py_interpreter='python3',
                py_requirements=['apache-beam[gcp]>=2.14.0'],
                py_system_site_packages=False,
                process_line_callback=mock.ANY,
            )
            dataflow_hook_instance.wait_for_done.assert_called_once_with(
                job_name=mock.ANY,
                location='us-central1',
                job_id=mock.ANY,
                multiple_jobs=False)

        with patch(
                'airflow.providers.google.cloud.utils.mlengine_operator_utils.GCSHook'
        ) as mock_gcs_hook:
            hook_instance = mock_gcs_hook.return_value
            hook_instance.download.return_value = '{"err": 0.9, "count": 9}'
            result = validate.execute({})
            hook_instance.download.assert_called_once_with(
                'legal-bucket', 'fake-output-path/prediction.summary.json')
            assert 'err=0.9' == result