def _do_batch_prediction(self, input_path: Text,
                             output_path: Text) -> Text:
        """Do batch prediction.

    Do AutoML batch prediction using data in input_path, write results to
      output_path(or with output prefix). Since the AutoML API may add
      timestamps to the output path specified by the user, the real output path
      must be extracted from the response.

    Args:
      input_path: The path to input data.
      output_path: The output path(for bq://) or output prefix(for gs://) passed
        to AutoML API to write prediction results.

    Returns:
      The output path for the prediction results.

    Raises:
      RuntimeError: When there are errors in the AutoML API response.
    """
        hook = automl_tables_hook.AutoMLTablesHook(gcp_conn_id=self.conn_id)
        resp = hook.batch_predict(model_id=self.model_id,
                                  input_path=input_path,
                                  output_path=output_path)
        if 'error' in resp:
            raise RuntimeError(f'Error in predictions {resp}')
        output_info = resp['metadata']['batchPredictDetails']['outputInfo']
        if 'gcsOutputDirectory' in output_info:
            return output_info['gcsOutputDirectory']
        elif 'bigqueryOutputDataset' in output_info:
            return output_info['bigqueryOutputDataset']
        else:
            raise RuntimeError(
                f'Output not found in prediction response: {resp}')
Esempio n. 2
0
    def test_batch_predict_fail_error(self):
        """Test for a batch prediction which fails with an error in response."""
        # Predict response
        self.mock_post.return_value.json.return_value = {
            'name': 'test_operation_name',
            'metadata': {}
        }
        self.mock_post.return_value.status_code = 200

        # Query status response
        self.mock_get.return_value.json.return_value = {
            'done': True,
            'error': 'error message'
        }
        self.mock_get.return_value.status_code = 200

        hook = automl_tables_hook.AutoMLTablesHook()
        # In unit tests we do not want to sleep between polling
        with self.assertRaises(RuntimeError) as cm:
            hook.batch_predict(input_path='gs://input',
                               output_path='gs://output',
                               model_id='test_model_id',
                               compute_region='test_region',
                               poll_wait_time=0)

        self.assertIn('Error in prediction result', cm.exception.args[0])
Esempio n. 3
0
    def test_batch_predict_success(self):
        """Test for a successful batch prediction."""
        # Predict response
        self.mock_post.return_value.json.return_value = {
            'name': 'test_operation_name',
            'metadata': {}
        }
        self.mock_post.return_value.status_code = 200

        # Query status response
        self.mock_get.return_value.json.side_effect = [
            {
                'done': False,
                'metadata': {}
            },
            {
                'done': False,
                'metadata': {}
            },
            {
                'done': True,
                'metadata': {
                    'batchPredictDetails': {
                        'outputInfo': {
                            'gcsOutputDirectory': 'gs://outputfile'
                        }
                    }
                }
            },
        ]
        self.mock_get.return_value.status_code = 200

        hook = automl_tables_hook.AutoMLTablesHook()
        # In unit tests we do not want to sleep between polling
        pred_result = hook.batch_predict(input_path='gs://input',
                                         output_path='gs://output',
                                         model_id='test_model_id',
                                         compute_region='test_region',
                                         poll_wait_time=0)
        self.assertEqual(pred_result['done'], True)
        self.mock_get.assert_called_with(
            'https://automl.googleapis.com/v1beta1/test_operation_name')
        self.mock_post.assert_called_with(
            ('https://automl.googleapis.com/v1beta1/projects/test_project'
             '/locations/test_region/models/test_model_id:batchPredict'),
            json={
                'inputConfig': {
                    'gcsSource': {
                        'inputUris': ['gs://input']
                    }
                },
                'outputConfig': {
                    'gcsDestination': {
                        'outputUriPrefix': 'gs://output'
                    }
                }
            })
Esempio n. 4
0
    def test_batch_predict_fail(self):
        """Test for a batch prediction which fails when submitting request."""
        self.mock_post.return_value.status_code = 400
        self.mock_post.return_value.text = 'Bad Request'

        hook = automl_tables_hook.AutoMLTablesHook()
        # In unit tests we do not want to sleep between polling
        with self.assertRaises(RuntimeError) as cm:
            hook.batch_predict(input_path='gs://input',
                               output_path='gs://output',
                               compute_region='test_region',
                               model_id='test_model_id',
                               poll_wait_time=0)

        self.assertIn('Error calling AutoML API', cm.exception.args[0])
Esempio n. 5
0
 def test_build_predict_payload_bq2bq(self):
     hook = automl_tables_hook.AutoMLTablesHook()
     p = hook._build_batch_predict_payload('bq://input', 'bq://output')
     self.assertEqual(
         p, {
             'inputConfig': {
                 'bigquerySource': {
                     'input_uri': 'bq://input'
                 }
             },
             'outputConfig': {
                 'bigqueryDestination': {
                     'outputUri': 'bq://output'
                 }
             }
         })
Esempio n. 6
0
    def test_batch_predict_retry(self):
        """Test for a retried prediction.

    Test for a batch prediction which fails for the first time, and after
    retry succeeds.
    """
        # Predict response
        self.mock_post.side_effect = [
            mock.Mock(status_code=400, text='Bad Request'),
            mock.Mock(status_code=200,
                      json=mock.Mock(return_value={
                          'name': 'test_operation_name',
                          'metadata': {}
                      })),
        ]

        # Query status response
        self.mock_get.return_value.json.side_effect = [
            {
                'done': False,
                'metadata': {}
            },
            {
                'done': False,
                'metadata': {}
            },
            {
                'done': True,
                'metadata': {
                    'batchPredictDetails': {
                        'outputInfo': {
                            'gcsOutputDirectory': 'gs://outputfile'
                        }
                    }
                }
            },
        ]
        self.mock_get.return_value.status_code = 200

        hook = automl_tables_hook.AutoMLTablesHook()
        # In unit tests we do not want to sleep between polling
        pred_result = hook.batch_predict(input_path='gs://input',
                                         output_path='gs://output',
                                         model_id='test_model_id',
                                         compute_region='test_region',
                                         poll_wait_time=0)
        self.assertEqual(pred_result['done'], True)
Esempio n. 7
0
 def test_build_predict_payload_gs2gs_multi_inputs(self):
     hook = automl_tables_hook.AutoMLTablesHook()
     p = hook._build_batch_predict_payload('gs://input1, gs://input2',
                                           'gs://output')
     self.assertEqual(
         p, {
             'inputConfig': {
                 'gcsSource': {
                     'inputUris': ['gs://input1', 'gs://input2']
                 }
             },
             'outputConfig': {
                 'gcsDestination': {
                     'outputUriPrefix': 'gs://output'
                 }
             }
         })
Esempio n. 8
0
 def test_build_predict_payload_gs2gs(self):
     """Test for building predict request payload from gs->gs."""
     hook = automl_tables_hook.AutoMLTablesHook()
     p = hook._build_batch_predict_payload('gs://input', 'gs://output')
     self.assertEqual(
         p, {
             'inputConfig': {
                 'gcsSource': {
                     'inputUris': ['gs://input']
                 }
             },
             'outputConfig': {
                 'gcsDestination': {
                     'outputUriPrefix': 'gs://output'
                 }
             }
         })
Esempio n. 9
0
    def test_batch_predict_timeout(self):
        """Test for a batch predict which timeouts."""
        # predict response
        self.mock_post.return_value.json.return_value = {
            'name': 'test_operation_name',
            'metadata': {}
        }
        self.mock_post.return_value.status_code = 200

        # Query status response
        self.mock_get.return_value.json.side_effect = [
            {
                'done': False,
                'metadata': {}
            },
            {
                'done': False,
                'metadata': {}
            },
            {
                'done': True,
                'metadata': {
                    'batchPredictDetails': {
                        'outputInfo': {
                            'gcsOutputDirectory': 'gs://outputfile'
                        }
                    }
                }
            },
        ]
        self.mock_get.return_value.status_code = 200

        # Set timeout to 0 to ensure the request timeout
        hook = automl_tables_hook.AutoMLTablesHook()
        with self.assertRaises(RuntimeError) as cm:
            hook.batch_predict(input_path='gs://input',
                               output_path='gs://output',
                               prediction_timeout=0,
                               model_id='test_model_id',
                               compute_region='test_region',
                               poll_wait_time=0)

        self.assertIn('Timeout', cm.exception.args[0])
Esempio n. 10
0
    def test_batch_predict_fail_wait(self):
        """Test for a batch prediction which fails when polling status."""
        self.mock_post.return_value.json.return_value = {
            'name': 'test_operation_name',
            'metadata': {}
        }
        self.mock_post.return_value.status_code = 200

        self.mock_get.return_value.status_code = 400
        self.mock_get.return_value.text = 'Bad Request'

        hook = automl_tables_hook.AutoMLTablesHook()
        # In unit tests we do not want to sleep between polling
        with self.assertRaises(RuntimeError) as cm:
            hook.batch_predict(input_path='gs://input',
                               output_path='gs://output',
                               model_id='test_model_id',
                               compute_region='test_region',
                               poll_wait_time=0)

        self.assertIn('Error waiting for AutoML', cm.exception.args[0])
Esempio n. 11
0
 def test_build_predict_payload_invalid_multi_paths(self):
     hook = automl_tables_hook.AutoMLTablesHook()
     # multiple paths only supported for cloud storage inputs
     with self.assertRaises(ValueError):
         hook._build_batch_predict_payload('bq://input1, bq://input2',
                                           'bq://output')
Esempio n. 12
0
 def test_build_predict_payload_invalid_uris(self):
     hook = automl_tables_hook.AutoMLTablesHook()
     with self.assertRaises(ValueError):
         hook._build_batch_predict_payload('http://input', 'bq://output')