예제 #1
0
    def testHttpError(self):
        http_error_code = 403

        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
                as mock_hook:
            input_with_model = self.INPUT_MISSING_ORIGIN.copy()
            input_with_model['modelName'] = \
                'projects/experimental/models/test_model'

            hook_instance = mock_hook.return_value
            hook_instance.create_job.side_effect = errors.HttpError(
                resp=httplib2.Response({'status': http_error_code}),
                content=b'Forbidden')

            with self.assertRaises(errors.HttpError) as context:
                prediction_task = CloudMLBatchPredictionOperator(
                    job_id='test_prediction',
                    project_id='test-project',
                    region=input_with_model['region'],
                    data_format=input_with_model['dataFormat'],
                    input_paths=input_with_model['inputPaths'],
                    output_path=input_with_model['outputPath'],
                    model_name=input_with_model['modelName'].split('/')[-1],
                    dag=self.dag,
                    task_id='test-prediction')
                prediction_task.execute(None)

                mock_hook.assert_called_with('google_cloud_default', None)
                hook_instance.create_job.assert_called_with(
                    'test-project', {
                        'jobId': 'test_prediction',
                        'predictionInput': input_with_model
                    }, ANY)

            self.assertEquals(http_error_code, context.exception.resp.status)
예제 #2
0
    def testSuccessWithModel(self):
        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
                as mock_hook:

            input_with_model = self.INPUT_MISSING_ORIGIN.copy()
            input_with_model['modelName'] = \
                'projects/test-project/models/test_model'
            success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
            success_message['predictionInput'] = input_with_model

            hook_instance = mock_hook.return_value
            hook_instance.get_job.side_effect = errors.HttpError(
                resp=httplib2.Response({'status': 404}), content=b'some bytes')
            hook_instance.create_job.return_value = success_message

            prediction_task = CloudMLBatchPredictionOperator(
                job_id='test_prediction',
                project_id='test-project',
                region=input_with_model['region'],
                data_format=input_with_model['dataFormat'],
                input_paths=input_with_model['inputPaths'],
                output_path=input_with_model['outputPath'],
                model_name=input_with_model['modelName'].split('/')[-1],
                dag=self.dag,
                task_id='test-prediction')
            prediction_output = prediction_task.execute(None)

            mock_hook.assert_called_with('google_cloud_default', None)
            hook_instance.create_job.assert_called_once_with(
                'test-project', {
                    'jobId': 'test_prediction',
                    'predictionInput': input_with_model
                }, ANY)
            self.assertEquals(success_message['predictionOutput'],
                              prediction_output)
예제 #3
0
    def testHttpError(self):
        http_error_code = 403

        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
                as mock_hook:
            input_with_model = self.INPUT_MISSING_ORIGIN.copy()
            input_with_model['modelName'] = \
                'projects/experimental/models/test_model'

            hook_instance = mock_hook.return_value
            hook_instance.create_job.side_effect = errors.HttpError(
                resp=httplib2.Response({
                    'status': http_error_code
                }), content=b'Forbidden')

            with self.assertRaises(errors.HttpError) as context:
                prediction_task = CloudMLBatchPredictionOperator(
                    job_id='test_prediction',
                    project_id='test-project',
                    region=input_with_model['region'],
                    data_format=input_with_model['dataFormat'],
                    input_paths=input_with_model['inputPaths'],
                    output_path=input_with_model['outputPath'],
                    model_name=input_with_model['modelName'].split('/')[-1],
                    dag=self.dag,
                    task_id='test-prediction')
                prediction_task.execute(None)

                mock_hook.assert_called_with('google_cloud_default', None)
                hook_instance.create_job.assert_called_with(
                    'test-project',
                    {
                        'jobId': 'test_prediction',
                        'predictionInput': input_with_model
                    }, ANY)

            self.assertEquals(http_error_code, context.exception.resp.status)
    def testSuccessWithVersion(self):
        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
                as mock_hook:

            input_with_version = INPUT_MISSING_ORIGIN.copy()
            input_with_version['versionName'] = \
                'projects/test-project/models/test_model/versions/test_version'
            success_message = SUCCESS_MESSAGE_MISSING_INPUT.copy()
            success_message['predictionInput'] = input_with_version

            hook_instance = mock_hook.return_value
            hook_instance.get_job.side_effect = errors.HttpError(
                resp=httplib2.Response({
                    'status': 404
                }), content=b'some bytes')
            hook_instance.create_job.return_value = success_message

            prediction_task = CloudMLBatchPredictionOperator(
                job_id='test_prediction',
                project_id='test-project',
                region=input_with_version['region'],
                data_format=input_with_version['dataFormat'],
                input_paths=input_with_version['inputPaths'],
                output_path=input_with_version['outputPath'],
                model_name=input_with_version['versionName'].split('/')[-3],
                version_name=input_with_version['versionName'].split('/')[-1],
                dag=self.dag,
                task_id='test-prediction')
            prediction_output = prediction_task.execute(None)

            mock_hook.assert_called_with('google_cloud_default', None)
            hook_instance.create_job.assert_called_with(
                'test-project',
                {
                    'jobId': 'test_prediction',
                    'predictionInput': input_with_version
                })
            self.assertEquals(
                success_message['predictionOutput'],
                prediction_output)
예제 #5
0
    def testSuccessWithURI(self):
        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
                as mock_hook:

            input_with_uri = self.INPUT_MISSING_ORIGIN.copy()
            input_with_uri['uri'] = 'gs://my_bucket/my_models/savedModel'
            success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
            success_message['predictionInput'] = input_with_uri

            hook_instance = mock_hook.return_value
            hook_instance.get_job.side_effect = errors.HttpError(
                resp=httplib2.Response({
                    'status': 404
                }), content=b'some bytes')
            hook_instance.create_job.return_value = success_message

            prediction_task = CloudMLBatchPredictionOperator(
                job_id='test_prediction',
                project_id='test-project',
                region=input_with_uri['region'],
                data_format=input_with_uri['dataFormat'],
                input_paths=input_with_uri['inputPaths'],
                output_path=input_with_uri['outputPath'],
                uri=input_with_uri['uri'],
                dag=self.dag,
                task_id='test-prediction')
            prediction_output = prediction_task.execute(None)

            mock_hook.assert_called_with('google_cloud_default', None)
            hook_instance.create_job.assert_called_with(
                'test-project',
                {
                    'jobId': 'test_prediction',
                    'predictionInput': input_with_uri
                }, ANY)
            self.assertEquals(
                success_message['predictionOutput'],
                prediction_output)