def test_http_error(self):
        http_error_code = 403

        with patch('airflow.gcp.operators.mlengine.MLEngineHook') \
                as mock_hook:
            input_with_model = self.INPUT_MISSING_ORIGIN.copy()
            input_with_model['modelName'] = \
                'projects/experimental/models/test_model'

            hook_instance = mock_hook.return_value
            hook_instance.create_job.side_effect = HttpError(
                resp=httplib2.Response({'status': http_error_code}),
                content=b'Forbidden')

            with self.assertRaises(HttpError) as context:
                prediction_task = MLEngineBatchPredictionOperator(
                    job_id='test_prediction',
                    project_id='test-project',
                    region=input_with_model['region'],
                    data_format=input_with_model['dataFormat'],
                    input_paths=input_with_model['inputPaths'],
                    output_path=input_with_model['outputPath'],
                    model_name=input_with_model['modelName'].split('/')[-1],
                    dag=self.dag,
                    task_id='test-prediction')
                prediction_task.execute(None)

                mock_hook.assert_called_once_with('google_cloud_default', None)
                hook_instance.create_job.assert_called_once_with(
                    'test-project', {
                        'jobId': 'test_prediction',
                        'predictionInput': input_with_model
                    }, ANY)

            self.assertEqual(http_error_code, context.exception.resp.status)
Beispiel #2
0
    def test_success_with_model(self):
        with patch('airflow.gcp.operators.mlengine.MLEngineHook') \
                as mock_hook:

            input_with_model = self.INPUT_MISSING_ORIGIN.copy()
            input_with_model['modelName'] = \
                'projects/test-project/models/test_model'
            success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
            success_message['predictionInput'] = input_with_model

            hook_instance = mock_hook.return_value
            hook_instance.get_job.side_effect = HttpError(
                resp=httplib2.Response({'status': 404}), content=b'some bytes')
            hook_instance.create_job.return_value = success_message

            prediction_task = MLEngineBatchPredictionOperator(
                job_id='test_prediction',
                project_id='test-project',
                region=input_with_model['region'],
                data_format=input_with_model['dataFormat'],
                input_paths=input_with_model['inputPaths'],
                output_path=input_with_model['outputPath'],
                model_name=input_with_model['modelName'].split('/')[-1],
                dag=self.dag,
                task_id='test-prediction')
            prediction_output = prediction_task.execute(None)

            mock_hook.assert_called_once_with('google_cloud_default', None)
            hook_instance.create_job.assert_called_once_with(
                'test-project', {
                    'jobId': 'test_prediction',
                    'predictionInput': input_with_model
                }, ANY)
            self.assertEqual(success_message['predictionOutput'],
                             prediction_output)
    def test_success_with_uri(self, mock_hook):
        input_with_uri = self.INPUT_MISSING_ORIGIN.copy()
        input_with_uri['uri'] = 'gs://my_bucket/my_models/savedModel'
        success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
        success_message['predictionInput'] = input_with_uri

        hook_instance = mock_hook.return_value
        hook_instance.get_job.side_effect = HttpError(resp=httplib2.Response(
            {'status': 404}),
                                                      content=b'some bytes')
        hook_instance.create_job.return_value = success_message

        prediction_task = MLEngineBatchPredictionOperator(
            job_id='test_prediction',
            project_id='test-project',
            region=input_with_uri['region'],
            data_format=input_with_uri['dataFormat'],
            input_paths=input_with_uri['inputPaths'],
            output_path=input_with_uri['outputPath'],
            uri=input_with_uri['uri'],
            dag=self.dag,
            task_id='test-prediction')
        prediction_output = prediction_task.execute(None)

        mock_hook.assert_called_once_with('google_cloud_default', None)
        hook_instance.create_job.assert_called_once_with(
            project_id='test-project',
            job={
                'jobId': 'test_prediction',
                'predictionInput': input_with_uri
            },
            use_existing_job_fn=ANY)
        self.assertEqual(success_message['predictionOutput'],
                         prediction_output)