def testHttpError(self): http_error_code = 403 with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ as mock_hook: input_with_model = self.INPUT_MISSING_ORIGIN.copy() input_with_model['modelName'] = \ 'projects/experimental/models/test_model' hook_instance = mock_hook.return_value hook_instance.create_job.side_effect = errors.HttpError( resp=httplib2.Response({'status': http_error_code}), content=b'Forbidden') with self.assertRaises(errors.HttpError) as context: prediction_task = CloudMLBatchPredictionOperator( job_id='test_prediction', project_id='test-project', region=input_with_model['region'], data_format=input_with_model['dataFormat'], input_paths=input_with_model['inputPaths'], output_path=input_with_model['outputPath'], model_name=input_with_model['modelName'].split('/')[-1], dag=self.dag, task_id='test-prediction') prediction_task.execute(None) mock_hook.assert_called_with('google_cloud_default', None) hook_instance.create_job.assert_called_with( 'test-project', { 'jobId': 'test_prediction', 'predictionInput': input_with_model }, ANY) self.assertEquals(http_error_code, context.exception.resp.status)
def testSuccessWithModel(self): with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ as mock_hook: input_with_model = self.INPUT_MISSING_ORIGIN.copy() input_with_model['modelName'] = \ 'projects/test-project/models/test_model' success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() success_message['predictionInput'] = input_with_model hook_instance = mock_hook.return_value hook_instance.get_job.side_effect = errors.HttpError( resp=httplib2.Response({'status': 404}), content=b'some bytes') hook_instance.create_job.return_value = success_message prediction_task = CloudMLBatchPredictionOperator( job_id='test_prediction', project_id='test-project', region=input_with_model['region'], data_format=input_with_model['dataFormat'], input_paths=input_with_model['inputPaths'], output_path=input_with_model['outputPath'], model_name=input_with_model['modelName'].split('/')[-1], dag=self.dag, task_id='test-prediction') prediction_output = prediction_task.execute(None) mock_hook.assert_called_with('google_cloud_default', None) hook_instance.create_job.assert_called_once_with( 'test-project', { 'jobId': 'test_prediction', 'predictionInput': input_with_model }, ANY) self.assertEquals(success_message['predictionOutput'], prediction_output)
def testHttpError(self): http_error_code = 403 with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ as mock_hook: input_with_model = self.INPUT_MISSING_ORIGIN.copy() input_with_model['modelName'] = \ 'projects/experimental/models/test_model' hook_instance = mock_hook.return_value hook_instance.create_job.side_effect = errors.HttpError( resp=httplib2.Response({ 'status': http_error_code }), content=b'Forbidden') with self.assertRaises(errors.HttpError) as context: prediction_task = CloudMLBatchPredictionOperator( job_id='test_prediction', project_id='test-project', region=input_with_model['region'], data_format=input_with_model['dataFormat'], input_paths=input_with_model['inputPaths'], output_path=input_with_model['outputPath'], model_name=input_with_model['modelName'].split('/')[-1], dag=self.dag, task_id='test-prediction') prediction_task.execute(None) mock_hook.assert_called_with('google_cloud_default', None) hook_instance.create_job.assert_called_with( 'test-project', { 'jobId': 'test_prediction', 'predictionInput': input_with_model }, ANY) self.assertEquals(http_error_code, context.exception.resp.status)
def testSuccessWithVersion(self): with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ as mock_hook: input_with_version = INPUT_MISSING_ORIGIN.copy() input_with_version['versionName'] = \ 'projects/test-project/models/test_model/versions/test_version' success_message = SUCCESS_MESSAGE_MISSING_INPUT.copy() success_message['predictionInput'] = input_with_version hook_instance = mock_hook.return_value hook_instance.get_job.side_effect = errors.HttpError( resp=httplib2.Response({ 'status': 404 }), content=b'some bytes') hook_instance.create_job.return_value = success_message prediction_task = CloudMLBatchPredictionOperator( job_id='test_prediction', project_id='test-project', region=input_with_version['region'], data_format=input_with_version['dataFormat'], input_paths=input_with_version['inputPaths'], output_path=input_with_version['outputPath'], model_name=input_with_version['versionName'].split('/')[-3], version_name=input_with_version['versionName'].split('/')[-1], dag=self.dag, task_id='test-prediction') prediction_output = prediction_task.execute(None) mock_hook.assert_called_with('google_cloud_default', None) hook_instance.create_job.assert_called_with( 'test-project', { 'jobId': 'test_prediction', 'predictionInput': input_with_version }) self.assertEquals( success_message['predictionOutput'], prediction_output)
def testSuccessWithURI(self): with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ as mock_hook: input_with_uri = self.INPUT_MISSING_ORIGIN.copy() input_with_uri['uri'] = 'gs://my_bucket/my_models/savedModel' success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() success_message['predictionInput'] = input_with_uri hook_instance = mock_hook.return_value hook_instance.get_job.side_effect = errors.HttpError( resp=httplib2.Response({ 'status': 404 }), content=b'some bytes') hook_instance.create_job.return_value = success_message prediction_task = CloudMLBatchPredictionOperator( job_id='test_prediction', project_id='test-project', region=input_with_uri['region'], data_format=input_with_uri['dataFormat'], input_paths=input_with_uri['inputPaths'], output_path=input_with_uri['outputPath'], uri=input_with_uri['uri'], dag=self.dag, task_id='test-prediction') prediction_output = prediction_task.execute(None) mock_hook.assert_called_with('google_cloud_default', None) hook_instance.create_job.assert_called_with( 'test-project', { 'jobId': 'test_prediction', 'predictionInput': input_with_uri }, ANY) self.assertEquals( success_message['predictionOutput'], prediction_output)