def load_from_temp_to_permanent_table(bq_client: BigQueryClientImpl,
                                      project_id: str) -> None:
    """Query temporary table and persist view to permanent table"""
    num_rows_before = bq_client.get_table(
        dataset_ref=bigquery.DatasetReference(
            project=project_id,
            dataset_id=DATASET_ID,
        ),
        table_id=FINAL_DESTINATION_TABLE,
    ).num_rows

    insert_job = bq_client.insert_into_table_from_query(
        destination_dataset_id=DATASET_ID,
        destination_table_id=FINAL_DESTINATION_TABLE,
        query=INSERT_QUERY_TEMPLATE.format(
            project_id=project_id,
            dataset_id=DATASET_ID,
            temp_table=TEMP_DESTINATION_TABLE,
            final_table=FINAL_DESTINATION_TABLE,
        ),
        write_disposition=WriteDisposition.WRITE_APPEND,
    )

    insert_job_result = insert_job.result()

    logging.info(
        "Loaded [%d] non-duplicate rows into table [%s]",
        (insert_job_result.total_rows - num_rows_before),
        FINAL_DESTINATION_TABLE,
    )

    bq_client.delete_table(dataset_id=DATASET_ID,
                           table_id=TEMP_DESTINATION_TABLE)
def _decommission_dataflow_metric_table(bq_client: BigQueryClientImpl,
                                        table_ref: TableListItem) -> None:
    """Decommissions a deprecated Dataflow metric table. Moves all remaining rows
    to cold storage and deletes the table in the DATAFLOW_METRICS_DATASET."""
    logging.info("Decommissioning Dataflow metric table: [%s]",
                 table_ref.table_id)

    dataflow_metrics_dataset = DATAFLOW_METRICS_DATASET
    cold_storage_dataset = dataflow_config.DATAFLOW_METRICS_COLD_STORAGE_DATASET
    table_id = table_ref.table_id

    # Move all rows in the table to cold storage
    insert_query = (
        """SELECT * FROM `{project_id}.{dataflow_metrics_dataset}.{table_id}`"""
        .format(
            project_id=table_ref.project,
            dataflow_metrics_dataset=dataflow_metrics_dataset,
            table_id=table_id,
        ))

    insert_job = bq_client.insert_into_table_from_query(
        destination_dataset_id=cold_storage_dataset,
        destination_table_id=table_id,
        query=insert_query,
        allow_field_additions=True,
        write_disposition=WriteDisposition.WRITE_APPEND,
    )

    # Wait for the insert job to complete before deleting the table
    insert_job.result()

    bq_client.delete_table(dataset_id=dataflow_metrics_dataset,
                           table_id=table_id)
Esempio n. 3
0
def _copy_regional_dataset_to_multi_region(
        config: CloudSqlToBQConfig,
        dataset_override_prefix: Optional[str]) -> None:
    """Copies the unioned regional dataset for a schema to the multi-region dataset
    that contains the same data. Backs up the multi-region dataset before performing
    the copy. This backup dataset will get cleaned up if the copy succeeds, but
    otherwise will stick around for 1 week before tables expire.
    """
    bq_client = BigQueryClientImpl()

    source_dataset_id = config.unioned_regional_dataset(
        dataset_override_prefix)
    destination_dataset_id = config.unioned_multi_region_dataset(
        dataset_override_prefix)
    destination_dataset = bq_client.dataset_ref_for_id(destination_dataset_id)

    backup_dataset = bq_client.backup_dataset_tables_if_dataset_exists(
        destination_dataset_id)

    try:
        if bq_client.dataset_exists(destination_dataset):
            tables = bq_client.list_tables(destination_dataset_id)
            for table in tables:
                bq_client.delete_table(table.dataset_id, table.table_id)

        bq_client.create_dataset_if_necessary(
            destination_dataset,
            default_table_expiration_ms=TEMP_DATASET_DEFAULT_TABLE_EXPIRATION_MS
            if dataset_override_prefix else None,
        )

        # Copy into the canonical unioned source datasets in the US multi-region
        bq_client.copy_dataset_tables_across_regions(
            source_dataset_id=source_dataset_id,
            destination_dataset_id=destination_dataset_id,
        )
    except Exception as e:
        logging.info(
            "Failed to flash [%s] to [%s] - contents backup can be found at [%s]",
            source_dataset_id,
            destination_dataset_id,
            backup_dataset.dataset_id if backup_dataset else "NO BACKUP",
        )
        raise e

    if backup_dataset:
        bq_client.delete_dataset(backup_dataset,
                                 delete_contents=True,
                                 not_found_ok=True)
Esempio n. 4
0
def compare_dataflow_output_to_sandbox(
    sandbox_dataset_prefix: str,
    job_name_to_compare: str,
    base_output_job_id: str,
    sandbox_output_job_id: str,
    additional_columns_to_compare: List[str],
    allow_overwrite: bool = False,
) -> None:
    """Compares the output for all metrics produced by the daily pipeline job with the given |job_name_to_compare|
    between the output from the |base_output_job_id| job in the dataflow_metrics dataset and the output from the
    |sandbox_output_job_id| job in the sandbox dataflow dataset."""
    bq_client = BigQueryClientImpl()
    sandbox_dataflow_dataset_id = (sandbox_dataset_prefix + "_" +
                                   DATAFLOW_METRICS_DATASET)

    sandbox_comparison_output_dataset_id = (sandbox_dataset_prefix +
                                            "_dataflow_comparison_output")
    sandbox_comparison_output_dataset_ref = bq_client.dataset_ref_for_id(
        sandbox_comparison_output_dataset_id)

    if bq_client.dataset_exists(sandbox_comparison_output_dataset_ref) and any(
            bq_client.list_tables(sandbox_comparison_output_dataset_id)):
        if not allow_overwrite:
            if __name__ == "__main__":
                logging.error(
                    "Dataset %s already exists in project %s. To overwrite, set --allow_overwrite.",
                    sandbox_comparison_output_dataset_id,
                    bq_client.project_id,
                )
                sys.exit(1)
            else:
                raise ValueError(
                    f"Cannot write comparison output to a non-empty dataset. Please delete tables in dataset: "
                    f"{bq_client.project_id}.{sandbox_comparison_output_dataset_id}."
                )
        else:
            # Clean up the existing tables in the dataset
            for table in bq_client.list_tables(
                    sandbox_comparison_output_dataset_id):
                bq_client.delete_table(table.dataset_id, table.table_id)

    bq_client.create_dataset_if_necessary(
        sandbox_comparison_output_dataset_ref,
        TEMP_DATASET_DEFAULT_TABLE_EXPIRATION_MS)

    query_jobs: List[Tuple[QueryJob, str]] = []

    pipelines = YAMLDict.from_path(PRODUCTION_TEMPLATES_PATH).pop_dicts(
        "daily_pipelines")

    for pipeline in pipelines:
        if pipeline.pop("job_name", str) == job_name_to_compare:
            pipeline_metric_types = pipeline.peek_optional("metric_types", str)

            if not pipeline_metric_types:
                raise ValueError(
                    f"Pipeline job {job_name_to_compare} missing required metric_types attribute."
                )

            metric_types_for_comparison = pipeline_metric_types.split()

            for metric_class, metric_table in DATAFLOW_METRICS_TO_TABLES.items(
            ):
                metric_type_value = DATAFLOW_TABLES_TO_METRIC_TYPES[
                    metric_table].value

                if metric_type_value in metric_types_for_comparison:
                    comparison_query = _query_for_metric_comparison(
                        bq_client,
                        base_output_job_id,
                        sandbox_output_job_id,
                        sandbox_dataflow_dataset_id,
                        metric_class,
                        metric_table,
                        additional_columns_to_compare,
                    )

                    query_job = bq_client.create_table_from_query_async(
                        dataset_id=sandbox_comparison_output_dataset_id,
                        table_id=metric_table,
                        query=comparison_query,
                        overwrite=True,
                    )

                    # Add query job to the list of running jobs
                    query_jobs.append((query_job, metric_table))

    for query_job, output_table_id in query_jobs:
        # Wait for the insert job to complete before looking for the table
        query_job.result()

        output_table = bq_client.get_table(
            sandbox_comparison_output_dataset_ref, output_table_id)

        if output_table.num_rows == 0:
            # If there are no rows in the output table, then the output was identical
            bq_client.delete_table(sandbox_comparison_output_dataset_id,
                                   output_table_id)

    metrics_with_different_output = peekable(
        bq_client.list_tables(sandbox_comparison_output_dataset_id))

    logging.info(
        "\n*************** DATAFLOW OUTPUT COMPARISON RESULTS ***************\n"
    )

    if metrics_with_different_output:
        for metric_table in metrics_with_different_output:
            # This will always be true, and is here to silence mypy warnings
            assert isinstance(metric_table, bigquery.table.TableListItem)

            logging.warning(
                "Dataflow output differs for metric %s. See %s.%s for diverging rows.",
                metric_table.table_id,
                sandbox_comparison_output_dataset_id,
                metric_table.table_id,
            )
    else:
        logging.info(
            "Dataflow output identical. Deleting dataset %s.",
            sandbox_comparison_output_dataset_ref.dataset_id,
        )
        bq_client.delete_dataset(sandbox_comparison_output_dataset_ref,
                                 delete_contents=True)
class BigQueryClientImplTest(unittest.TestCase):
    """Tests for BigQueryClientImpl"""
    def setUp(self):
        self.mock_project_id = 'fake-recidiviz-project'
        self.mock_dataset_id = 'fake-dataset'
        self.mock_table_id = 'test_table'
        self.mock_dataset = bigquery.dataset.DatasetReference(
            self.mock_project_id, self.mock_dataset_id)
        self.mock_table = self.mock_dataset.table(self.mock_table_id)

        self.metadata_patcher = mock.patch(
            'recidiviz.utils.metadata.project_id')
        self.mock_project_id_fn = self.metadata_patcher.start()
        self.mock_project_id_fn.return_value = self.mock_project_id

        self.client_patcher = mock.patch(
            'recidiviz.big_query.big_query_client.client')
        self.mock_client = self.client_patcher.start().return_value

        self.mock_view = BigQueryView(
            dataset_id='dataset',
            view_id='test_view',
            view_query_template='SELECT NULL LIMIT 0',
            materialized_view_table_id='test_view_table')

        self.bq_client = BigQueryClientImpl()

    def tearDown(self):
        self.client_patcher.stop()
        self.metadata_patcher.stop()

    def test_create_dataset_if_necessary(self):
        """Check that a dataset is created if it does not exist."""
        self.mock_client.get_dataset.side_effect = exceptions.NotFound('!')
        self.bq_client.create_dataset_if_necessary(self.mock_dataset)
        self.mock_client.create_dataset.assert_called()

    def test_create_dataset_if_necessary_dataset_exists(self):
        """Check that a dataset is not created if it already exists."""
        self.mock_client.get_dataset.side_effect = None
        self.bq_client.create_dataset_if_necessary(self.mock_dataset)
        self.mock_client.create_dataset.assert_not_called()

    def test_table_exists(self):
        """Check that table_exists returns True if the table exists."""
        self.mock_client.get_table.side_effect = None
        self.assertTrue(
            self.bq_client.table_exists(self.mock_dataset, self.mock_table_id))

    def test_table_exists_does_not_exist(self):
        """Check that table_exists returns False if the table does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        with self.assertLogs(level='WARNING'):
            table_exists = self.bq_client.table_exists(self.mock_dataset,
                                                       self.mock_table_id)
            self.assertFalse(table_exists)

    def test_create_or_update_view_creates_view(self):
        """create_or_update_view creates a View if it does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        self.bq_client.create_or_update_view(self.mock_dataset, self.mock_view)
        self.mock_client.create_table.assert_called()
        self.mock_client.update_table.assert_not_called()

    def test_create_or_update_view_updates_view(self):
        """create_or_update_view updates a View if it already exist."""
        self.mock_client.get_table.side_effect = None
        self.bq_client.create_or_update_view(self.mock_dataset, self.mock_view)
        self.mock_client.update_table.assert_called()
        self.mock_client.create_table.assert_not_called()

    def test_export_to_cloud_storage(self):
        """export_to_cloud_storage extracts the table corresponding to the
        view."""
        self.assertIsNotNone(
            self.bq_client.export_table_to_cloud_storage_async(
                source_table_dataset_ref=self.mock_dataset,
                source_table_id='source-table',
                destination_uri=
                f'gs://{self.mock_project_id}-bucket/destination_path.json',
                destination_format=bigquery.DestinationFormat.
                NEWLINE_DELIMITED_JSON))
        self.mock_client.extract_table.assert_called()

    def test_export_to_cloud_storage_no_table(self):
        """export_to_cloud_storage does not extract from a table if the table
        does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        with self.assertLogs(level='WARNING'):
            self.assertIsNone(
                self.bq_client.export_table_to_cloud_storage_async(
                    source_table_dataset_ref=self.mock_dataset,
                    source_table_id='source-table',
                    destination_uri=
                    f'gs://{self.mock_project_id}-bucket/destination_path.json',
                    destination_format=bigquery.DestinationFormat.
                    NEWLINE_DELIMITED_JSON))
            self.mock_client.extract_table.assert_not_called()

    def test_load_table_async_create_dataset(self):
        """Test that load_table_from_cloud_storage_async tries to create a parent dataset."""

        self.mock_client.get_dataset.side_effect = exceptions.NotFound('!')

        self.bq_client.load_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[
                SchemaField('my_column', 'STRING', 'NULLABLE', None, ())
            ],
            source_uri='gs://bucket/export-uri')

        self.mock_client.create_dataset.assert_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_load_table_async_dataset_exists(self):
        """Test that load_table_from_cloud_storage_async does not try to create a parent dataset if it already exists.
        """

        self.bq_client.load_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[
                SchemaField('my_column', 'STRING', 'NULLABLE', None, ())
            ],
            source_uri='gs://bucket/export-uri')

        self.mock_client.create_dataset.assert_not_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_export_query_results_to_cloud_storage_no_table(self):
        bucket = self.mock_project_id + '-bucket'
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        with self.assertLogs(level='WARNING'):
            self.bq_client.export_query_results_to_cloud_storage([
                ExportQueryConfig.from_view_query(
                    view=self.mock_view,
                    view_filter_clause='WHERE x = y',
                    intermediate_table_name=self.mock_table_id,
                    output_uri=f'gs://{bucket}/view.json',
                    output_format=bigquery.DestinationFormat.
                    NEWLINE_DELIMITED_JSON)
            ])

    def test_export_query_results_to_cloud_storage(self):
        """export_query_results_to_cloud_storage creates the table from the view query and
        exports the table."""
        bucket = self.mock_project_id + '-bucket'
        query_job = futures.Future()
        query_job.set_result([])
        extract_job = futures.Future()
        extract_job.set_result(None)
        self.mock_client.query.return_value = query_job
        self.mock_client.extract_table.return_value = extract_job
        self.bq_client.export_query_results_to_cloud_storage([
            ExportQueryConfig.from_view_query(
                view=self.mock_view,
                view_filter_clause='WHERE x = y',
                intermediate_table_name=self.mock_table_id,
                output_uri=f'gs://{bucket}/view.json',
                output_format=bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON
            )
        ])
        self.mock_client.query.assert_called()
        self.mock_client.extract_table.assert_called()
        self.mock_client.delete_table.assert_called_with(
            bigquery.DatasetReference(self.mock_project_id,
                                      self.mock_view.dataset_id).table(
                                          self.mock_table_id))

    def test_create_table_from_query(self):
        """Tests that the create_table_from_query function calls the function to create a table from a query."""
        self.bq_client.create_table_from_query_async(
            self.mock_dataset_id,
            self.mock_table_id,
            query="SELECT * FROM some.fake.table",
            query_parameters=[])
        self.mock_client.query.assert_called()

    def test_insert_into_table_from_table(self):
        """Tests that the insert_into_table_from_table function runs a query."""
        self.bq_client.insert_into_table_from_table_async(
            'fake_source_dataset_id', 'fake_table_id', self.mock_dataset_id,
            self.mock_table_id)
        self.mock_client.get_table.assert_called()
        self.mock_client.query.assert_called()

    def test_insert_into_table_from_table_invalid_destination(self):
        """Tests that the insert_into_table_from_table function does not run the query if the destination table does
        not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')

        with pytest.raises(ValueError):
            self.bq_client.insert_into_table_from_table_async(
                self.mock_dataset_id, self.mock_table_id,
                'fake_source_dataset_id', 'fake_table_id')
        self.mock_client.get_table.assert_called()
        self.mock_client.query.assert_not_called()

    def test_insert_into_table_from_cloud_storage_async(self):
        self.mock_client.get_dataset.side_effect = exceptions.NotFound('!')

        self.bq_client.insert_into_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[
                SchemaField('my_column', 'STRING', 'NULLABLE', None, ())
            ],
            source_uri='gs://bucket/export-uri')

        self.mock_client.create_dataset.assert_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_delete_from_table(self):
        """Tests that the delete_from_table function runs a query."""
        self.bq_client.delete_from_table_async(self.mock_dataset_id,
                                               self.mock_table_id,
                                               filter_clause="WHERE x > y")
        self.mock_client.query.assert_called()

    def test_delete_from_table_invalid_filter_clause(self):
        """Tests that the delete_from_table function does not run a query when the filter clause is invalid."""
        with pytest.raises(ValueError):
            self.bq_client.delete_from_table_async(self.mock_dataset_id,
                                                   self.mock_table_id,
                                                   filter_clause="x > y")
        self.mock_client.query.assert_not_called()

    def test_materialize_view_to_table(self):
        """Tests that the materialize_view_to_table function calls the function to create a table from a query."""
        self.bq_client.materialize_view_to_table(self.mock_view)
        self.mock_client.query.assert_called()

    def test_materialize_view_to_table_no_materialized_view_table_id(self):
        """Tests that the materialize_view_to_table function does not call the function to create a table from a
        query if there is no set materialized_view_table_id on the view."""
        invalid_view = BigQueryView(dataset_id='dataset',
                                    view_id='test_view',
                                    view_query_template='SELECT NULL LIMIT 0',
                                    materialized_view_table_id=None)

        with pytest.raises(ValueError):
            self.bq_client.materialize_view_to_table(invalid_view)
        self.mock_client.query.assert_not_called()

    def test_create_table_with_schema(self):
        """Tests that the create_table_with_schema function calls the create_table function on the client."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        schema_fields = [bigquery.SchemaField('new_schema_field', 'STRING')]

        self.bq_client.create_table_with_schema(self.mock_dataset_id,
                                                self.mock_table_id,
                                                schema_fields)
        self.mock_client.create_table.assert_called()

    def test_create_table_with_schema_table_exists(self):
        """Tests that the create_table_with_schema function raises an error when the table already exists."""
        self.mock_client.get_table.side_effect = None
        schema_fields = [bigquery.SchemaField('new_schema_field', 'STRING')]

        with pytest.raises(ValueError):
            self.bq_client.create_table_with_schema(self.mock_dataset_id,
                                                    self.mock_table_id,
                                                    schema_fields)
        self.mock_client.create_table.assert_not_called()

    def test_add_missing_fields_to_schema(self):
        """Tests that the add_missing_fields_to_schema function calls the client to update the table."""
        table_ref = bigquery.TableReference(self.mock_dataset,
                                            self.mock_table_id)
        schema_fields = [bigquery.SchemaField('fake_schema_field', 'STRING')]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField('new_schema_field', 'STRING')
        ]

        self.bq_client.add_missing_fields_to_schema(self.mock_dataset_id,
                                                    self.mock_table_id,
                                                    new_schema_fields)

        self.mock_client.update_table.assert_called()

    def test_add_missing_fields_to_schema_no_missing_fields(self):
        """Tests that the add_missing_fields_to_schema function does not call the client to update the table when all
        of the fields are already present."""
        table_ref = bigquery.TableReference(self.mock_dataset,
                                            self.mock_table_id)
        schema_fields = [bigquery.SchemaField('fake_schema_field', 'STRING')]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField('fake_schema_field', 'STRING')
        ]

        self.bq_client.add_missing_fields_to_schema(self.mock_dataset_id,
                                                    self.mock_table_id,
                                                    new_schema_fields)

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_no_table(self):
        """Tests that the add_missing_fields_to_schema function does not call the client to update the table when the
        table does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        new_schema_fields = [
            bigquery.SchemaField('fake_schema_field', 'STRING')
        ]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields)

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_fields_with_same_name_different_type(
            self):
        """Tests that the add_missing_fields_to_schema function raises an error when the user is trying to add a field
        with the same name but different field_type as an existing field."""
        table_ref = bigquery.TableReference(self.mock_dataset,
                                            self.mock_table_id)
        schema_fields = [bigquery.SchemaField('fake_schema_field', 'STRING')]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField('fake_schema_field', 'INTEGER')
        ]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields)

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_fields_with_same_name_different_mode(
            self):
        """Tests that the add_missing_fields_to_schema function raises an error when the user is trying to add a field
        with the same name but different mode as an existing field."""
        table_ref = bigquery.TableReference(self.mock_dataset,
                                            self.mock_table_id)
        schema_fields = [
            bigquery.SchemaField('fake_schema_field',
                                 'STRING',
                                 mode="NULLABLE")
        ]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField('fake_schema_field',
                                 'STRING',
                                 mode="REQUIRED")
        ]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields)

        self.mock_client.update_table.assert_not_called()

    def test_delete_table(self):
        """Tests that our delete table function calls the correct client method."""
        self.bq_client.delete_table(self.mock_dataset_id, self.mock_table_id)
        self.mock_client.delete_table.assert_called()
Esempio n. 6
0
class BigQueryClientImplTest(unittest.TestCase):
    """Tests for BigQueryClientImpl"""
    def setUp(self) -> None:
        self.location = 'US'
        self.mock_project_id = 'fake-recidiviz-project'
        self.mock_dataset_id = 'fake-dataset'
        self.mock_table_id = 'test_table'
        self.mock_dataset_ref = bigquery.dataset.DatasetReference(
            self.mock_project_id, self.mock_dataset_id)
        self.mock_table = self.mock_dataset_ref.table(self.mock_table_id)

        self.metadata_patcher = mock.patch(
            'recidiviz.utils.metadata.project_id')
        self.mock_project_id_fn = self.metadata_patcher.start()
        self.mock_project_id_fn.return_value = self.mock_project_id

        self.client_patcher = mock.patch(
            'recidiviz.big_query.big_query_client.client')
        self.mock_client = self.client_patcher.start().return_value

        self.mock_view = BigQueryView(
            dataset_id='dataset',
            view_id='test_view',
            view_query_template='SELECT NULL LIMIT 0',
            should_materialize=True)

        self.bq_client = BigQueryClientImpl()

    def tearDown(self) -> None:
        self.client_patcher.stop()
        self.metadata_patcher.stop()

    def test_create_dataset_if_necessary(self) -> None:
        """Check that a dataset is created if it does not exist."""
        self.mock_client.get_dataset.side_effect = exceptions.NotFound('!')
        self.bq_client.create_dataset_if_necessary(self.mock_dataset_ref)
        self.mock_client.create_dataset.assert_called()

    def test_create_dataset_if_necessary_dataset_exists(self) -> None:
        """Check that a dataset is not created if it already exists."""
        self.mock_client.get_dataset.side_effect = None
        self.bq_client.create_dataset_if_necessary(self.mock_dataset_ref)
        self.mock_client.create_dataset.assert_not_called()

    def test_create_dataset_if_necessary_table_expiration(self) -> None:
        """Check that the dataset is created with a set table expiration if the dataset does not exist and the
        new_dataset_table_expiration_ms is specified."""
        self.mock_client.get_dataset.side_effect = exceptions.NotFound('!')
        self.bq_client.create_dataset_if_necessary(
            self.mock_dataset_ref, default_table_expiration_ms=6000)
        self.mock_client.create_dataset.assert_called()

    def test_table_exists(self) -> None:
        """Check that table_exists returns True if the table exists."""
        self.mock_client.get_table.side_effect = None
        self.assertTrue(
            self.bq_client.table_exists(self.mock_dataset_ref,
                                        self.mock_table_id))

    def test_table_exists_does_not_exist(self) -> None:
        """Check that table_exists returns False if the table does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        with self.assertLogs(level='WARNING'):
            table_exists = self.bq_client.table_exists(self.mock_dataset_ref,
                                                       self.mock_table_id)
            self.assertFalse(table_exists)

    def test_create_or_update_view_creates_view(self) -> None:
        """create_or_update_view creates a View if it does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        self.bq_client.create_or_update_view(self.mock_dataset_ref,
                                             self.mock_view)
        self.mock_client.create_table.assert_called()
        self.mock_client.update_table.assert_not_called()

    def test_create_or_update_view_updates_view(self) -> None:
        """create_or_update_view updates a View if it already exist."""
        self.mock_client.get_table.side_effect = None
        self.bq_client.create_or_update_view(self.mock_dataset_ref,
                                             self.mock_view)
        self.mock_client.update_table.assert_called()
        self.mock_client.create_table.assert_not_called()

    def test_export_to_cloud_storage(self) -> None:
        """export_to_cloud_storage extracts the table corresponding to the
        view."""
        self.assertIsNotNone(
            self.bq_client.export_table_to_cloud_storage_async(
                source_table_dataset_ref=self.mock_dataset_ref,
                source_table_id='source-table',
                destination_uri=
                f'gs://{self.mock_project_id}-bucket/destination_path.json',
                destination_format=bigquery.DestinationFormat.
                NEWLINE_DELIMITED_JSON))
        self.mock_client.extract_table.assert_called()

    def test_export_to_cloud_storage_no_table(self) -> None:
        """export_to_cloud_storage does not extract from a table if the table
        does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        with self.assertLogs(level='WARNING'):
            self.assertIsNone(
                self.bq_client.export_table_to_cloud_storage_async(
                    source_table_dataset_ref=self.mock_dataset_ref,
                    source_table_id='source-table',
                    destination_uri=
                    f'gs://{self.mock_project_id}-bucket/destination_path.json',
                    destination_format=bigquery.DestinationFormat.
                    NEWLINE_DELIMITED_JSON))
            self.mock_client.extract_table.assert_not_called()

    def test_load_table_async_create_dataset(self) -> None:
        """Test that load_table_from_cloud_storage_async tries to create a parent dataset."""

        self.mock_client.get_dataset.side_effect = exceptions.NotFound('!')

        self.bq_client.load_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset_ref,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[
                SchemaField('my_column', 'STRING', 'NULLABLE', None, ())
            ],
            source_uri='gs://bucket/export-uri')

        self.mock_client.create_dataset.assert_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_load_table_async_dataset_exists(self) -> None:
        """Test that load_table_from_cloud_storage_async does not try to create a parent dataset if it already exists.
        """

        self.bq_client.load_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset_ref,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[
                SchemaField('my_column', 'STRING', 'NULLABLE', None, ())
            ],
            source_uri='gs://bucket/export-uri')

        self.mock_client.create_dataset.assert_not_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_export_query_results_to_cloud_storage_no_table(self) -> None:
        bucket = self.mock_project_id + '-bucket'
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        with self.assertLogs(level='WARNING'):
            self.bq_client.export_query_results_to_cloud_storage([
                ExportQueryConfig.from_view_query(
                    view=self.mock_view,
                    view_filter_clause='WHERE x = y',
                    intermediate_table_name=self.mock_table_id,
                    output_uri=f'gs://{bucket}/view.json',
                    output_format=bigquery.DestinationFormat.
                    NEWLINE_DELIMITED_JSON)
            ])

    def test_export_query_results_to_cloud_storage(self) -> None:
        """export_query_results_to_cloud_storage creates the table from the view query and
        exports the table."""
        bucket = self.mock_project_id + '-bucket'
        query_job: futures.Future = futures.Future()
        query_job.set_result([])
        extract_job: futures.Future = futures.Future()
        extract_job.set_result(None)
        self.mock_client.query.return_value = query_job
        self.mock_client.extract_table.return_value = extract_job
        self.bq_client.export_query_results_to_cloud_storage([
            ExportQueryConfig.from_view_query(
                view=self.mock_view,
                view_filter_clause='WHERE x = y',
                intermediate_table_name=self.mock_table_id,
                output_uri=f'gs://{bucket}/view.json',
                output_format=bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON
            )
        ])
        self.mock_client.query.assert_called()
        self.mock_client.extract_table.assert_called()
        self.mock_client.delete_table.assert_called_with(
            bigquery.DatasetReference(self.mock_project_id,
                                      self.mock_view.dataset_id).table(
                                          self.mock_table_id))

    def test_create_table_from_query(self) -> None:
        """Tests that the create_table_from_query function calls the function to create a table from a query."""
        self.bq_client.create_table_from_query_async(
            self.mock_dataset_id,
            self.mock_table_id,
            query="SELECT * FROM some.fake.table",
            query_parameters=[])
        self.mock_client.query.assert_called()

    @mock.patch('google.cloud.bigquery.job.QueryJobConfig')
    def test_insert_into_table_from_table_async(
            self, mock_job_config: mock.MagicMock) -> None:
        """Tests that the insert_into_table_from_table_async function runs a query."""
        self.bq_client.insert_into_table_from_table_async(
            source_dataset_id=self.mock_dataset_id,
            source_table_id=self.mock_table_id,
            destination_dataset_id=self.mock_dataset_id,
            destination_table_id='fake_table_temp')
        expected_query = f"SELECT * FROM `fake-recidiviz-project.{self.mock_dataset_id}.{self.mock_table_id}`"
        self.mock_client.get_table.assert_called()
        self.mock_client.query.assert_called_with(query=expected_query,
                                                  location=self.location,
                                                  job_config=mock_job_config())

    @mock.patch('google.cloud.bigquery.job.QueryJobConfig')
    def test_insert_into_table_from_table_async_hydrate_missing_columns(
            self, mock_job_config: mock.MagicMock) -> None:
        """Tests that the insert_into_table_from_table_async generates a query with missing columns as NULL."""
        with mock.patch(
                'recidiviz.big_query.big_query_client.BigQueryClientImpl'
                '._get_schema_fields_missing_from_table') as mock_missing:
            mock_missing.return_value = [
                bigquery.SchemaField('state_code', 'STRING', 'REQUIRED'),
                bigquery.SchemaField('new_column_name', 'INTEGER', 'REQUIRED')
            ]
            self.mock_destination_id = 'fake_table_temp'
            self.bq_client.insert_into_table_from_table_async(
                source_dataset_id=self.mock_dataset_id,
                source_table_id=self.mock_table_id,
                destination_dataset_id=self.mock_dataset_id,
                destination_table_id=self.mock_destination_id,
                hydrate_missing_columns_with_null=True,
                allow_field_additions=True)
            expected_query = "SELECT *, CAST(NULL AS STRING) AS state_code, CAST(NULL AS INTEGER) AS new_column_name " \
                             f"FROM `fake-recidiviz-project.{self.mock_dataset_id}.{self.mock_table_id}`"
            self.mock_client.query.assert_called_with(
                query=expected_query,
                location=self.location,
                job_config=mock_job_config())

    def test_insert_into_table_from_table_invalid_destination(self) -> None:
        """Tests that the insert_into_table_from_table_async function does not run the query if the destination
        table does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')

        with pytest.raises(ValueError):
            self.bq_client.insert_into_table_from_table_async(
                self.mock_dataset_id, self.mock_table_id,
                'fake_source_dataset_id', 'fake_table_id')
        self.mock_client.get_table.assert_called()
        self.mock_client.query.assert_not_called()

    def test_insert_into_table_from_table_invalid_filter_clause(self) -> None:
        """Tests that the insert_into_table_from_table_async function does not run the query if the filter clause
        does not start with a WHERE."""
        with pytest.raises(ValueError):
            self.bq_client.insert_into_table_from_table_async(
                self.mock_dataset_id,
                self.mock_table_id,
                'fake_source_dataset_id',
                'fake_table_id',
                source_data_filter_clause='bad filter clause')
        self.mock_client.query.assert_not_called()

    @mock.patch('google.cloud.bigquery.job.QueryJobConfig')
    def test_insert_into_table_from_table_with_filter_clause(
            self, mock_job_config: mock.MagicMock) -> None:
        """Tests that the insert_into_table_from_table_async generates a valid query when given a filter clause."""
        filter_clause = "WHERE state_code IN ('US_ND')"
        job_config = mock_job_config()
        self.bq_client.insert_into_table_from_table_async(
            self.mock_dataset_id,
            self.mock_table_id,
            'fake_source_dataset_id',
            'fake_table_id',
            source_data_filter_clause=filter_clause)
        expected_query = "SELECT * FROM `fake-recidiviz-project.fake-dataset.test_table` " \
                         "WHERE state_code IN ('US_ND')"
        self.mock_client.query.assert_called_with(query=expected_query,
                                                  location=self.location,
                                                  job_config=job_config)

    def test_insert_into_table_from_cloud_storage_async(self) -> None:
        self.mock_client.get_dataset.side_effect = exceptions.NotFound('!')

        self.bq_client.insert_into_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset_ref,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[
                SchemaField('my_column', 'STRING', 'NULLABLE', None, ())
            ],
            source_uri='gs://bucket/export-uri')

        self.mock_client.create_dataset.assert_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_delete_from_table(self) -> None:
        """Tests that the delete_from_table function runs a query."""
        self.bq_client.delete_from_table_async(self.mock_dataset_id,
                                               self.mock_table_id,
                                               filter_clause="WHERE x > y")
        self.mock_client.query.assert_called()

    def test_delete_from_table_invalid_filter_clause(self) -> None:
        """Tests that the delete_from_table function does not run a query when the filter clause is invalid."""
        with pytest.raises(ValueError):
            self.bq_client.delete_from_table_async(self.mock_dataset_id,
                                                   self.mock_table_id,
                                                   filter_clause="x > y")
        self.mock_client.query.assert_not_called()

    def test_materialize_view_to_table(self) -> None:
        """Tests that the materialize_view_to_table function calls the function to create a table from a query."""
        self.bq_client.materialize_view_to_table(self.mock_view)
        self.mock_client.query.assert_called()

    def test_materialize_view_to_table_no_materialized_view_table_id(
            self) -> None:
        """Tests that the materialize_view_to_table function does not call the function to create a table from a
        query if there is no set materialized_view_table_id on the view."""
        invalid_view = BigQueryView(dataset_id='dataset',
                                    view_id='test_view',
                                    view_query_template='SELECT NULL LIMIT 0',
                                    should_materialize=False)

        with pytest.raises(ValueError):
            self.bq_client.materialize_view_to_table(invalid_view)
        self.mock_client.query.assert_not_called()

    def test_create_table_with_schema(self) -> None:
        """Tests that the create_table_with_schema function calls the create_table function on the client."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        schema_fields = [bigquery.SchemaField('new_schema_field', 'STRING')]

        self.bq_client.create_table_with_schema(self.mock_dataset_id,
                                                self.mock_table_id,
                                                schema_fields)
        self.mock_client.create_table.assert_called()

    def test_create_table_with_schema_table_exists(self) -> None:
        """Tests that the create_table_with_schema function raises an error when the table already exists."""
        self.mock_client.get_table.side_effect = None
        schema_fields = [bigquery.SchemaField('new_schema_field', 'STRING')]

        with pytest.raises(ValueError):
            self.bq_client.create_table_with_schema(self.mock_dataset_id,
                                                    self.mock_table_id,
                                                    schema_fields)
        self.mock_client.create_table.assert_not_called()

    def test_add_missing_fields_to_schema(self) -> None:
        """Tests that the add_missing_fields_to_schema function calls the client to update the table."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref,
                                            self.mock_table_id)
        schema_fields = [bigquery.SchemaField('fake_schema_field', 'STRING')]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField('new_schema_field', 'STRING')
        ]

        self.bq_client.add_missing_fields_to_schema(self.mock_dataset_id,
                                                    self.mock_table_id,
                                                    new_schema_fields)

        self.mock_client.update_table.assert_called()

    def test_add_missing_fields_to_schema_no_missing_fields(self) -> None:
        """Tests that the add_missing_fields_to_schema function does not call the client to update the table when all
        of the fields are already present."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref,
                                            self.mock_table_id)
        schema_fields = [bigquery.SchemaField('fake_schema_field', 'STRING')]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField('fake_schema_field', 'STRING')
        ]

        self.bq_client.add_missing_fields_to_schema(self.mock_dataset_id,
                                                    self.mock_table_id,
                                                    new_schema_fields)

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_no_table(self) -> None:
        """Tests that the add_missing_fields_to_schema function does not call the client to update the table when the
        table does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        new_schema_fields = [
            bigquery.SchemaField('fake_schema_field', 'STRING')
        ]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields)

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_fields_with_same_name_different_type(
            self) -> None:
        """Tests that the add_missing_fields_to_schema function raises an error when the user is trying to add a field
        with the same name but different field_type as an existing field."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref,
                                            self.mock_table_id)
        schema_fields = [bigquery.SchemaField('fake_schema_field', 'STRING')]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField('fake_schema_field', 'INTEGER')
        ]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields)

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_fields_with_same_name_different_mode(
            self) -> None:
        """Tests that the add_missing_fields_to_schema function raises an error when the user is trying to add a field
        with the same name but different mode as an existing field."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref,
                                            self.mock_table_id)
        schema_fields = [
            bigquery.SchemaField('fake_schema_field',
                                 'STRING',
                                 mode="NULLABLE")
        ]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField('fake_schema_field',
                                 'STRING',
                                 mode="REQUIRED")
        ]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields)

        self.mock_client.update_table.assert_not_called()

    def test_delete_table(self) -> None:
        """Tests that our delete table function calls the correct client method."""
        self.bq_client.delete_table(self.mock_dataset_id, self.mock_table_id)
        self.mock_client.delete_table.assert_called()

    @mock.patch('google.cloud.bigquery.QueryJob')
    def test_paged_read_single_page_single_row(
            self, mock_query_job: mock.MagicMock) -> None:
        first_row = bigquery.table.Row(
            ['parole', 15, '10N'],
            {
                'supervision_type': 0,
                'revocations': 1,
                'district': 2
            },
        )

        # First call returns a single row, second call returns nothing
        mock_query_job.result.side_effect = [[first_row], []]

        processed_results = []

        def _process_fn(row: bigquery.table.Row) -> None:
            processed_results.append(dict(row))

        self.bq_client.paged_read_and_process(mock_query_job, 1, _process_fn)

        self.assertEqual([dict(first_row)], processed_results)
        mock_query_job.result.assert_has_calls([
            call(max_results=1, start_index=0),
            call(max_results=1, start_index=1),
        ])

    @mock.patch('google.cloud.bigquery.QueryJob')
    def test_paged_read_single_page_multiple_rows(
            self, mock_query_job: mock.MagicMock) -> None:
        first_row = bigquery.table.Row(
            ['parole', 15, '10N'],
            {
                'supervision_type': 0,
                'revocations': 1,
                'district': 2
            },
        )
        second_row = bigquery.table.Row(
            ['probation', 7, '10N'],
            {
                'supervision_type': 0,
                'revocations': 1,
                'district': 2
            },
        )

        # First call returns a single row, second call returns nothing
        mock_query_job.result.side_effect = [[first_row, second_row], []]

        processed_results = []

        def _process_fn(row: bigquery.table.Row) -> None:
            processed_results.append(dict(row))

        self.bq_client.paged_read_and_process(mock_query_job, 10, _process_fn)

        self.assertEqual([dict(first_row), dict(second_row)],
                         processed_results)
        mock_query_job.result.assert_has_calls([
            call(max_results=10, start_index=0),
            call(max_results=10, start_index=2),
        ])

    @mock.patch('google.cloud.bigquery.QueryJob')
    def test_paged_read_multiple_pages(self,
                                       mock_query_job: mock.MagicMock) -> None:
        p1_r1 = bigquery.table.Row(
            ['parole', 15, '10N'],
            {
                'supervision_type': 0,
                'revocations': 1,
                'district': 2
            },
        )
        p1_r2 = bigquery.table.Row(
            ['probation', 7, '10N'],
            {
                'supervision_type': 0,
                'revocations': 1,
                'district': 2
            },
        )

        p2_r1 = bigquery.table.Row(
            ['parole', 8, '10F'],
            {
                'supervision_type': 0,
                'revocations': 1,
                'district': 2
            },
        )
        p2_r2 = bigquery.table.Row(
            ['probation', 3, '10F'],
            {
                'supervision_type': 0,
                'revocations': 1,
                'district': 2
            },
        )

        # First two calls returns results, third call returns nothing
        mock_query_job.result.side_effect = [[p1_r1, p1_r2], [p2_r1, p2_r2],
                                             []]

        processed_results = []

        def _process_fn(row: bigquery.table.Row) -> None:
            processed_results.append(dict(row))

        self.bq_client.paged_read_and_process(mock_query_job, 2, _process_fn)

        self.assertEqual(
            [dict(p1_r1), dict(p1_r2),
             dict(p2_r1), dict(p2_r2)], processed_results)
        mock_query_job.result.assert_has_calls([
            call(max_results=2, start_index=0),
            call(max_results=2, start_index=2),
            call(max_results=2, start_index=4),
        ])
Esempio n. 7
0
def compare_metric_view_output_to_sandbox(
    sandbox_dataset_prefix: str,
    load_sandbox_views: bool,
    check_determinism: bool,
    allow_schema_changes: bool,
    dataset_id_filters: Optional[List[str]],
) -> None:
    """Compares the output of all deployed metric views to the output of the corresponding views in the sandbox
    dataset."""
    if load_sandbox_views:
        logging.info(
            "Loading views into sandbox datasets prefixed with %s",
            sandbox_dataset_prefix,
        )
        load_views_to_sandbox(sandbox_dataset_prefix)

    bq_client = BigQueryClientImpl()
    sandbox_comparison_output_dataset_id = (sandbox_dataset_prefix +
                                            "_metric_view_comparison_output")
    sandbox_comparison_output_dataset_ref = bq_client.dataset_ref_for_id(
        sandbox_comparison_output_dataset_id)

    if bq_client.dataset_exists(sandbox_comparison_output_dataset_ref) and any(
            bq_client.list_tables(sandbox_comparison_output_dataset_id)):
        raise ValueError(
            f"Cannot write comparison output to a non-empty dataset. Please delete tables in dataset: "
            f"{bq_client.project_id}.{sandbox_comparison_output_dataset_id}.")

    bq_client.create_dataset_if_necessary(
        sandbox_comparison_output_dataset_ref,
        TEMP_DATASET_DEFAULT_TABLE_EXPIRATION_MS)

    query_jobs: List[Tuple[QueryJob, str]] = []
    skipped_views: List[str] = []

    for view_builders in VIEW_BUILDERS_BY_NAMESPACE.values():
        for view_builder in view_builders:
            # Only compare output of metric views
            if not isinstance(view_builder, MetricBigQueryViewBuilder):
                continue

            base_dataset_id = view_builder.dataset_id

            if dataset_id_filters and base_dataset_id not in dataset_id_filters:
                continue

            if view_builder in VIEW_BUILDERS_WITH_KNOWN_NOT_DETERMINISTIC_OUTPUT:
                logging.warning(
                    "View %s.%s has known non-deterministic output. Skipping output comparison.",
                    view_builder.dataset_id,
                    view_builder.view_id,
                )
                skipped_views.append(
                    f"{view_builder.dataset_id}.{view_builder.view_id}")
                continue

            sandbox_dataset_id = sandbox_dataset_prefix + "_" + base_dataset_id

            if not bq_client.dataset_exists(
                    bq_client.dataset_ref_for_id(sandbox_dataset_id)):
                raise ValueError(
                    f"Trying to compare output to a sandbox dataset that does not exist: "
                    f"{bq_client.project_id}.{sandbox_dataset_id}")

            base_dataset_ref = bq_client.dataset_ref_for_id(base_dataset_id)
            base_view_id = (view_builder.build().materialized_view_table_id
                            if view_builder.should_materialize
                            and not check_determinism else
                            view_builder.view_id)

            if not base_view_id:
                raise ValueError(
                    "Unexpected empty base_view_id. view_id or materialized_view_table_id unset"
                    f"for {view_builder}.")

            if not check_determinism and not bq_client.table_exists(
                    base_dataset_ref, base_view_id):
                logging.warning(
                    "View %s.%s does not exist. Skipping output comparison.",
                    base_dataset_ref.dataset_id,
                    base_view_id,
                )
                skipped_views.append(f"{base_dataset_id}.{base_view_id}")
                continue

            if not bq_client.table_exists(
                    bq_client.dataset_ref_for_id(sandbox_dataset_id),
                    base_view_id):
                logging.warning(
                    "View %s.%s does not exist in sandbox. Skipping output comparison.",
                    sandbox_dataset_id,
                    base_view_id,
                )
                skipped_views.append(f"{sandbox_dataset_id}.{base_view_id}")
                continue
            query_job, output_table_id = _view_output_comparison_job(
                bq_client,
                view_builder,
                base_view_id,
                base_dataset_id,
                sandbox_dataset_id,
                sandbox_comparison_output_dataset_id,
                check_determinism,
                allow_schema_changes,
            )

            # Add query job to the list of running jobs
            query_jobs.append((query_job, output_table_id))

    for query_job, output_table_id in query_jobs:
        # Wait for the insert job to complete before looking for the table
        query_job.result()

        output_table = bq_client.get_table(
            sandbox_comparison_output_dataset_ref, output_table_id)

        if output_table.num_rows == 0:
            # If there are no rows in the output table, then the view output was identical
            bq_client.delete_table(sandbox_comparison_output_dataset_id,
                                   output_table_id)

    views_with_different_output = bq_client.list_tables(
        sandbox_comparison_output_dataset_id)
    views_with_different_output = peekable(views_with_different_output)

    logging.info(
        "\n*************** METRIC VIEW OUTPUT RESULTS ***************\n")

    if dataset_id_filters:
        logging.info(
            "Only compared metric view output for the following datasets: \n %s \n",
            dataset_id_filters,
        )

    logging.info(
        "Skipped output comparison for the following metric views: \n %s \n",
        skipped_views,
    )

    if views_with_different_output:
        for view in views_with_different_output:
            base_dataset_id, base_view_id = view.table_id.split("--")

            logging.warning(
                "View output differs for view %s.%s. See %s.%s for diverging rows.",
                base_dataset_id,
                base_view_id,
                sandbox_comparison_output_dataset_id,
                view.table_id,
            )
    else:
        output_message = (
            "identical between deployed views and sandbox datasets"
            if not check_determinism else "deterministic")
        logging.info(
            "View output %s. Deleting dataset %s.",
            output_message,
            sandbox_comparison_output_dataset_ref.dataset_id,
        )
        bq_client.delete_dataset(sandbox_comparison_output_dataset_ref,
                                 delete_contents=True)
Esempio n. 8
0
class BigQueryClientImplTest(unittest.TestCase):
    """Tests for BigQueryClientImpl"""

    def setUp(self) -> None:
        self.location = "US"
        self.mock_project_id = "fake-recidiviz-project"
        self.mock_dataset_id = "fake-dataset"
        self.mock_table_id = "test_table"
        self.mock_dataset_ref = bigquery.dataset.DatasetReference(
            self.mock_project_id, self.mock_dataset_id
        )
        self.mock_table = self.mock_dataset_ref.table(self.mock_table_id)

        self.metadata_patcher = mock.patch("recidiviz.utils.metadata.project_id")
        self.mock_project_id_fn = self.metadata_patcher.start()
        self.mock_project_id_fn.return_value = self.mock_project_id

        self.client_patcher = mock.patch("recidiviz.big_query.big_query_client.client")
        self.mock_client = self.client_patcher.start().return_value

        self.mock_view = BigQueryView(
            dataset_id="dataset",
            view_id="test_view",
            view_query_template="SELECT NULL LIMIT 0",
            should_materialize=True,
        )

        self.bq_client = BigQueryClientImpl()

    def tearDown(self) -> None:
        self.client_patcher.stop()
        self.metadata_patcher.stop()

    def test_create_dataset_if_necessary(self) -> None:
        """Check that a dataset is created if it does not exist."""
        self.mock_client.get_dataset.side_effect = exceptions.NotFound("!")
        self.bq_client.create_dataset_if_necessary(self.mock_dataset_ref)
        self.mock_client.create_dataset.assert_called()

    def test_create_dataset_if_necessary_dataset_exists(self) -> None:
        """Check that a dataset is not created if it already exists."""
        self.mock_client.get_dataset.side_effect = None
        self.bq_client.create_dataset_if_necessary(self.mock_dataset_ref)
        self.mock_client.create_dataset.assert_not_called()

    def test_create_dataset_if_necessary_table_expiration(self) -> None:
        """Check that the dataset is created with a set table expiration if the dataset does not exist and the
        new_dataset_table_expiration_ms is specified."""
        self.mock_client.get_dataset.side_effect = exceptions.NotFound("!")
        self.bq_client.create_dataset_if_necessary(
            self.mock_dataset_ref, default_table_expiration_ms=6000
        )
        self.mock_client.create_dataset.assert_called()

    def test_table_exists(self) -> None:
        """Check that table_exists returns True if the table exists."""
        self.mock_client.get_table.side_effect = None
        self.assertTrue(
            self.bq_client.table_exists(self.mock_dataset_ref, self.mock_table_id)
        )

    def test_table_exists_does_not_exist(self) -> None:
        """Check that table_exists returns False if the table does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound("!")
        with self.assertLogs(level="WARNING"):
            table_exists = self.bq_client.table_exists(
                self.mock_dataset_ref, self.mock_table_id
            )
            self.assertFalse(table_exists)

    def test_create_or_update_view_creates_view(self) -> None:
        """create_or_update_view creates a View if it does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound("!")
        self.bq_client.create_or_update_view(self.mock_dataset_ref, self.mock_view)
        self.mock_client.create_table.assert_called()
        self.mock_client.update_table.assert_not_called()

    def test_create_or_update_view_updates_view(self) -> None:
        """create_or_update_view updates a View if it already exist."""
        self.mock_client.get_table.side_effect = None
        self.bq_client.create_or_update_view(self.mock_dataset_ref, self.mock_view)
        self.mock_client.update_table.assert_called()
        self.mock_client.create_table.assert_not_called()

    def test_export_to_cloud_storage(self) -> None:
        """export_to_cloud_storage extracts the table corresponding to the
        view."""
        self.assertIsNotNone(
            self.bq_client.export_table_to_cloud_storage_async(
                source_table_dataset_ref=self.mock_dataset_ref,
                source_table_id="source-table",
                destination_uri=f"gs://{self.mock_project_id}-bucket/destination_path.json",
                destination_format=bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON,
                print_header=True,
            )
        )
        self.mock_client.extract_table.assert_called()

    def test_export_to_cloud_storage_no_table(self) -> None:
        """export_to_cloud_storage does not extract from a table if the table
        does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound("!")
        with self.assertLogs(level="WARNING"):
            self.assertIsNone(
                self.bq_client.export_table_to_cloud_storage_async(
                    source_table_dataset_ref=self.mock_dataset_ref,
                    source_table_id="source-table",
                    destination_uri=f"gs://{self.mock_project_id}-bucket/destination_path.json",
                    destination_format=bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON,
                    print_header=True,
                )
            )
            self.mock_client.extract_table.assert_not_called()

    def test_load_table_async_create_dataset(self) -> None:
        """Test that load_table_from_cloud_storage_async tries to create a parent dataset."""

        self.mock_client.get_dataset.side_effect = exceptions.NotFound("!")

        self.bq_client.load_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset_ref,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[
                SchemaField("my_column", "STRING", "NULLABLE", None, ())
            ],
            source_uri="gs://bucket/export-uri",
        )

        self.mock_client.create_dataset.assert_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_load_table_async_dataset_exists(self) -> None:
        """Test that load_table_from_cloud_storage_async does not try to create a
        parent dataset if it already exists."""

        self.bq_client.load_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset_ref,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[
                SchemaField("my_column", "STRING", "NULLABLE", None, ())
            ],
            source_uri="gs://bucket/export-uri",
        )

        self.mock_client.create_dataset.assert_not_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_export_query_results_to_cloud_storage_no_table(self) -> None:
        bucket = self.mock_project_id + "-bucket"
        self.mock_client.get_table.side_effect = exceptions.NotFound("!")
        with self.assertLogs(level="WARNING"):
            self.bq_client.export_query_results_to_cloud_storage(
                [
                    ExportQueryConfig.from_view_query(
                        view=self.mock_view,
                        view_filter_clause="WHERE x = y",
                        intermediate_table_name=self.mock_table_id,
                        output_uri=f"gs://{bucket}/view.json",
                        output_format=bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON,
                    )
                ],
                print_header=True,
            )

    def test_export_query_results_to_cloud_storage(self) -> None:
        """export_query_results_to_cloud_storage creates the table from the view query and
        exports the table."""
        bucket = self.mock_project_id + "-bucket"
        query_job: futures.Future = futures.Future()
        query_job.set_result([])
        extract_job: futures.Future = futures.Future()
        extract_job.set_result(None)
        self.mock_client.query.return_value = query_job
        self.mock_client.extract_table.return_value = extract_job
        self.bq_client.export_query_results_to_cloud_storage(
            [
                ExportQueryConfig.from_view_query(
                    view=self.mock_view,
                    view_filter_clause="WHERE x = y",
                    intermediate_table_name=self.mock_table_id,
                    output_uri=f"gs://{bucket}/view.json",
                    output_format=bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON,
                )
            ],
            print_header=True,
        )
        self.mock_client.query.assert_called()
        self.mock_client.extract_table.assert_called()
        self.mock_client.delete_table.assert_called_with(
            bigquery.DatasetReference(
                self.mock_project_id, self.mock_view.dataset_id
            ).table(self.mock_table_id)
        )

    def test_create_table_from_query(self) -> None:
        """Tests that the create_table_from_query function calls the function to create a table from a query."""
        self.bq_client.create_table_from_query_async(
            self.mock_dataset_id,
            self.mock_table_id,
            query="SELECT * FROM some.fake.table",
            query_parameters=[],
        )
        self.mock_client.query.assert_called()

    @mock.patch("google.cloud.bigquery.job.QueryJobConfig")
    def test_insert_into_table_from_table_async(
        self, mock_job_config: mock.MagicMock
    ) -> None:
        """Tests that the insert_into_table_from_table_async function runs a query."""
        self.bq_client.insert_into_table_from_table_async(
            source_dataset_id=self.mock_dataset_id,
            source_table_id=self.mock_table_id,
            destination_dataset_id=self.mock_dataset_id,
            destination_table_id="fake_table_temp",
        )
        expected_query = f"SELECT * FROM `fake-recidiviz-project.{self.mock_dataset_id}.{self.mock_table_id}`"
        self.mock_client.get_table.assert_called()
        self.mock_client.query.assert_called_with(
            query=expected_query, location=self.location, job_config=mock_job_config()
        )

    @mock.patch("google.cloud.bigquery.job.QueryJobConfig")
    def test_insert_into_table_from_table_async_hydrate_missing_columns(
        self, mock_job_config: mock.MagicMock
    ) -> None:
        """Tests that the insert_into_table_from_table_async generates a query with missing columns as NULL."""
        with mock.patch(
            "recidiviz.big_query.big_query_client.BigQueryClientImpl"
            "._get_excess_schema_fields"
        ) as mock_missing:
            mock_missing.return_value = [
                bigquery.SchemaField("state_code", "STRING", "REQUIRED"),
                bigquery.SchemaField("new_column_name", "INTEGER", "REQUIRED"),
            ]
            self.mock_destination_id = "fake_table_temp"
            self.bq_client.insert_into_table_from_table_async(
                source_dataset_id=self.mock_dataset_id,
                source_table_id=self.mock_table_id,
                destination_dataset_id=self.mock_dataset_id,
                destination_table_id=self.mock_destination_id,
                hydrate_missing_columns_with_null=True,
                allow_field_additions=True,
            )
            expected_query = (
                "SELECT *, CAST(NULL AS STRING) AS state_code, CAST(NULL AS INTEGER) AS new_column_name "
                f"FROM `fake-recidiviz-project.{self.mock_dataset_id}.{self.mock_table_id}`"
            )
            self.mock_client.query.assert_called_with(
                query=expected_query,
                location=self.location,
                job_config=mock_job_config(),
            )

    def test_insert_into_table_from_table_invalid_destination(self) -> None:
        """Tests that the insert_into_table_from_table_async function does not run the query if the destination
        table does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound("!")

        with pytest.raises(ValueError):
            self.bq_client.insert_into_table_from_table_async(
                self.mock_dataset_id,
                self.mock_table_id,
                "fake_source_dataset_id",
                "fake_table_id",
            )
        self.mock_client.get_table.assert_called()
        self.mock_client.query.assert_not_called()

    def test_insert_into_table_from_table_invalid_filter_clause(self) -> None:
        """Tests that the insert_into_table_from_table_async function does not run the query if the filter clause
        does not start with a WHERE."""
        with pytest.raises(ValueError):
            self.bq_client.insert_into_table_from_table_async(
                self.mock_dataset_id,
                self.mock_table_id,
                "fake_source_dataset_id",
                "fake_table_id",
                source_data_filter_clause="bad filter clause",
            )
        self.mock_client.query.assert_not_called()

    @mock.patch("google.cloud.bigquery.job.QueryJobConfig")
    def test_insert_into_table_from_table_with_filter_clause(
        self, mock_job_config: mock.MagicMock
    ) -> None:
        """Tests that the insert_into_table_from_table_async generates a valid query when given a filter clause."""
        filter_clause = "WHERE state_code IN ('US_ND')"
        job_config = mock_job_config()
        self.bq_client.insert_into_table_from_table_async(
            self.mock_dataset_id,
            self.mock_table_id,
            "fake_source_dataset_id",
            "fake_table_id",
            source_data_filter_clause=filter_clause,
        )
        expected_query = (
            "SELECT * FROM `fake-recidiviz-project.fake-dataset.test_table` "
            "WHERE state_code IN ('US_ND')"
        )
        self.mock_client.query.assert_called_with(
            query=expected_query, location=self.location, job_config=job_config
        )

    def test_insert_into_table_from_cloud_storage_async(self) -> None:
        self.mock_client.get_dataset.side_effect = exceptions.NotFound("!")

        self.bq_client.insert_into_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset_ref,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[
                SchemaField("my_column", "STRING", "NULLABLE", None, ())
            ],
            source_uri="gs://bucket/export-uri",
        )

        self.mock_client.create_dataset.assert_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_delete_from_table(self) -> None:
        """Tests that the delete_from_table function runs a query."""
        self.bq_client.delete_from_table_async(
            self.mock_dataset_id, self.mock_table_id, filter_clause="WHERE x > y"
        )
        self.mock_client.query.assert_called()

    def test_delete_from_table_invalid_filter_clause(self) -> None:
        """Tests that the delete_from_table function does not run a query when the filter clause is invalid."""
        with pytest.raises(ValueError):
            self.bq_client.delete_from_table_async(
                self.mock_dataset_id, self.mock_table_id, filter_clause="x > y"
            )
        self.mock_client.query.assert_not_called()

    def test_materialize_view_to_table(self) -> None:
        """Tests that the materialize_view_to_table function calls the function to create a table from a query."""
        self.bq_client.materialize_view_to_table(self.mock_view)
        self.mock_client.query.assert_called()

    def test_materialize_view_to_table_no_materialized_view_table_id(self) -> None:
        """Tests that the materialize_view_to_table function does not call the function to create a table from a
        query if there is no set materialized_view_table_id on the view."""
        invalid_view = BigQueryView(
            dataset_id="dataset",
            view_id="test_view",
            view_query_template="SELECT NULL LIMIT 0",
            should_materialize=False,
        )

        with pytest.raises(ValueError):
            self.bq_client.materialize_view_to_table(invalid_view)
        self.mock_client.query.assert_not_called()

    def test_create_table_with_schema(self) -> None:
        """Tests that the create_table_with_schema function calls the create_table function on the client."""
        self.mock_client.get_table.side_effect = exceptions.NotFound("!")
        schema_fields = [bigquery.SchemaField("new_schema_field", "STRING")]

        self.bq_client.create_table_with_schema(
            self.mock_dataset_id, self.mock_table_id, schema_fields
        )
        self.mock_client.create_table.assert_called()

    def test_create_table_with_schema_table_exists(self) -> None:
        """Tests that the create_table_with_schema function raises an error when the table already exists."""
        self.mock_client.get_table.side_effect = None
        schema_fields = [bigquery.SchemaField("new_schema_field", "STRING")]

        with pytest.raises(ValueError):
            self.bq_client.create_table_with_schema(
                self.mock_dataset_id, self.mock_table_id, schema_fields
            )
        self.mock_client.create_table.assert_not_called()

    def test_add_missing_fields_to_schema(self) -> None:
        """Tests that the add_missing_fields_to_schema function calls the client to update the table."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [bigquery.SchemaField("fake_schema_field", "STRING")]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [bigquery.SchemaField("new_schema_field", "STRING")]

        self.bq_client.add_missing_fields_to_schema(
            self.mock_dataset_id, self.mock_table_id, new_schema_fields
        )

        self.mock_client.update_table.assert_called()

    def test_add_missing_fields_to_schema_no_missing_fields(self) -> None:
        """Tests that the add_missing_fields_to_schema function does not call the client to update the table when all
        of the fields are already present."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [bigquery.SchemaField("fake_schema_field", "STRING")]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [bigquery.SchemaField("fake_schema_field", "STRING")]

        self.bq_client.add_missing_fields_to_schema(
            self.mock_dataset_id, self.mock_table_id, new_schema_fields
        )

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_no_table(self) -> None:
        """Tests that the add_missing_fields_to_schema function does not call the client to update the table when the
        table does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound("!")
        new_schema_fields = [bigquery.SchemaField("fake_schema_field", "STRING")]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields
            )

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_fields_with_same_name_different_type(
        self,
    ) -> None:
        """Tests that the add_missing_fields_to_schema function raises an error when the user is trying to add a field
        with the same name but different field_type as an existing field."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [bigquery.SchemaField("fake_schema_field", "STRING")]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [bigquery.SchemaField("fake_schema_field", "INTEGER")]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields
            )

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_fields_with_same_name_different_mode(
        self,
    ) -> None:
        """Tests that the add_missing_fields_to_schema function raises an error when the user is trying to add a field
        with the same name but different mode as an existing field."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [
            bigquery.SchemaField("fake_schema_field", "STRING", mode="NULLABLE")
        ]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField("fake_schema_field", "STRING", mode="REQUIRED")
        ]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields
            )

        self.mock_client.update_table.assert_not_called()

    def test_remove_unused_fields_from_schema(self) -> None:
        """Tests that remove_unused_fields_from_schema() calls the client to update the table with a query."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [
            bigquery.SchemaField("field_1", "STRING"),
            bigquery.SchemaField("field_2", "STRING"),
        ]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [bigquery.SchemaField("field_1", "STRING")]

        self.bq_client.remove_unused_fields_from_schema(
            self.mock_dataset_id, self.mock_table_id, new_schema_fields
        )

        self.mock_client.query.assert_called()

    def test_remove_unused_fields_from_schema_no_missing_fields(self) -> None:
        """Tests that remove_unused_fields_from_schema() does nothing if there are no missing fields."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [bigquery.SchemaField("field_1", "STRING")]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [bigquery.SchemaField("field_1", "STRING")]

        self.bq_client.remove_unused_fields_from_schema(
            self.mock_dataset_id, self.mock_table_id, new_schema_fields
        )

        self.mock_client.query.assert_not_called()

    def test_remove_unused_fields_from_schema_ignore_excess_desired_fields(
        self,
    ) -> None:
        """Tests that remove_unused_fields_from_schema() drops columns even when there are excess desired fields."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [
            bigquery.SchemaField("field_1", "STRING"),
            bigquery.SchemaField("field_2", "STRING"),
        ]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField("field_1", "STRING"),
            bigquery.SchemaField("field_3", "STRING"),
        ]

        self.bq_client.remove_unused_fields_from_schema(
            self.mock_dataset_id, self.mock_table_id, new_schema_fields
        )

        self.mock_client.query.assert_called()

    @mock.patch(
        "recidiviz.big_query.big_query_client.BigQueryClientImpl.remove_unused_fields_from_schema"
    )
    @mock.patch(
        "recidiviz.big_query.big_query_client.BigQueryClientImpl.add_missing_fields_to_schema"
    )
    def test_update_schema(
        self, remove_unused_mock: mock.MagicMock, add_missing_mock: mock.MagicMock
    ) -> None:
        """Tests that update_schema() calls both field updaters if the inputs are valid."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [
            bigquery.SchemaField("field_1", "STRING"),
            bigquery.SchemaField("field_2", "STRING"),
        ]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField("field_1", "STRING"),
            bigquery.SchemaField("field_3", "STRING"),
        ]

        self.bq_client.update_schema(
            self.mock_dataset_id, self.mock_table_id, new_schema_fields
        )

        remove_unused_mock.assert_called()
        add_missing_mock.assert_called()

    @mock.patch(
        "recidiviz.big_query.big_query_client.BigQueryClientImpl.remove_unused_fields_from_schema"
    )
    @mock.patch(
        "recidiviz.big_query.big_query_client.BigQueryClientImpl.add_missing_fields_to_schema"
    )
    def test_update_schema_fails_on_changed_type(
        self, remove_unused_mock: mock.MagicMock, add_missing_mock: mock.MagicMock
    ) -> None:
        """Tests that update_schema() throws if we try to change a field type."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [
            bigquery.SchemaField("field_1", "STRING"),
            bigquery.SchemaField("field_2", "STRING"),
        ]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField("field_1", "STRING"),
            bigquery.SchemaField("field_2", "INT"),
        ]

        with pytest.raises(ValueError):
            self.bq_client.update_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields
            )

        remove_unused_mock.assert_not_called()
        add_missing_mock.assert_not_called()

    @mock.patch(
        "recidiviz.big_query.big_query_client.BigQueryClientImpl.remove_unused_fields_from_schema"
    )
    @mock.patch(
        "recidiviz.big_query.big_query_client.BigQueryClientImpl.add_missing_fields_to_schema"
    )
    def test_update_schema_fails_on_changed_mode(
        self, remove_unused_mock: mock.MagicMock, add_missing_mock: mock.MagicMock
    ) -> None:
        """Tests that update_schema() throws if we try to change a field mode."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [
            bigquery.SchemaField("field_1", "STRING", "NULLABLE"),
            bigquery.SchemaField("field_2", "STRING"),
        ]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField("field_1", "STRING", "REQUIRED"),
            bigquery.SchemaField("field_2", "INT"),
        ]

        with pytest.raises(ValueError):
            self.bq_client.update_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields
            )

        remove_unused_mock.assert_not_called()
        add_missing_mock.assert_not_called()

    def test__get_excess_schema_fields_simple_excess(self) -> None:
        """Tests _get_excess_schema_fields() when extended_schema is a strict superset of base_schema."""
        base_schema = [bigquery.SchemaField("field_1", "INT")]
        extended_schema = [
            bigquery.SchemaField("field_1", "INT"),
            bigquery.SchemaField("field_2", "INT"),
            bigquery.SchemaField("field_3", "INT"),
        ]

        excess_fields = BigQueryClientImpl._get_excess_schema_fields(
            base_schema, extended_schema
        )

        self.assertEqual(
            excess_fields,
            [
                bigquery.SchemaField("field_2", "INT"),
                bigquery.SchemaField("field_3", "INT"),
            ],
        )

    def test__get_excess_schema_fields_with_extra_base_schema(self) -> None:
        """Tests _get_excess_schema_fields() when base_schema has fields not in extended_schema."""
        base_schema = [
            bigquery.SchemaField("field_1", "INT"),
            bigquery.SchemaField("field_2", "INT"),
        ]
        extended_schema = [
            bigquery.SchemaField("field_1", "INT"),
            bigquery.SchemaField("field_3", "INT"),
            bigquery.SchemaField("field_4", "INT"),
        ]

        excess_fields = BigQueryClientImpl._get_excess_schema_fields(
            base_schema, extended_schema
        )

        self.assertEqual(
            excess_fields,
            [
                bigquery.SchemaField("field_3", "INT"),
                bigquery.SchemaField("field_4", "INT"),
            ],
        )

    def test__get_excess_schema_fields_with_matching_schema(self) -> None:
        """Tests _get_excess_schema_fields() when base_schema is the same as extended_schema."""
        base_schema = [
            bigquery.SchemaField("field_1", "INT"),
            bigquery.SchemaField("field_2", "INT"),
        ]

        excess_fields = BigQueryClientImpl._get_excess_schema_fields(
            base_schema, base_schema
        )

        self.assertListEqual(excess_fields, [])

    def test__get_excess_schema_fields_no_excess(self) -> None:
        """Tests _get_excess_schema_fields() when base_schema is a superset of extended_schema."""
        base_schema = [
            bigquery.SchemaField("field_1", "INT"),
            bigquery.SchemaField("field_2", "INT"),
        ]
        extended_schema = [bigquery.SchemaField("field_2", "INT")]

        excess_fields = BigQueryClientImpl._get_excess_schema_fields(
            base_schema, extended_schema
        )

        self.assertListEqual(excess_fields, [])

    def test_delete_table(self) -> None:
        """Tests that our delete table function calls the correct client method."""
        self.bq_client.delete_table(self.mock_dataset_id, self.mock_table_id)
        self.mock_client.delete_table.assert_called()

    @mock.patch("google.cloud.bigquery.QueryJob")
    def test_paged_read_single_page_single_row(
        self, mock_query_job: mock.MagicMock
    ) -> None:
        first_row = bigquery.table.Row(
            ["parole", 15, "10N"],
            {"supervision_type": 0, "revocations": 1, "district": 2},
        )

        # First call returns a single row, second call returns nothing
        mock_query_job.result.side_effect = [[first_row], []]

        processed_results = []

        def _process_fn(row: bigquery.table.Row) -> None:
            processed_results.append(dict(row))

        self.bq_client.paged_read_and_process(mock_query_job, 1, _process_fn)

        self.assertEqual([dict(first_row)], processed_results)
        mock_query_job.result.assert_has_calls(
            [
                call(max_results=1, start_index=0),
                call(max_results=1, start_index=1),
            ]
        )

    @mock.patch("google.cloud.bigquery.QueryJob")
    def test_paged_read_single_page_multiple_rows(
        self, mock_query_job: mock.MagicMock
    ) -> None:
        first_row = bigquery.table.Row(
            ["parole", 15, "10N"],
            {"supervision_type": 0, "revocations": 1, "district": 2},
        )
        second_row = bigquery.table.Row(
            ["probation", 7, "10N"],
            {"supervision_type": 0, "revocations": 1, "district": 2},
        )

        # First call returns a single row, second call returns nothing
        mock_query_job.result.side_effect = [[first_row, second_row], []]

        processed_results = []

        def _process_fn(row: bigquery.table.Row) -> None:
            processed_results.append(dict(row))

        self.bq_client.paged_read_and_process(mock_query_job, 10, _process_fn)

        self.assertEqual([dict(first_row), dict(second_row)], processed_results)
        mock_query_job.result.assert_has_calls(
            [
                call(max_results=10, start_index=0),
                call(max_results=10, start_index=2),
            ]
        )

    @mock.patch("google.cloud.bigquery.QueryJob")
    def test_paged_read_multiple_pages(self, mock_query_job: mock.MagicMock) -> None:
        p1_r1 = bigquery.table.Row(
            ["parole", 15, "10N"],
            {"supervision_type": 0, "revocations": 1, "district": 2},
        )
        p1_r2 = bigquery.table.Row(
            ["probation", 7, "10N"],
            {"supervision_type": 0, "revocations": 1, "district": 2},
        )

        p2_r1 = bigquery.table.Row(
            ["parole", 8, "10F"],
            {"supervision_type": 0, "revocations": 1, "district": 2},
        )
        p2_r2 = bigquery.table.Row(
            ["probation", 3, "10F"],
            {"supervision_type": 0, "revocations": 1, "district": 2},
        )

        # First two calls returns results, third call returns nothing
        mock_query_job.result.side_effect = [[p1_r1, p1_r2], [p2_r1, p2_r2], []]

        processed_results = []

        def _process_fn(row: bigquery.table.Row) -> None:
            processed_results.append(dict(row))

        self.bq_client.paged_read_and_process(mock_query_job, 2, _process_fn)

        self.assertEqual(
            [dict(p1_r1), dict(p1_r2), dict(p2_r1), dict(p2_r2)], processed_results
        )
        mock_query_job.result.assert_has_calls(
            [
                call(max_results=2, start_index=0),
                call(max_results=2, start_index=2),
                call(max_results=2, start_index=4),
            ]
        )