def test_execute_success_output_to_bq(self):
        self.mock_predict.return_value = {
            'done': True,
            'metadata': {
                'batchPredictDetails': {
                    'outputInfo': {
                        'bigqueryOutputDataset': 'bq://output_path'
                    }
                }
            }
        }
        op = automl_pred_op.AutoMLTablesBatchPredictionOperator(
            task_id='test',
            model_id='my_test_model_id',
            input_path='gs://input',
            output_key='output_key',
            output_path='gs://output')

        op.execute(context={'foo': 'bar'})

        self.mock_predict.assert_called_with(model_id='my_test_model_id',
                                             input_path='gs://input',
                                             output_path='gs://output')
        self.mock_input_path.assert_called_with({'foo': 'bar'})
        self.mock_output_path.assert_called_with({'foo': 'bar'},
                                                 'bq://output_path')
 def test_init_fail_no_input_path_or_input_key(self):
     with self.assertRaises(ValueError):
         automl_pred_op.AutoMLTablesBatchPredictionOperator(
             task_id='test',
             model_id='my_test_model_id',
             output_key='output_key',
             output_path='gs://output')
 def test_init_success(self):
     op = automl_pred_op.AutoMLTablesBatchPredictionOperator(
         task_id='test',
         model_id='my_test_model_id',
         input_key='input_key',
         output_key='output_key',
         output_path='gs://output')
     self.assertIsNotNone(op)
    def test_execute_raise_error_when_prediction_fail(self):
        self.mock_predict.return_value = {
            'done': True,
            'error': 'error message'
        }
        op = automl_pred_op.AutoMLTablesBatchPredictionOperator(
            task_id='test',
            model_id='my_test_model_id',
            input_path='gs://input',
            output_key='output_key',
            output_path='gs://output')

        with self.assertRaises(RuntimeError):
            op.execute(context={'foo': 'bar'})

        self.mock_predict.assert_called_with(model_id='my_test_model_id',
                                             input_path='gs://input',
                                             output_path='gs://output')
        self.mock_input_path.assert_called_with({'foo': 'bar'})
        self.mock_output_path.assert_not_called()