Ejemplo n.º 1
0
    def test_http_error(self, mock_hook):
        http_error_code = 403
        input_with_model = self.INPUT_MISSING_ORIGIN.copy()
        input_with_model['modelName'] = 'projects/experimental/models/test_model'

        hook_instance = mock_hook.return_value
        hook_instance.create_job.side_effect = HttpError(
            resp=httplib2.Response({'status': http_error_code}), content=b'Forbidden'
        )

        with self.assertRaises(HttpError) as context:
            prediction_task = MLEngineStartBatchPredictionJobOperator(
                job_id='test_prediction',
                project_id='test-project',
                region=input_with_model['region'],
                data_format=input_with_model['dataFormat'],
                input_paths=input_with_model['inputPaths'],
                output_path=input_with_model['outputPath'],
                model_name=input_with_model['modelName'].split('/')[-1],
                dag=self.dag,
                task_id='test-prediction',
            )
            prediction_task.execute(None)

            mock_hook.assert_called_once_with(
                'google_cloud_default', None, impersonation_chain=None,
            )
            hook_instance.create_job.assert_called_once_with(
                'test-project', {'jobId': 'test_prediction', 'predictionInput': input_with_model}, ANY
            )

        self.assertEqual(http_error_code, context.exception.resp.status)
Ejemplo n.º 2
0
    def test_success_with_uri(self, mock_hook):
        input_with_uri = self.INPUT_MISSING_ORIGIN.copy()
        input_with_uri['uri'] = 'gs://my_bucket/my_models/savedModel'
        success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
        success_message['predictionInput'] = input_with_uri

        hook_instance = mock_hook.return_value
        hook_instance.get_job.side_effect = HttpError(
            resp=httplib2.Response({'status': 404}), content=b'some bytes'
        )
        hook_instance.create_job.return_value = success_message

        prediction_task = MLEngineStartBatchPredictionJobOperator(
            job_id='test_prediction',
            project_id='test-project',
            region=input_with_uri['region'],
            data_format=input_with_uri['dataFormat'],
            input_paths=input_with_uri['inputPaths'],
            output_path=input_with_uri['outputPath'],
            uri=input_with_uri['uri'],
            dag=self.dag,
            task_id='test-prediction',
        )
        prediction_output = prediction_task.execute(None)

        mock_hook.assert_called_once_with(
            'google_cloud_default', None, impersonation_chain=None,
        )
        hook_instance.create_job.assert_called_once_with(
            project_id='test-project',
            job={'jobId': 'test_prediction', 'predictionInput': input_with_uri},
            use_existing_job_fn=ANY,
        )
        self.assertEqual(success_message['predictionOutput'], prediction_output)
Ejemplo n.º 3
0
    def test_success_with_model(self, mock_hook):
        input_with_model = self.INPUT_MISSING_ORIGIN.copy()
        input_with_model['modelName'] = \
            'projects/test-project/models/test_model'
        success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
        success_message['predictionInput'] = input_with_model

        hook_instance = mock_hook.return_value
        hook_instance.get_job.side_effect = HttpError(resp=httplib2.Response(
            {'status': 404}),
                                                      content=b'some bytes')
        hook_instance.create_job.return_value = success_message

        prediction_task = MLEngineStartBatchPredictionJobOperator(
            job_id='test_prediction',
            project_id='test-project',
            region=input_with_model['region'],
            data_format=input_with_model['dataFormat'],
            input_paths=input_with_model['inputPaths'],
            output_path=input_with_model['outputPath'],
            model_name=input_with_model['modelName'].split('/')[-1],
            labels={'some': 'labels'},
            dag=self.dag,
            task_id='test-prediction')
        prediction_output = prediction_task.execute(None)

        mock_hook.assert_called_once_with('google_cloud_default', None)
        hook_instance.create_job.assert_called_once_with(
            project_id='test-project',
            job={
                'jobId': 'test_prediction',
                'labels': {
                    'some': 'labels'
                },
                'predictionInput': input_with_model
            },
            use_existing_job_fn=ANY)
        self.assertEqual(success_message['predictionOutput'],
                         prediction_output)