def test_http_error(self): http_error_code = 403 with patch('airflow.gcp.operators.mlengine.MLEngineHook') \ as mock_hook: input_with_model = self.INPUT_MISSING_ORIGIN.copy() input_with_model['modelName'] = \ 'projects/experimental/models/test_model' hook_instance = mock_hook.return_value hook_instance.create_job.side_effect = HttpError( resp=httplib2.Response({'status': http_error_code}), content=b'Forbidden') with self.assertRaises(HttpError) as context: prediction_task = MLEngineBatchPredictionOperator( job_id='test_prediction', project_id='test-project', region=input_with_model['region'], data_format=input_with_model['dataFormat'], input_paths=input_with_model['inputPaths'], output_path=input_with_model['outputPath'], model_name=input_with_model['modelName'].split('/')[-1], dag=self.dag, task_id='test-prediction') prediction_task.execute(None) mock_hook.assert_called_once_with('google_cloud_default', None) hook_instance.create_job.assert_called_once_with( 'test-project', { 'jobId': 'test_prediction', 'predictionInput': input_with_model }, ANY) self.assertEqual(http_error_code, context.exception.resp.status)
def test_success_with_model(self): with patch('airflow.gcp.operators.mlengine.MLEngineHook') \ as mock_hook: input_with_model = self.INPUT_MISSING_ORIGIN.copy() input_with_model['modelName'] = \ 'projects/test-project/models/test_model' success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() success_message['predictionInput'] = input_with_model hook_instance = mock_hook.return_value hook_instance.get_job.side_effect = HttpError( resp=httplib2.Response({'status': 404}), content=b'some bytes') hook_instance.create_job.return_value = success_message prediction_task = MLEngineBatchPredictionOperator( job_id='test_prediction', project_id='test-project', region=input_with_model['region'], data_format=input_with_model['dataFormat'], input_paths=input_with_model['inputPaths'], output_path=input_with_model['outputPath'], model_name=input_with_model['modelName'].split('/')[-1], dag=self.dag, task_id='test-prediction') prediction_output = prediction_task.execute(None) mock_hook.assert_called_once_with('google_cloud_default', None) hook_instance.create_job.assert_called_once_with( 'test-project', { 'jobId': 'test_prediction', 'predictionInput': input_with_model }, ANY) self.assertEqual(success_message['predictionOutput'], prediction_output)
def test_success_with_uri(self, mock_hook): input_with_uri = self.INPUT_MISSING_ORIGIN.copy() input_with_uri['uri'] = 'gs://my_bucket/my_models/savedModel' success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() success_message['predictionInput'] = input_with_uri hook_instance = mock_hook.return_value hook_instance.get_job.side_effect = HttpError(resp=httplib2.Response( {'status': 404}), content=b'some bytes') hook_instance.create_job.return_value = success_message prediction_task = MLEngineBatchPredictionOperator( job_id='test_prediction', project_id='test-project', region=input_with_uri['region'], data_format=input_with_uri['dataFormat'], input_paths=input_with_uri['inputPaths'], output_path=input_with_uri['outputPath'], uri=input_with_uri['uri'], dag=self.dag, task_id='test-prediction') prediction_output = prediction_task.execute(None) mock_hook.assert_called_once_with('google_cloud_default', None) hook_instance.create_job.assert_called_once_with( project_id='test-project', job={ 'jobId': 'test_prediction', 'predictionInput': input_with_uri }, use_existing_job_fn=ANY) self.assertEqual(success_message['predictionOutput'], prediction_output)