def _do_batch_prediction(self, input_path: Text, output_path: Text) -> Text: """Do batch prediction. Do AutoML batch prediction using data in input_path, write results to output_path(or with output prefix). Since the AutoML API may add timestamps to the output path specified by the user, the real output path must be extracted from the response. Args: input_path: The path to input data. output_path: The output path(for bq://) or output prefix(for gs://) passed to AutoML API to write prediction results. Returns: The output path for the prediction results. Raises: RuntimeError: When there are errors in the AutoML API response. """ hook = automl_tables_hook.AutoMLTablesHook(gcp_conn_id=self.conn_id) resp = hook.batch_predict(model_id=self.model_id, input_path=input_path, output_path=output_path) if 'error' in resp: raise RuntimeError(f'Error in predictions {resp}') output_info = resp['metadata']['batchPredictDetails']['outputInfo'] if 'gcsOutputDirectory' in output_info: return output_info['gcsOutputDirectory'] elif 'bigqueryOutputDataset' in output_info: return output_info['bigqueryOutputDataset'] else: raise RuntimeError( f'Output not found in prediction response: {resp}')
def test_batch_predict_fail_error(self): """Test for a batch prediction which fails with an error in response.""" # Predict response self.mock_post.return_value.json.return_value = { 'name': 'test_operation_name', 'metadata': {} } self.mock_post.return_value.status_code = 200 # Query status response self.mock_get.return_value.json.return_value = { 'done': True, 'error': 'error message' } self.mock_get.return_value.status_code = 200 hook = automl_tables_hook.AutoMLTablesHook() # In unit tests we do not want to sleep between polling with self.assertRaises(RuntimeError) as cm: hook.batch_predict(input_path='gs://input', output_path='gs://output', model_id='test_model_id', compute_region='test_region', poll_wait_time=0) self.assertIn('Error in prediction result', cm.exception.args[0])
def test_batch_predict_success(self): """Test for a successful batch prediction.""" # Predict response self.mock_post.return_value.json.return_value = { 'name': 'test_operation_name', 'metadata': {} } self.mock_post.return_value.status_code = 200 # Query status response self.mock_get.return_value.json.side_effect = [ { 'done': False, 'metadata': {} }, { 'done': False, 'metadata': {} }, { 'done': True, 'metadata': { 'batchPredictDetails': { 'outputInfo': { 'gcsOutputDirectory': 'gs://outputfile' } } } }, ] self.mock_get.return_value.status_code = 200 hook = automl_tables_hook.AutoMLTablesHook() # In unit tests we do not want to sleep between polling pred_result = hook.batch_predict(input_path='gs://input', output_path='gs://output', model_id='test_model_id', compute_region='test_region', poll_wait_time=0) self.assertEqual(pred_result['done'], True) self.mock_get.assert_called_with( 'https://automl.googleapis.com/v1beta1/test_operation_name') self.mock_post.assert_called_with( ('https://automl.googleapis.com/v1beta1/projects/test_project' '/locations/test_region/models/test_model_id:batchPredict'), json={ 'inputConfig': { 'gcsSource': { 'inputUris': ['gs://input'] } }, 'outputConfig': { 'gcsDestination': { 'outputUriPrefix': 'gs://output' } } })
def test_batch_predict_fail(self): """Test for a batch prediction which fails when submitting request.""" self.mock_post.return_value.status_code = 400 self.mock_post.return_value.text = 'Bad Request' hook = automl_tables_hook.AutoMLTablesHook() # In unit tests we do not want to sleep between polling with self.assertRaises(RuntimeError) as cm: hook.batch_predict(input_path='gs://input', output_path='gs://output', compute_region='test_region', model_id='test_model_id', poll_wait_time=0) self.assertIn('Error calling AutoML API', cm.exception.args[0])
def test_build_predict_payload_bq2bq(self): hook = automl_tables_hook.AutoMLTablesHook() p = hook._build_batch_predict_payload('bq://input', 'bq://output') self.assertEqual( p, { 'inputConfig': { 'bigquerySource': { 'input_uri': 'bq://input' } }, 'outputConfig': { 'bigqueryDestination': { 'outputUri': 'bq://output' } } })
def test_batch_predict_retry(self): """Test for a retried prediction. Test for a batch prediction which fails for the first time, and after retry succeeds. """ # Predict response self.mock_post.side_effect = [ mock.Mock(status_code=400, text='Bad Request'), mock.Mock(status_code=200, json=mock.Mock(return_value={ 'name': 'test_operation_name', 'metadata': {} })), ] # Query status response self.mock_get.return_value.json.side_effect = [ { 'done': False, 'metadata': {} }, { 'done': False, 'metadata': {} }, { 'done': True, 'metadata': { 'batchPredictDetails': { 'outputInfo': { 'gcsOutputDirectory': 'gs://outputfile' } } } }, ] self.mock_get.return_value.status_code = 200 hook = automl_tables_hook.AutoMLTablesHook() # In unit tests we do not want to sleep between polling pred_result = hook.batch_predict(input_path='gs://input', output_path='gs://output', model_id='test_model_id', compute_region='test_region', poll_wait_time=0) self.assertEqual(pred_result['done'], True)
def test_build_predict_payload_gs2gs_multi_inputs(self): hook = automl_tables_hook.AutoMLTablesHook() p = hook._build_batch_predict_payload('gs://input1, gs://input2', 'gs://output') self.assertEqual( p, { 'inputConfig': { 'gcsSource': { 'inputUris': ['gs://input1', 'gs://input2'] } }, 'outputConfig': { 'gcsDestination': { 'outputUriPrefix': 'gs://output' } } })
def test_build_predict_payload_gs2gs(self): """Test for building predict request payload from gs->gs.""" hook = automl_tables_hook.AutoMLTablesHook() p = hook._build_batch_predict_payload('gs://input', 'gs://output') self.assertEqual( p, { 'inputConfig': { 'gcsSource': { 'inputUris': ['gs://input'] } }, 'outputConfig': { 'gcsDestination': { 'outputUriPrefix': 'gs://output' } } })
def test_batch_predict_timeout(self): """Test for a batch predict which timeouts.""" # predict response self.mock_post.return_value.json.return_value = { 'name': 'test_operation_name', 'metadata': {} } self.mock_post.return_value.status_code = 200 # Query status response self.mock_get.return_value.json.side_effect = [ { 'done': False, 'metadata': {} }, { 'done': False, 'metadata': {} }, { 'done': True, 'metadata': { 'batchPredictDetails': { 'outputInfo': { 'gcsOutputDirectory': 'gs://outputfile' } } } }, ] self.mock_get.return_value.status_code = 200 # Set timeout to 0 to ensure the request timeout hook = automl_tables_hook.AutoMLTablesHook() with self.assertRaises(RuntimeError) as cm: hook.batch_predict(input_path='gs://input', output_path='gs://output', prediction_timeout=0, model_id='test_model_id', compute_region='test_region', poll_wait_time=0) self.assertIn('Timeout', cm.exception.args[0])
def test_batch_predict_fail_wait(self): """Test for a batch prediction which fails when polling status.""" self.mock_post.return_value.json.return_value = { 'name': 'test_operation_name', 'metadata': {} } self.mock_post.return_value.status_code = 200 self.mock_get.return_value.status_code = 400 self.mock_get.return_value.text = 'Bad Request' hook = automl_tables_hook.AutoMLTablesHook() # In unit tests we do not want to sleep between polling with self.assertRaises(RuntimeError) as cm: hook.batch_predict(input_path='gs://input', output_path='gs://output', model_id='test_model_id', compute_region='test_region', poll_wait_time=0) self.assertIn('Error waiting for AutoML', cm.exception.args[0])
def test_build_predict_payload_invalid_multi_paths(self): hook = automl_tables_hook.AutoMLTablesHook() # multiple paths only supported for cloud storage inputs with self.assertRaises(ValueError): hook._build_batch_predict_payload('bq://input1, bq://input2', 'bq://output')
def test_build_predict_payload_invalid_uris(self): hook = automl_tables_hook.AutoMLTablesHook() with self.assertRaises(ValueError): hook._build_batch_predict_payload('http://input', 'bq://output')