def test_failures(self):
        dag = DAG('test_dag',
                  default_args={
                      'owner': 'airflow',
                      'start_date': DEFAULT_DATE,
                      'end_date': DEFAULT_DATE,
                      'project_id': 'test-project',
                      'region': 'us-east1',
                  },
                  schedule_interval='@daily')

        input_with_model = self.INPUT_MISSING_ORIGIN.copy()
        other_params_but_models = {
            'task_prefix': 'eval-test',
            'batch_prediction_job_id': 'eval-test-prediction',
            'data_format': input_with_model['dataFormat'],
            'input_paths': input_with_model['inputPaths'],
            'prediction_path': input_with_model['outputPath'],
            'metric_fn_and_keys': (self.metric_fn, ['err']),
            'validate_fn': (lambda x: 'err=%.1f' % x['err']),
            'dag': dag,
        }

        with self.assertRaisesRegex(AirflowException, 'Missing model origin'):
            mlengine_operator_utils.create_evaluate_ops(
                **other_params_but_models)

        with self.assertRaisesRegex(AirflowException,
                                    'Ambiguous model origin'):
            mlengine_operator_utils.create_evaluate_ops(
                model_uri='abc', model_name='cde', **other_params_but_models)

        with self.assertRaisesRegex(AirflowException,
                                    'Ambiguous model origin'):
            mlengine_operator_utils.create_evaluate_ops(
                model_uri='abc', version_name='vvv', **other_params_but_models)

        with self.assertRaisesRegex(AirflowException,
                                    '`metric_fn` param must be callable'):
            params = other_params_but_models.copy()
            params['metric_fn_and_keys'] = (None, ['abc'])
            mlengine_operator_utils.create_evaluate_ops(model_uri='gs://blah',
                                                        **params)

        with self.assertRaisesRegex(AirflowException,
                                    '`validate_fn` param must be callable'):
            params = other_params_but_models.copy()
            params['validate_fn'] = None
            mlengine_operator_utils.create_evaluate_ops(model_uri='gs://blah',
                                                        **params)
Пример #2
0
    def test_successful_run(self):
        input_with_model = self.INPUT_MISSING_ORIGIN.copy()

        pred, summary, validate = mlengine_operator_utils.create_evaluate_ops(
            task_prefix='eval-test',
            batch_prediction_job_id='eval-test-prediction',
            data_format=input_with_model['dataFormat'],
            input_paths=input_with_model['inputPaths'],
            prediction_path=input_with_model['outputPath'],
            metric_fn_and_keys=(self.metric_fn, ['err']),
            validate_fn=(lambda x: 'err=%.1f' % x['err']),
            dag=self.dag,
            py_interpreter="python3",
        )

        with patch('airflow.gcp.operators.mlengine.MLEngineHook'
                   ) as mock_mlengine_hook:
            success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
            success_message['predictionInput'] = input_with_model
            hook_instance = mock_mlengine_hook.return_value
            hook_instance.create_job.return_value = success_message
            result = pred.execute(None)
            mock_mlengine_hook.assert_called_once_with('google_cloud_default',
                                                       None)
            hook_instance.create_job.assert_called_once_with(
                'test-project', {
                    'jobId': 'eval_test_prediction',
                    'predictionInput': input_with_model,
                }, ANY)
            self.assertEqual(success_message['predictionOutput'], result)

        with patch('airflow.gcp.operators.dataflow.DataFlowHook'
                   ) as mock_dataflow_hook:
            hook_instance = mock_dataflow_hook.return_value
            hook_instance.start_python_dataflow.return_value = None
            summary.execute(None)
            mock_dataflow_hook.assert_called_once_with(
                gcp_conn_id='google_cloud_default',
                delegate_to=None,
                poll_sleep=10)
            hook_instance.start_python_dataflow.assert_called_once_with(
                '{{task.task_id}}', {
                    'prediction_path': 'gs://legal-bucket/fake-output-path',
                    'labels': {
                        'airflow-version': TEST_VERSION
                    },
                    'metric_keys': 'err',
                    'metric_fn_encoded': self.metric_fn_encoded,
                },
                'airflow.gcp.utils.mlengine_prediction_summary', ['-m'],
                py_interpreter='python3')

        with patch('airflow.gcp.utils.mlengine_operator_utils.'
                   'GoogleCloudStorageHook') as mock_gcs_hook:
            hook_instance = mock_gcs_hook.return_value
            hook_instance.download.return_value = '{"err": 0.9, "count": 9}'
            result = validate.execute({})
            hook_instance.download.assert_called_once_with(
                'legal-bucket', 'fake-output-path/prediction.summary.json')
            self.assertEqual('err=0.9', result)
Пример #3
0
        if summary['val'] < 0:
            raise ValueError('Too low val<0; summary={}'.format(summary))
        if summary['count'] != 20:
            raise ValueError(
                'Invalid value val != 20; summary={}'.format(summary))
        return summary

    evaluate_prediction, evaluate_summary, evaluate_validation = mlengine_operator_utils.create_evaluate_ops(
        task_prefix="evalueate-ops",  # pylint:disable=too-many-arguments
        data_format="TEXT",
        input_paths=[PREDICTION_INPUT],
        prediction_path=PREDICTION_OUTPUT,
        metric_fn_and_keys=get_metric_fn_and_keys(),
        validate_fn=validate_err_and_count,
        batch_prediction_job_id=
        "evalueate-ops-{{ ts_nodash }}-{{ params.model_name }}",
        project_id=PROJECT_ID,
        region="us-central1",
        dataflow_options={
            'project': PROJECT_ID,
            'tempLocation': SUMMARY_TMP,
            'stagingLocation': SUMMARY_STAGING,
        },
        model_name=MODEL_NAME,
        version_name="v1",
        py_interpreter="python3",
    )

    create_model >> create_version >> evaluate_prediction
    evaluate_validation >> delete_version