def move_old_dataflow_metrics_to_cold_storage():
    """Moves old output in Dataflow metrics tables to tables in a cold storage dataset. We only keep the
    MAX_DAYS_IN_DATAFLOW_METRICS_TABLE days worth of data in a Dataflow metric table at once. All other
    output is moved to cold storage.
    """
    bq_client = BigQueryClientImpl()
    dataflow_metrics_dataset = DATAFLOW_METRICS_DATASET
    cold_storage_dataset = DATAFLOW_METRICS_COLD_STORAGE_DATASET
    dataflow_metrics_tables = bq_client.list_tables(dataflow_metrics_dataset)

    for table_ref in dataflow_metrics_tables:
        table_id = table_ref.table_id

        filter_clause = """WHERE created_on NOT IN
                          (SELECT DISTINCT created_on FROM `{project_id}.{dataflow_metrics_dataset}.{table_id}` 
                          ORDER BY created_on DESC
                          LIMIT {day_count_limit})""".format(
            project_id=table_ref.project,
            dataflow_metrics_dataset=table_ref.dataset_id,
            table_id=table_ref.table_id,
            day_count_limit=MAX_DAYS_IN_DATAFLOW_METRICS_TABLE)

        cold_storage_dataset_ref = bq_client.dataset_ref_for_id(
            cold_storage_dataset)

        if bq_client.table_exists(cold_storage_dataset_ref, table_id):
            # Move data from the Dataflow metrics dataset into the cold storage dataset
            insert_job = bq_client.insert_into_table_from_table_async(
                source_dataset_id=dataflow_metrics_dataset,
                source_table_id=table_id,
                destination_dataset_id=cold_storage_dataset,
                destination_table_id=table_id,
                source_data_filter_clause=filter_clause,
                allow_field_additions=True)

            # Wait for the insert job to complete before running the delete job
            insert_job.result()
        else:
            # This table doesn't yet exist in cold storage. Create it.
            table_query = f"SELECT * FROM `{bq_client.project_id}.{dataflow_metrics_dataset}.{table_id}` " \
                          f"{filter_clause}"

            create_job = bq_client.create_table_from_query_async(
                cold_storage_dataset,
                table_id,
                table_query,
                query_parameters=[])

            # Wait for the create job to complete before running the delete job
            create_job.result()

        # Delete that data from the Dataflow dataset
        delete_job = bq_client.delete_from_table_async(
            dataflow_metrics_dataset, table_ref.table_id, filter_clause)

        # Wait for the delete job to complete before moving on
        delete_job.result()
예제 #2
0
def compare_dataflow_output_to_sandbox(
    sandbox_dataset_prefix: str,
    job_name_to_compare: str,
    base_output_job_id: str,
    sandbox_output_job_id: str,
    additional_columns_to_compare: List[str],
    allow_overwrite: bool = False,
) -> None:
    """Compares the output for all metrics produced by the daily pipeline job with the given |job_name_to_compare|
    between the output from the |base_output_job_id| job in the dataflow_metrics dataset and the output from the
    |sandbox_output_job_id| job in the sandbox dataflow dataset."""
    bq_client = BigQueryClientImpl()
    sandbox_dataflow_dataset_id = (sandbox_dataset_prefix + "_" +
                                   DATAFLOW_METRICS_DATASET)

    sandbox_comparison_output_dataset_id = (sandbox_dataset_prefix +
                                            "_dataflow_comparison_output")
    sandbox_comparison_output_dataset_ref = bq_client.dataset_ref_for_id(
        sandbox_comparison_output_dataset_id)

    if bq_client.dataset_exists(sandbox_comparison_output_dataset_ref) and any(
            bq_client.list_tables(sandbox_comparison_output_dataset_id)):
        if not allow_overwrite:
            if __name__ == "__main__":
                logging.error(
                    "Dataset %s already exists in project %s. To overwrite, set --allow_overwrite.",
                    sandbox_comparison_output_dataset_id,
                    bq_client.project_id,
                )
                sys.exit(1)
            else:
                raise ValueError(
                    f"Cannot write comparison output to a non-empty dataset. Please delete tables in dataset: "
                    f"{bq_client.project_id}.{sandbox_comparison_output_dataset_id}."
                )
        else:
            # Clean up the existing tables in the dataset
            for table in bq_client.list_tables(
                    sandbox_comparison_output_dataset_id):
                bq_client.delete_table(table.dataset_id, table.table_id)

    bq_client.create_dataset_if_necessary(
        sandbox_comparison_output_dataset_ref,
        TEMP_DATASET_DEFAULT_TABLE_EXPIRATION_MS)

    query_jobs: List[Tuple[QueryJob, str]] = []

    pipelines = YAMLDict.from_path(PRODUCTION_TEMPLATES_PATH).pop_dicts(
        "daily_pipelines")

    for pipeline in pipelines:
        if pipeline.pop("job_name", str) == job_name_to_compare:
            pipeline_metric_types = pipeline.peek_optional("metric_types", str)

            if not pipeline_metric_types:
                raise ValueError(
                    f"Pipeline job {job_name_to_compare} missing required metric_types attribute."
                )

            metric_types_for_comparison = pipeline_metric_types.split()

            for metric_class, metric_table in DATAFLOW_METRICS_TO_TABLES.items(
            ):
                metric_type_value = DATAFLOW_TABLES_TO_METRIC_TYPES[
                    metric_table].value

                if metric_type_value in metric_types_for_comparison:
                    comparison_query = _query_for_metric_comparison(
                        bq_client,
                        base_output_job_id,
                        sandbox_output_job_id,
                        sandbox_dataflow_dataset_id,
                        metric_class,
                        metric_table,
                        additional_columns_to_compare,
                    )

                    query_job = bq_client.create_table_from_query_async(
                        dataset_id=sandbox_comparison_output_dataset_id,
                        table_id=metric_table,
                        query=comparison_query,
                        overwrite=True,
                    )

                    # Add query job to the list of running jobs
                    query_jobs.append((query_job, metric_table))

    for query_job, output_table_id in query_jobs:
        # Wait for the insert job to complete before looking for the table
        query_job.result()

        output_table = bq_client.get_table(
            sandbox_comparison_output_dataset_ref, output_table_id)

        if output_table.num_rows == 0:
            # If there are no rows in the output table, then the output was identical
            bq_client.delete_table(sandbox_comparison_output_dataset_id,
                                   output_table_id)

    metrics_with_different_output = peekable(
        bq_client.list_tables(sandbox_comparison_output_dataset_id))

    logging.info(
        "\n*************** DATAFLOW OUTPUT COMPARISON RESULTS ***************\n"
    )

    if metrics_with_different_output:
        for metric_table in metrics_with_different_output:
            # This will always be true, and is here to silence mypy warnings
            assert isinstance(metric_table, bigquery.table.TableListItem)

            logging.warning(
                "Dataflow output differs for metric %s. See %s.%s for diverging rows.",
                metric_table.table_id,
                sandbox_comparison_output_dataset_id,
                metric_table.table_id,
            )
    else:
        logging.info(
            "Dataflow output identical. Deleting dataset %s.",
            sandbox_comparison_output_dataset_ref.dataset_id,
        )
        bq_client.delete_dataset(sandbox_comparison_output_dataset_ref,
                                 delete_contents=True)
예제 #3
0
class BigQueryClientImplTest(unittest.TestCase):
    """Tests for BigQueryClientImpl"""

    def setUp(self):
        self.mock_project_id = 'fake-recidiviz-project'
        self.mock_dataset_id = 'fake-dataset'
        self.mock_table_id = 'test_table'
        self.mock_dataset = bigquery.dataset.DatasetReference(
            self.mock_project_id, self.mock_dataset_id)
        self.mock_table = self.mock_dataset.table(self.mock_table_id)

        self.metadata_patcher = mock.patch('recidiviz.utils.metadata.project_id')
        self.mock_project_id_fn = self.metadata_patcher.start()
        self.mock_project_id_fn.return_value = self.mock_project_id

        self.client_patcher = mock.patch(
            'recidiviz.big_query.big_query_client.client')
        self.mock_client = self.client_patcher.start().return_value

        self.mock_view = BigQueryView(
            dataset_id='dataset',
            view_id='test_view',
            view_query_template='SELECT NULL LIMIT 0',
            materialized_view_table_id='test_view_table'
        )

        self.bq_client = BigQueryClientImpl()

    def tearDown(self):
        self.client_patcher.stop()
        self.metadata_patcher.stop()

    def test_create_dataset_if_necessary(self):
        """Check that a dataset is created if it does not exist."""
        self.mock_client.get_dataset.side_effect = exceptions.NotFound('!')
        self.bq_client.create_dataset_if_necessary(self.mock_dataset)
        self.mock_client.create_dataset.assert_called()

    def test_create_dataset_if_necessary_dataset_exists(self):
        """Check that a dataset is not created if it already exists."""
        self.mock_client.get_dataset.side_effect = None
        self.bq_client.create_dataset_if_necessary(self.mock_dataset)
        self.mock_client.create_dataset.assert_not_called()

    def test_table_exists(self):
        """Check that table_exists returns True if the table exists."""
        self.mock_client.get_table.side_effect = None
        self.assertTrue(
            self.bq_client.table_exists(self.mock_dataset, self.mock_table_id))

    def test_table_exists_does_not_exist(self):
        """Check that table_exists returns False if the table does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        with self.assertLogs(level='WARNING'):
            table_exists = self.bq_client.table_exists(
                self.mock_dataset, self.mock_table_id)
            self.assertFalse(table_exists)

    def test_create_or_update_view_creates_view(self):
        """create_or_update_view creates a View if it does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        self.bq_client.create_or_update_view(self.mock_dataset, self.mock_view)
        self.mock_client.create_table.assert_called()
        self.mock_client.update_table.assert_not_called()

    def test_create_or_update_view_updates_view(self):
        """create_or_update_view updates a View if it already exist."""
        self.mock_client.get_table.side_effect = None
        self.bq_client.create_or_update_view(self.mock_dataset, self.mock_view)
        self.mock_client.update_table.assert_called()
        self.mock_client.create_table.assert_not_called()

    def test_export_to_cloud_storage(self):
        """export_to_cloud_storage extracts the table corresponding to the
        view."""
        self.assertIsNotNone(self.bq_client.export_table_to_cloud_storage_async(
            source_table_dataset_ref=self.mock_dataset,
            source_table_id='source-table',
            destination_uri=f'gs://{self.mock_project_id}-bucket/destination_path.json',
            destination_format=bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON
        ))
        self.mock_client.extract_table.assert_called()

    def test_export_to_cloud_storage_no_table(self):
        """export_to_cloud_storage does not extract from a table if the table
        does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        with self.assertLogs(level='WARNING'):
            self.assertIsNone(self.bq_client.export_table_to_cloud_storage_async(
                source_table_dataset_ref=self.mock_dataset,
                source_table_id='source-table',
                destination_uri=f'gs://{self.mock_project_id}-bucket/destination_path.json',
                destination_format=bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON))
            self.mock_client.extract_table.assert_not_called()

    def test_load_table_async_create_dataset(self):
        """Test that load_table_from_cloud_storage_async tries to create a parent dataset."""

        self.mock_client.get_dataset.side_effect = exceptions.NotFound('!')

        self.bq_client.load_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[SchemaField('my_column', 'STRING', 'NULLABLE', None, ())],
            source_uri='gs://bucket/export-uri')

        self.mock_client.create_dataset.assert_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_load_table_async_dataset_exists(self):
        """Test that load_table_from_cloud_storage_async does not try to create a parent dataset if it already exists.
        """

        self.bq_client.load_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[SchemaField('my_column', 'STRING', 'NULLABLE', None, ())],
            source_uri='gs://bucket/export-uri')

        self.mock_client.create_dataset.assert_not_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_export_query_results_to_cloud_storage_no_table(self):
        bucket = self.mock_project_id + '-bucket'
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        with self.assertLogs(level='WARNING'):
            self.bq_client.export_query_results_to_cloud_storage([
                ExportQueryConfig.from_view_query(
                    view=self.mock_view,
                    view_filter_clause='WHERE x = y',
                    intermediate_table_name=self.mock_table_id,
                    output_uri=f'gs://{bucket}/view.json',
                    output_format=bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON)
            ])

    def test_export_query_results_to_cloud_storage(self):
        """export_query_results_to_cloud_storage creates the table from the view query and
        exports the table."""
        bucket = self.mock_project_id + '-bucket'
        query_job = futures.Future()
        query_job.set_result([])
        extract_job = futures.Future()
        extract_job.set_result(None)
        self.mock_client.query.return_value = query_job
        self.mock_client.extract_table.return_value = extract_job
        self.bq_client.export_query_results_to_cloud_storage([
            ExportQueryConfig.from_view_query(
                view=self.mock_view,
                view_filter_clause='WHERE x = y',
                intermediate_table_name=self.mock_table_id,
                output_uri=f'gs://{bucket}/view.json',
                output_format=bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON)
            ])
        self.mock_client.query.assert_called()
        self.mock_client.extract_table.assert_called()
        self.mock_client.delete_table.assert_called_with(
            bigquery.DatasetReference(self.mock_project_id, self.mock_view.dataset_id).table(self.mock_table_id))

    def test_create_table_from_query(self):
        """Tests that the create_table_from_query function calls the function to create a table from a query."""
        self.bq_client.create_table_from_query_async(self.mock_dataset_id, self.mock_table_id,
                                                     query="SELECT * FROM some.fake.table",
                                                     query_parameters=[])
        self.mock_client.query.assert_called()

    def test_insert_into_table_from_table(self):
        """Tests that the insert_into_table_from_table function runs a query."""
        self.bq_client.insert_into_table_from_table_async('fake_source_dataset_id', 'fake_table_id',
                                                          self.mock_dataset_id, self.mock_table_id)
        self.mock_client.get_table.assert_called()
        self.mock_client.query.assert_called()

    def test_insert_into_table_from_table_invalid_destination(self):
        """Tests that the insert_into_table_from_table function does not run the query if the destination table does
        not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')

        with pytest.raises(ValueError):
            self.bq_client.insert_into_table_from_table_async(self.mock_dataset_id, self.mock_table_id,
                                                              'fake_source_dataset_id', 'fake_table_id')
        self.mock_client.get_table.assert_called()
        self.mock_client.query.assert_not_called()

    def test_insert_into_table_from_cloud_storage_async(self):
        self.mock_client.get_dataset.side_effect = exceptions.NotFound('!')

        self.bq_client.insert_into_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[SchemaField('my_column', 'STRING', 'NULLABLE', None, ())],
            source_uri='gs://bucket/export-uri')

        self.mock_client.create_dataset.assert_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_delete_from_table(self):
        """Tests that the delete_from_table function runs a query."""
        self.bq_client.delete_from_table_async(self.mock_dataset_id, self.mock_table_id, filter_clause="WHERE x > y")
        self.mock_client.query.assert_called()

    def test_delete_from_table_invalid_filter_clause(self):
        """Tests that the delete_from_table function does not run a query when the filter clause is invalid."""
        with pytest.raises(ValueError):
            self.bq_client.delete_from_table_async(self.mock_dataset_id, self.mock_table_id, filter_clause="x > y")
        self.mock_client.query.assert_not_called()

    def test_materialize_view_to_table(self):
        """Tests that the materialize_view_to_table function calls the function to create a table from a query."""
        self.bq_client.materialize_view_to_table(self.mock_view)
        self.mock_client.query.assert_called()

    def test_materialize_view_to_table_no_materialized_view_table_id(self):
        """Tests that the materialize_view_to_table function does not call the function to create a table from a
        query if there is no set materialized_view_table_id on the view."""
        invalid_view = BigQueryView(
            dataset_id='dataset',
            view_id='test_view',
            view_query_template='SELECT NULL LIMIT 0',
            materialized_view_table_id=None
        )

        with pytest.raises(ValueError):
            self.bq_client.materialize_view_to_table(invalid_view)
        self.mock_client.query.assert_not_called()

    def test_create_table_with_schema(self):
        """Tests that the create_table_with_schema function calls the create_table function on the client."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        schema_fields = [bigquery.SchemaField('new_schema_field', 'STRING')]

        self.bq_client.create_table_with_schema(self.mock_dataset_id, self.mock_table_id, schema_fields)
        self.mock_client.create_table.assert_called()

    def test_create_table_with_schema_table_exists(self):
        """Tests that the create_table_with_schema function raises an error when the table already exists."""
        self.mock_client.get_table.side_effect = None
        schema_fields = [bigquery.SchemaField('new_schema_field', 'STRING')]

        with pytest.raises(ValueError):
            self.bq_client.create_table_with_schema(self.mock_dataset_id, self.mock_table_id, schema_fields)
        self.mock_client.create_table.assert_not_called()

    def test_add_missing_fields_to_schema(self):
        """Tests that the add_missing_fields_to_schema function calls the client to update the table."""
        table_ref = bigquery.TableReference(self.mock_dataset, self.mock_table_id)
        schema_fields = [bigquery.SchemaField('fake_schema_field', 'STRING')]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [bigquery.SchemaField('new_schema_field', 'STRING')]

        self.bq_client.add_missing_fields_to_schema(self.mock_dataset_id, self.mock_table_id, new_schema_fields)

        self.mock_client.update_table.assert_called()

    def test_add_missing_fields_to_schema_no_missing_fields(self):
        """Tests that the add_missing_fields_to_schema function does not call the client to update the table when all
        of the fields are already present."""
        table_ref = bigquery.TableReference(self.mock_dataset, self.mock_table_id)
        schema_fields = [bigquery.SchemaField('fake_schema_field', 'STRING')]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [bigquery.SchemaField('fake_schema_field', 'STRING')]

        self.bq_client.add_missing_fields_to_schema(self.mock_dataset_id, self.mock_table_id, new_schema_fields)

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_no_table(self):
        """Tests that the add_missing_fields_to_schema function does not call the client to update the table when the
        table does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        new_schema_fields = [bigquery.SchemaField('fake_schema_field', 'STRING')]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(self.mock_dataset_id, self.mock_table_id, new_schema_fields)

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_fields_with_same_name_different_type(self):
        """Tests that the add_missing_fields_to_schema function raises an error when the user is trying to add a field
        with the same name but different field_type as an existing field."""
        table_ref = bigquery.TableReference(self.mock_dataset, self.mock_table_id)
        schema_fields = [bigquery.SchemaField('fake_schema_field', 'STRING')]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [bigquery.SchemaField('fake_schema_field', 'INTEGER')]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(self.mock_dataset_id, self.mock_table_id, new_schema_fields)

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_fields_with_same_name_different_mode(self):
        """Tests that the add_missing_fields_to_schema function raises an error when the user is trying to add a field
        with the same name but different mode as an existing field."""
        table_ref = bigquery.TableReference(self.mock_dataset, self.mock_table_id)
        schema_fields = [bigquery.SchemaField('fake_schema_field', 'STRING', mode="NULLABLE")]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [bigquery.SchemaField('fake_schema_field', 'STRING', mode="REQUIRED")]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(self.mock_dataset_id, self.mock_table_id, new_schema_fields)

        self.mock_client.update_table.assert_not_called()
예제 #4
0
class BigQueryClientImplTest(unittest.TestCase):
    """Tests for BigQueryClientImpl"""
    def setUp(self) -> None:
        self.location = 'US'
        self.mock_project_id = 'fake-recidiviz-project'
        self.mock_dataset_id = 'fake-dataset'
        self.mock_table_id = 'test_table'
        self.mock_dataset_ref = bigquery.dataset.DatasetReference(
            self.mock_project_id, self.mock_dataset_id)
        self.mock_table = self.mock_dataset_ref.table(self.mock_table_id)

        self.metadata_patcher = mock.patch(
            'recidiviz.utils.metadata.project_id')
        self.mock_project_id_fn = self.metadata_patcher.start()
        self.mock_project_id_fn.return_value = self.mock_project_id

        self.client_patcher = mock.patch(
            'recidiviz.big_query.big_query_client.client')
        self.mock_client = self.client_patcher.start().return_value

        self.mock_view = BigQueryView(
            dataset_id='dataset',
            view_id='test_view',
            view_query_template='SELECT NULL LIMIT 0',
            should_materialize=True)

        self.bq_client = BigQueryClientImpl()

    def tearDown(self) -> None:
        self.client_patcher.stop()
        self.metadata_patcher.stop()

    def test_create_dataset_if_necessary(self) -> None:
        """Check that a dataset is created if it does not exist."""
        self.mock_client.get_dataset.side_effect = exceptions.NotFound('!')
        self.bq_client.create_dataset_if_necessary(self.mock_dataset_ref)
        self.mock_client.create_dataset.assert_called()

    def test_create_dataset_if_necessary_dataset_exists(self) -> None:
        """Check that a dataset is not created if it already exists."""
        self.mock_client.get_dataset.side_effect = None
        self.bq_client.create_dataset_if_necessary(self.mock_dataset_ref)
        self.mock_client.create_dataset.assert_not_called()

    def test_create_dataset_if_necessary_table_expiration(self) -> None:
        """Check that the dataset is created with a set table expiration if the dataset does not exist and the
        new_dataset_table_expiration_ms is specified."""
        self.mock_client.get_dataset.side_effect = exceptions.NotFound('!')
        self.bq_client.create_dataset_if_necessary(
            self.mock_dataset_ref, default_table_expiration_ms=6000)
        self.mock_client.create_dataset.assert_called()

    def test_table_exists(self) -> None:
        """Check that table_exists returns True if the table exists."""
        self.mock_client.get_table.side_effect = None
        self.assertTrue(
            self.bq_client.table_exists(self.mock_dataset_ref,
                                        self.mock_table_id))

    def test_table_exists_does_not_exist(self) -> None:
        """Check that table_exists returns False if the table does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        with self.assertLogs(level='WARNING'):
            table_exists = self.bq_client.table_exists(self.mock_dataset_ref,
                                                       self.mock_table_id)
            self.assertFalse(table_exists)

    def test_create_or_update_view_creates_view(self) -> None:
        """create_or_update_view creates a View if it does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        self.bq_client.create_or_update_view(self.mock_dataset_ref,
                                             self.mock_view)
        self.mock_client.create_table.assert_called()
        self.mock_client.update_table.assert_not_called()

    def test_create_or_update_view_updates_view(self) -> None:
        """create_or_update_view updates a View if it already exist."""
        self.mock_client.get_table.side_effect = None
        self.bq_client.create_or_update_view(self.mock_dataset_ref,
                                             self.mock_view)
        self.mock_client.update_table.assert_called()
        self.mock_client.create_table.assert_not_called()

    def test_export_to_cloud_storage(self) -> None:
        """export_to_cloud_storage extracts the table corresponding to the
        view."""
        self.assertIsNotNone(
            self.bq_client.export_table_to_cloud_storage_async(
                source_table_dataset_ref=self.mock_dataset_ref,
                source_table_id='source-table',
                destination_uri=
                f'gs://{self.mock_project_id}-bucket/destination_path.json',
                destination_format=bigquery.DestinationFormat.
                NEWLINE_DELIMITED_JSON))
        self.mock_client.extract_table.assert_called()

    def test_export_to_cloud_storage_no_table(self) -> None:
        """export_to_cloud_storage does not extract from a table if the table
        does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        with self.assertLogs(level='WARNING'):
            self.assertIsNone(
                self.bq_client.export_table_to_cloud_storage_async(
                    source_table_dataset_ref=self.mock_dataset_ref,
                    source_table_id='source-table',
                    destination_uri=
                    f'gs://{self.mock_project_id}-bucket/destination_path.json',
                    destination_format=bigquery.DestinationFormat.
                    NEWLINE_DELIMITED_JSON))
            self.mock_client.extract_table.assert_not_called()

    def test_load_table_async_create_dataset(self) -> None:
        """Test that load_table_from_cloud_storage_async tries to create a parent dataset."""

        self.mock_client.get_dataset.side_effect = exceptions.NotFound('!')

        self.bq_client.load_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset_ref,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[
                SchemaField('my_column', 'STRING', 'NULLABLE', None, ())
            ],
            source_uri='gs://bucket/export-uri')

        self.mock_client.create_dataset.assert_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_load_table_async_dataset_exists(self) -> None:
        """Test that load_table_from_cloud_storage_async does not try to create a parent dataset if it already exists.
        """

        self.bq_client.load_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset_ref,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[
                SchemaField('my_column', 'STRING', 'NULLABLE', None, ())
            ],
            source_uri='gs://bucket/export-uri')

        self.mock_client.create_dataset.assert_not_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_export_query_results_to_cloud_storage_no_table(self) -> None:
        bucket = self.mock_project_id + '-bucket'
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        with self.assertLogs(level='WARNING'):
            self.bq_client.export_query_results_to_cloud_storage([
                ExportQueryConfig.from_view_query(
                    view=self.mock_view,
                    view_filter_clause='WHERE x = y',
                    intermediate_table_name=self.mock_table_id,
                    output_uri=f'gs://{bucket}/view.json',
                    output_format=bigquery.DestinationFormat.
                    NEWLINE_DELIMITED_JSON)
            ])

    def test_export_query_results_to_cloud_storage(self) -> None:
        """export_query_results_to_cloud_storage creates the table from the view query and
        exports the table."""
        bucket = self.mock_project_id + '-bucket'
        query_job: futures.Future = futures.Future()
        query_job.set_result([])
        extract_job: futures.Future = futures.Future()
        extract_job.set_result(None)
        self.mock_client.query.return_value = query_job
        self.mock_client.extract_table.return_value = extract_job
        self.bq_client.export_query_results_to_cloud_storage([
            ExportQueryConfig.from_view_query(
                view=self.mock_view,
                view_filter_clause='WHERE x = y',
                intermediate_table_name=self.mock_table_id,
                output_uri=f'gs://{bucket}/view.json',
                output_format=bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON
            )
        ])
        self.mock_client.query.assert_called()
        self.mock_client.extract_table.assert_called()
        self.mock_client.delete_table.assert_called_with(
            bigquery.DatasetReference(self.mock_project_id,
                                      self.mock_view.dataset_id).table(
                                          self.mock_table_id))

    def test_create_table_from_query(self) -> None:
        """Tests that the create_table_from_query function calls the function to create a table from a query."""
        self.bq_client.create_table_from_query_async(
            self.mock_dataset_id,
            self.mock_table_id,
            query="SELECT * FROM some.fake.table",
            query_parameters=[])
        self.mock_client.query.assert_called()

    @mock.patch('google.cloud.bigquery.job.QueryJobConfig')
    def test_insert_into_table_from_table_async(
            self, mock_job_config: mock.MagicMock) -> None:
        """Tests that the insert_into_table_from_table_async function runs a query."""
        self.bq_client.insert_into_table_from_table_async(
            source_dataset_id=self.mock_dataset_id,
            source_table_id=self.mock_table_id,
            destination_dataset_id=self.mock_dataset_id,
            destination_table_id='fake_table_temp')
        expected_query = f"SELECT * FROM `fake-recidiviz-project.{self.mock_dataset_id}.{self.mock_table_id}`"
        self.mock_client.get_table.assert_called()
        self.mock_client.query.assert_called_with(query=expected_query,
                                                  location=self.location,
                                                  job_config=mock_job_config())

    @mock.patch('google.cloud.bigquery.job.QueryJobConfig')
    def test_insert_into_table_from_table_async_hydrate_missing_columns(
            self, mock_job_config: mock.MagicMock) -> None:
        """Tests that the insert_into_table_from_table_async generates a query with missing columns as NULL."""
        with mock.patch(
                'recidiviz.big_query.big_query_client.BigQueryClientImpl'
                '._get_schema_fields_missing_from_table') as mock_missing:
            mock_missing.return_value = [
                bigquery.SchemaField('state_code', 'STRING', 'REQUIRED'),
                bigquery.SchemaField('new_column_name', 'INTEGER', 'REQUIRED')
            ]
            self.mock_destination_id = 'fake_table_temp'
            self.bq_client.insert_into_table_from_table_async(
                source_dataset_id=self.mock_dataset_id,
                source_table_id=self.mock_table_id,
                destination_dataset_id=self.mock_dataset_id,
                destination_table_id=self.mock_destination_id,
                hydrate_missing_columns_with_null=True,
                allow_field_additions=True)
            expected_query = "SELECT *, CAST(NULL AS STRING) AS state_code, CAST(NULL AS INTEGER) AS new_column_name " \
                             f"FROM `fake-recidiviz-project.{self.mock_dataset_id}.{self.mock_table_id}`"
            self.mock_client.query.assert_called_with(
                query=expected_query,
                location=self.location,
                job_config=mock_job_config())

    def test_insert_into_table_from_table_invalid_destination(self) -> None:
        """Tests that the insert_into_table_from_table_async function does not run the query if the destination
        table does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')

        with pytest.raises(ValueError):
            self.bq_client.insert_into_table_from_table_async(
                self.mock_dataset_id, self.mock_table_id,
                'fake_source_dataset_id', 'fake_table_id')
        self.mock_client.get_table.assert_called()
        self.mock_client.query.assert_not_called()

    def test_insert_into_table_from_table_invalid_filter_clause(self) -> None:
        """Tests that the insert_into_table_from_table_async function does not run the query if the filter clause
        does not start with a WHERE."""
        with pytest.raises(ValueError):
            self.bq_client.insert_into_table_from_table_async(
                self.mock_dataset_id,
                self.mock_table_id,
                'fake_source_dataset_id',
                'fake_table_id',
                source_data_filter_clause='bad filter clause')
        self.mock_client.query.assert_not_called()

    @mock.patch('google.cloud.bigquery.job.QueryJobConfig')
    def test_insert_into_table_from_table_with_filter_clause(
            self, mock_job_config: mock.MagicMock) -> None:
        """Tests that the insert_into_table_from_table_async generates a valid query when given a filter clause."""
        filter_clause = "WHERE state_code IN ('US_ND')"
        job_config = mock_job_config()
        self.bq_client.insert_into_table_from_table_async(
            self.mock_dataset_id,
            self.mock_table_id,
            'fake_source_dataset_id',
            'fake_table_id',
            source_data_filter_clause=filter_clause)
        expected_query = "SELECT * FROM `fake-recidiviz-project.fake-dataset.test_table` " \
                         "WHERE state_code IN ('US_ND')"
        self.mock_client.query.assert_called_with(query=expected_query,
                                                  location=self.location,
                                                  job_config=job_config)

    def test_insert_into_table_from_cloud_storage_async(self) -> None:
        self.mock_client.get_dataset.side_effect = exceptions.NotFound('!')

        self.bq_client.insert_into_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset_ref,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[
                SchemaField('my_column', 'STRING', 'NULLABLE', None, ())
            ],
            source_uri='gs://bucket/export-uri')

        self.mock_client.create_dataset.assert_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_delete_from_table(self) -> None:
        """Tests that the delete_from_table function runs a query."""
        self.bq_client.delete_from_table_async(self.mock_dataset_id,
                                               self.mock_table_id,
                                               filter_clause="WHERE x > y")
        self.mock_client.query.assert_called()

    def test_delete_from_table_invalid_filter_clause(self) -> None:
        """Tests that the delete_from_table function does not run a query when the filter clause is invalid."""
        with pytest.raises(ValueError):
            self.bq_client.delete_from_table_async(self.mock_dataset_id,
                                                   self.mock_table_id,
                                                   filter_clause="x > y")
        self.mock_client.query.assert_not_called()

    def test_materialize_view_to_table(self) -> None:
        """Tests that the materialize_view_to_table function calls the function to create a table from a query."""
        self.bq_client.materialize_view_to_table(self.mock_view)
        self.mock_client.query.assert_called()

    def test_materialize_view_to_table_no_materialized_view_table_id(
            self) -> None:
        """Tests that the materialize_view_to_table function does not call the function to create a table from a
        query if there is no set materialized_view_table_id on the view."""
        invalid_view = BigQueryView(dataset_id='dataset',
                                    view_id='test_view',
                                    view_query_template='SELECT NULL LIMIT 0',
                                    should_materialize=False)

        with pytest.raises(ValueError):
            self.bq_client.materialize_view_to_table(invalid_view)
        self.mock_client.query.assert_not_called()

    def test_create_table_with_schema(self) -> None:
        """Tests that the create_table_with_schema function calls the create_table function on the client."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        schema_fields = [bigquery.SchemaField('new_schema_field', 'STRING')]

        self.bq_client.create_table_with_schema(self.mock_dataset_id,
                                                self.mock_table_id,
                                                schema_fields)
        self.mock_client.create_table.assert_called()

    def test_create_table_with_schema_table_exists(self) -> None:
        """Tests that the create_table_with_schema function raises an error when the table already exists."""
        self.mock_client.get_table.side_effect = None
        schema_fields = [bigquery.SchemaField('new_schema_field', 'STRING')]

        with pytest.raises(ValueError):
            self.bq_client.create_table_with_schema(self.mock_dataset_id,
                                                    self.mock_table_id,
                                                    schema_fields)
        self.mock_client.create_table.assert_not_called()

    def test_add_missing_fields_to_schema(self) -> None:
        """Tests that the add_missing_fields_to_schema function calls the client to update the table."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref,
                                            self.mock_table_id)
        schema_fields = [bigquery.SchemaField('fake_schema_field', 'STRING')]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField('new_schema_field', 'STRING')
        ]

        self.bq_client.add_missing_fields_to_schema(self.mock_dataset_id,
                                                    self.mock_table_id,
                                                    new_schema_fields)

        self.mock_client.update_table.assert_called()

    def test_add_missing_fields_to_schema_no_missing_fields(self) -> None:
        """Tests that the add_missing_fields_to_schema function does not call the client to update the table when all
        of the fields are already present."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref,
                                            self.mock_table_id)
        schema_fields = [bigquery.SchemaField('fake_schema_field', 'STRING')]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField('fake_schema_field', 'STRING')
        ]

        self.bq_client.add_missing_fields_to_schema(self.mock_dataset_id,
                                                    self.mock_table_id,
                                                    new_schema_fields)

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_no_table(self) -> None:
        """Tests that the add_missing_fields_to_schema function does not call the client to update the table when the
        table does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound('!')
        new_schema_fields = [
            bigquery.SchemaField('fake_schema_field', 'STRING')
        ]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields)

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_fields_with_same_name_different_type(
            self) -> None:
        """Tests that the add_missing_fields_to_schema function raises an error when the user is trying to add a field
        with the same name but different field_type as an existing field."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref,
                                            self.mock_table_id)
        schema_fields = [bigquery.SchemaField('fake_schema_field', 'STRING')]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField('fake_schema_field', 'INTEGER')
        ]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields)

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_fields_with_same_name_different_mode(
            self) -> None:
        """Tests that the add_missing_fields_to_schema function raises an error when the user is trying to add a field
        with the same name but different mode as an existing field."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref,
                                            self.mock_table_id)
        schema_fields = [
            bigquery.SchemaField('fake_schema_field',
                                 'STRING',
                                 mode="NULLABLE")
        ]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField('fake_schema_field',
                                 'STRING',
                                 mode="REQUIRED")
        ]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields)

        self.mock_client.update_table.assert_not_called()

    def test_delete_table(self) -> None:
        """Tests that our delete table function calls the correct client method."""
        self.bq_client.delete_table(self.mock_dataset_id, self.mock_table_id)
        self.mock_client.delete_table.assert_called()

    @mock.patch('google.cloud.bigquery.QueryJob')
    def test_paged_read_single_page_single_row(
            self, mock_query_job: mock.MagicMock) -> None:
        first_row = bigquery.table.Row(
            ['parole', 15, '10N'],
            {
                'supervision_type': 0,
                'revocations': 1,
                'district': 2
            },
        )

        # First call returns a single row, second call returns nothing
        mock_query_job.result.side_effect = [[first_row], []]

        processed_results = []

        def _process_fn(row: bigquery.table.Row) -> None:
            processed_results.append(dict(row))

        self.bq_client.paged_read_and_process(mock_query_job, 1, _process_fn)

        self.assertEqual([dict(first_row)], processed_results)
        mock_query_job.result.assert_has_calls([
            call(max_results=1, start_index=0),
            call(max_results=1, start_index=1),
        ])

    @mock.patch('google.cloud.bigquery.QueryJob')
    def test_paged_read_single_page_multiple_rows(
            self, mock_query_job: mock.MagicMock) -> None:
        first_row = bigquery.table.Row(
            ['parole', 15, '10N'],
            {
                'supervision_type': 0,
                'revocations': 1,
                'district': 2
            },
        )
        second_row = bigquery.table.Row(
            ['probation', 7, '10N'],
            {
                'supervision_type': 0,
                'revocations': 1,
                'district': 2
            },
        )

        # First call returns a single row, second call returns nothing
        mock_query_job.result.side_effect = [[first_row, second_row], []]

        processed_results = []

        def _process_fn(row: bigquery.table.Row) -> None:
            processed_results.append(dict(row))

        self.bq_client.paged_read_and_process(mock_query_job, 10, _process_fn)

        self.assertEqual([dict(first_row), dict(second_row)],
                         processed_results)
        mock_query_job.result.assert_has_calls([
            call(max_results=10, start_index=0),
            call(max_results=10, start_index=2),
        ])

    @mock.patch('google.cloud.bigquery.QueryJob')
    def test_paged_read_multiple_pages(self,
                                       mock_query_job: mock.MagicMock) -> None:
        p1_r1 = bigquery.table.Row(
            ['parole', 15, '10N'],
            {
                'supervision_type': 0,
                'revocations': 1,
                'district': 2
            },
        )
        p1_r2 = bigquery.table.Row(
            ['probation', 7, '10N'],
            {
                'supervision_type': 0,
                'revocations': 1,
                'district': 2
            },
        )

        p2_r1 = bigquery.table.Row(
            ['parole', 8, '10F'],
            {
                'supervision_type': 0,
                'revocations': 1,
                'district': 2
            },
        )
        p2_r2 = bigquery.table.Row(
            ['probation', 3, '10F'],
            {
                'supervision_type': 0,
                'revocations': 1,
                'district': 2
            },
        )

        # First two calls returns results, third call returns nothing
        mock_query_job.result.side_effect = [[p1_r1, p1_r2], [p2_r1, p2_r2],
                                             []]

        processed_results = []

        def _process_fn(row: bigquery.table.Row) -> None:
            processed_results.append(dict(row))

        self.bq_client.paged_read_and_process(mock_query_job, 2, _process_fn)

        self.assertEqual(
            [dict(p1_r1), dict(p1_r2),
             dict(p2_r1), dict(p2_r2)], processed_results)
        mock_query_job.result.assert_has_calls([
            call(max_results=2, start_index=0),
            call(max_results=2, start_index=2),
            call(max_results=2, start_index=4),
        ])
def move_old_dataflow_metrics_to_cold_storage() -> None:
    """Moves old output in Dataflow metrics tables to tables in a cold storage dataset. We only keep the
    MAX_DAYS_IN_DATAFLOW_METRICS_TABLE days worth of data in a Dataflow metric table at once. All other
    output is moved to cold storage.
    """
    bq_client = BigQueryClientImpl()
    dataflow_metrics_dataset = DATAFLOW_METRICS_DATASET
    cold_storage_dataset = DATAFLOW_METRICS_COLD_STORAGE_DATASET
    dataflow_metrics_tables = bq_client.list_tables(dataflow_metrics_dataset)

    for table_ref in dataflow_metrics_tables:
        table_id = table_ref.table_id

        source_data_join_clause = """LEFT JOIN
                          (SELECT DISTINCT job_id AS keep_job_id FROM
                          `{project_id}.{reference_views_dataset}.most_recent_job_id_by_metric_and_state_code_materialized`)
                        ON job_id = keep_job_id
                        LEFT JOIN 
                          (SELECT DISTINCT created_on AS keep_created_date FROM
                          `{project_id}.{dataflow_metrics_dataset}.{table_id}`
                          ORDER BY created_on DESC
                          LIMIT {day_count_limit})
                        ON created_on = keep_created_date
                        """.format(
                            project_id=table_ref.project,
                            dataflow_metrics_dataset=table_ref.dataset_id,
                            reference_views_dataset=REFERENCE_VIEWS_DATASET,
                            table_id=table_id,
                            day_count_limit=MAX_DAYS_IN_DATAFLOW_METRICS_TABLE
                        )

        # Exclude these columns leftover from the exclusion join from being added to the metric tables in cold storage
        columns_to_exclude_from_transfer = ['keep_job_id', 'keep_created_date']

        # This filter will return the rows that should be moved to cold storage
        insert_filter_clause = "WHERE keep_job_id IS NULL AND keep_created_date IS NULL"

        # Query for rows to be moved to the cold storage table
        insert_query = """
            SELECT * EXCEPT({columns_to_exclude}) FROM
            `{project_id}.{dataflow_metrics_dataset}.{table_id}`
            {source_data_join_clause}
            {insert_filter_clause}
        """.format(
            columns_to_exclude=', '.join(columns_to_exclude_from_transfer),
            project_id=table_ref.project,
            dataflow_metrics_dataset=table_ref.dataset_id,
            table_id=table_id,
            source_data_join_clause=source_data_join_clause,
            insert_filter_clause=insert_filter_clause
        )

        # Move data from the Dataflow metrics dataset into the cold storage table, creating the table if necessary
        insert_job = bq_client.insert_into_table_from_query(
            destination_dataset_id=cold_storage_dataset,
            destination_table_id=table_id,
            query=insert_query,
            allow_field_additions=True,
            write_disposition=WriteDisposition.WRITE_APPEND)

        # Wait for the insert job to complete before running the replace job
        insert_job.result()

        # This will return the rows that were not moved to cold storage and should remain in the table
        replace_query = """
            SELECT * EXCEPT({columns_to_exclude}) FROM
            `{project_id}.{dataflow_metrics_dataset}.{table_id}`
            {source_data_join_clause}
            WHERE keep_job_id IS NOT NULL OR keep_created_date IS NOT NULL
        """.format(
            columns_to_exclude=', '.join(columns_to_exclude_from_transfer),
            project_id=table_ref.project,
            dataflow_metrics_dataset=table_ref.dataset_id,
            table_id=table_id,
            source_data_join_clause=source_data_join_clause,
        )

        # Replace the Dataflow table with only the rows that should remain
        replace_job = bq_client.create_table_from_query_async(
            dataflow_metrics_dataset, table_ref.table_id, query=replace_query, overwrite=True)

        # Wait for the replace job to complete before moving on
        replace_job.result()
예제 #6
0
def _view_output_comparison_job(
    bq_client: BigQueryClientImpl,
    view_builder: MetricBigQueryViewBuilder,
    base_view_id: str,
    base_dataset_id: str,
    sandbox_dataset_id: str,
    sandbox_comparison_output_dataset_id: str,
    check_determinism: bool,
    allow_schema_changes: bool,
) -> Tuple[bigquery.QueryJob, str]:
    """Builds and executes the query that compares the base and sandbox views. Returns a tuple with the the QueryJob and
    the table_id where the output will be written to in the sandbox_comparison_output_dataset_id dataset."""
    base_dataset_ref = bq_client.dataset_ref_for_id(base_dataset_id)
    sandbox_dataset_ref = bq_client.dataset_ref_for_id(sandbox_dataset_id)
    output_table_id = f"{view_builder.dataset_id}--{base_view_id}"

    if check_determinism:
        # Compare all columns
        columns_to_compare = ["*"]
        preserve_column_types = True
    else:
        # Columns in deployed view
        deployed_base_view = bq_client.get_table(base_dataset_ref,
                                                 view_builder.view_id)
        # If there are nested columns in the deployed view then we can't allow column type changes
        preserve_column_types = _table_contains_nested_columns(
            deployed_base_view)
        base_columns_to_compare = set(field.name
                                      for field in deployed_base_view.schema)

        # Columns in sandbox view
        deployed_sandbox_view = bq_client.get_table(sandbox_dataset_ref,
                                                    view_builder.view_id)
        if not preserve_column_types:
            # If there are nested columns in the sandbox view then we can't allow column type changes
            preserve_column_types = _table_contains_nested_columns(
                deployed_sandbox_view)

        sandbox_columns_to_compare = set(
            field.name for field in deployed_sandbox_view.schema)

        if allow_schema_changes:
            # Only compare columns in both views
            shared_columns = base_columns_to_compare.intersection(
                sandbox_columns_to_compare)
            columns_to_compare = list(shared_columns)
        else:
            if base_columns_to_compare != sandbox_columns_to_compare:
                raise ValueError(
                    f"Schemas of the {base_dataset_id}.{base_view_id} deployed and"
                    f" sandbox views do not match. If this is expected, please run again"
                    f"with the --allow_schema_changes flag.")
            columns_to_compare = list(base_columns_to_compare)

    # Only include dimensions in both views unless we are checking the determinism of the local
    # view
    metric_dimensions = [
        dimension for dimension in view_builder.dimensions
        if dimension in columns_to_compare or check_determinism
    ]

    if not preserve_column_types:
        # Cast all columns to strings to guard against column types that may have changed
        columns_to_compare = [
            f"CAST({col} AS STRING) as {col}" for col in columns_to_compare
        ]

    base_dataset_id_for_query = (sandbox_dataset_id if check_determinism else
                                 view_builder.dataset_id)

    diff_query = OUTPUT_COMPARISON_TEMPLATE.format(
        project_id=bq_client.project_id,
        base_dataset_id=base_dataset_id_for_query,
        sandbox_dataset_id=sandbox_dataset_id,
        view_id=base_view_id,
        columns_to_compare=", ".join(columns_to_compare),
        dimensions=", ".join(metric_dimensions),
    )

    return (
        bq_client.create_table_from_query_async(
            dataset_id=sandbox_comparison_output_dataset_id,
            table_id=output_table_id,
            query=diff_query,
            overwrite=True,
        ),
        output_table_id,
    )
def move_old_dataflow_metrics_to_cold_storage() -> None:
    """Moves old output in Dataflow metrics tables to tables in a cold storage dataset.
    We only keep the MAX_DAYS_IN_DATAFLOW_METRICS_TABLE days worth of data in a Dataflow
    metric table at once. All other output is moved to cold storage, unless it is the
    most recent job_id for a metric in a state where that metric is regularly calculated,
    and where the year and month of the output falls into the window of what is regularly
    calculated for that metric and state. See the production_calculation_pipeline_templates.yaml
    file for a list of regularly scheduled calculations.

    If a metric has been entirely decommissioned, handles the deletion of the corresponding table.
    """
    bq_client = BigQueryClientImpl()
    dataflow_metrics_dataset = DATAFLOW_METRICS_DATASET
    cold_storage_dataset = dataflow_config.DATAFLOW_METRICS_COLD_STORAGE_DATASET
    dataflow_metrics_tables = bq_client.list_tables(dataflow_metrics_dataset)

    month_range_for_metric_and_state = _get_month_range_for_metric_and_state()

    for table_ref in dataflow_metrics_tables:
        table_id = table_ref.table_id

        if table_id not in dataflow_config.DATAFLOW_TABLES_TO_METRIC_TYPES:
            # This metric has been deprecated. Handle the deletion of the table
            _decommission_dataflow_metric_table(bq_client, table_ref)
            continue

        is_unbounded_date_pipeline = any(
            pipeline in table_id
            for pipeline in dataflow_config.ALWAYS_UNBOUNDED_DATE_PIPELINES)

        # This means there are no currently scheduled pipelines writing metrics to
        # this table with specific month ranges
        no_active_month_range_pipelines = not month_range_for_metric_and_state[
            table_id].items()

        if is_unbounded_date_pipeline or no_active_month_range_pipelines:
            source_data_join_clause = SOURCE_DATA_JOIN_CLAUSE_STANDARD_TEMPLATE.format(
                project_id=table_ref.project,
                dataflow_metrics_dataset=table_ref.dataset_id,
                materialized_metrics_dataset=
                DATAFLOW_METRICS_MATERIALIZED_DATASET,
                table_id=table_id,
                day_count_limit=dataflow_config.
                MAX_DAYS_IN_DATAFLOW_METRICS_TABLE,
            )
        else:
            month_limit_by_state = "\nUNION ALL\n".join([
                f"SELECT '{state_code}' as state_code, {month_limit} as month_limit"
                for state_code, month_limit in
                month_range_for_metric_and_state[table_id].items()
            ])

            source_data_join_clause = (
                SOURCE_DATA_JOIN_CLAUSE_WITH_MONTH_LIMIT_TEMPLATE.format(
                    project_id=table_ref.project,
                    dataflow_metrics_dataset=table_ref.dataset_id,
                    materialized_metrics_dataset=
                    DATAFLOW_METRICS_MATERIALIZED_DATASET,
                    table_id=table_id,
                    day_count_limit=dataflow_config.
                    MAX_DAYS_IN_DATAFLOW_METRICS_TABLE,
                    month_limit_by_state=month_limit_by_state,
                ))

        # Exclude these columns leftover from the exclusion join from being added to the metric tables in cold storage
        columns_to_exclude_from_transfer = ["keep_job_id", "keep_created_date"]

        # This filter will return the rows that should be moved to cold storage
        insert_filter_clause = "WHERE keep_job_id IS NULL AND keep_created_date IS NULL"

        # Query for rows to be moved to the cold storage table
        insert_query = """
            SELECT * EXCEPT({columns_to_exclude}) FROM
            `{project_id}.{dataflow_metrics_dataset}.{table_id}`
            {source_data_join_clause}
            {insert_filter_clause}
        """.format(
            columns_to_exclude=", ".join(columns_to_exclude_from_transfer),
            project_id=table_ref.project,
            dataflow_metrics_dataset=table_ref.dataset_id,
            table_id=table_id,
            source_data_join_clause=source_data_join_clause,
            insert_filter_clause=insert_filter_clause,
        )

        # Move data from the Dataflow metrics dataset into the cold storage table, creating the table if necessary
        insert_job = bq_client.insert_into_table_from_query(
            destination_dataset_id=cold_storage_dataset,
            destination_table_id=table_id,
            query=insert_query,
            allow_field_additions=True,
            write_disposition=WriteDisposition.WRITE_APPEND,
        )

        # Wait for the insert job to complete before running the replace job
        insert_job.result()

        # This will return the rows that were not moved to cold storage and should remain in the table
        replace_query = """
            SELECT * EXCEPT({columns_to_exclude}) FROM
            `{project_id}.{dataflow_metrics_dataset}.{table_id}`
            {source_data_join_clause}
            WHERE keep_job_id IS NOT NULL OR keep_created_date IS NOT NULL
        """.format(
            columns_to_exclude=", ".join(columns_to_exclude_from_transfer),
            project_id=table_ref.project,
            dataflow_metrics_dataset=table_ref.dataset_id,
            table_id=table_id,
            source_data_join_clause=source_data_join_clause,
        )

        # Replace the Dataflow table with only the rows that should remain
        replace_job = bq_client.create_table_from_query_async(
            dataflow_metrics_dataset,
            table_ref.table_id,
            query=replace_query,
            overwrite=True,
        )

        # Wait for the replace job to complete before moving on
        replace_job.result()
예제 #8
0
class BigQueryClientImplTest(unittest.TestCase):
    """Tests for BigQueryClientImpl"""

    def setUp(self) -> None:
        self.location = "US"
        self.mock_project_id = "fake-recidiviz-project"
        self.mock_dataset_id = "fake-dataset"
        self.mock_table_id = "test_table"
        self.mock_dataset_ref = bigquery.dataset.DatasetReference(
            self.mock_project_id, self.mock_dataset_id
        )
        self.mock_table = self.mock_dataset_ref.table(self.mock_table_id)

        self.metadata_patcher = mock.patch("recidiviz.utils.metadata.project_id")
        self.mock_project_id_fn = self.metadata_patcher.start()
        self.mock_project_id_fn.return_value = self.mock_project_id

        self.client_patcher = mock.patch("recidiviz.big_query.big_query_client.client")
        self.mock_client = self.client_patcher.start().return_value

        self.mock_view = BigQueryView(
            dataset_id="dataset",
            view_id="test_view",
            view_query_template="SELECT NULL LIMIT 0",
            should_materialize=True,
        )

        self.bq_client = BigQueryClientImpl()

    def tearDown(self) -> None:
        self.client_patcher.stop()
        self.metadata_patcher.stop()

    def test_create_dataset_if_necessary(self) -> None:
        """Check that a dataset is created if it does not exist."""
        self.mock_client.get_dataset.side_effect = exceptions.NotFound("!")
        self.bq_client.create_dataset_if_necessary(self.mock_dataset_ref)
        self.mock_client.create_dataset.assert_called()

    def test_create_dataset_if_necessary_dataset_exists(self) -> None:
        """Check that a dataset is not created if it already exists."""
        self.mock_client.get_dataset.side_effect = None
        self.bq_client.create_dataset_if_necessary(self.mock_dataset_ref)
        self.mock_client.create_dataset.assert_not_called()

    def test_create_dataset_if_necessary_table_expiration(self) -> None:
        """Check that the dataset is created with a set table expiration if the dataset does not exist and the
        new_dataset_table_expiration_ms is specified."""
        self.mock_client.get_dataset.side_effect = exceptions.NotFound("!")
        self.bq_client.create_dataset_if_necessary(
            self.mock_dataset_ref, default_table_expiration_ms=6000
        )
        self.mock_client.create_dataset.assert_called()

    def test_table_exists(self) -> None:
        """Check that table_exists returns True if the table exists."""
        self.mock_client.get_table.side_effect = None
        self.assertTrue(
            self.bq_client.table_exists(self.mock_dataset_ref, self.mock_table_id)
        )

    def test_table_exists_does_not_exist(self) -> None:
        """Check that table_exists returns False if the table does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound("!")
        with self.assertLogs(level="WARNING"):
            table_exists = self.bq_client.table_exists(
                self.mock_dataset_ref, self.mock_table_id
            )
            self.assertFalse(table_exists)

    def test_create_or_update_view_creates_view(self) -> None:
        """create_or_update_view creates a View if it does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound("!")
        self.bq_client.create_or_update_view(self.mock_dataset_ref, self.mock_view)
        self.mock_client.create_table.assert_called()
        self.mock_client.update_table.assert_not_called()

    def test_create_or_update_view_updates_view(self) -> None:
        """create_or_update_view updates a View if it already exist."""
        self.mock_client.get_table.side_effect = None
        self.bq_client.create_or_update_view(self.mock_dataset_ref, self.mock_view)
        self.mock_client.update_table.assert_called()
        self.mock_client.create_table.assert_not_called()

    def test_export_to_cloud_storage(self) -> None:
        """export_to_cloud_storage extracts the table corresponding to the
        view."""
        self.assertIsNotNone(
            self.bq_client.export_table_to_cloud_storage_async(
                source_table_dataset_ref=self.mock_dataset_ref,
                source_table_id="source-table",
                destination_uri=f"gs://{self.mock_project_id}-bucket/destination_path.json",
                destination_format=bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON,
                print_header=True,
            )
        )
        self.mock_client.extract_table.assert_called()

    def test_export_to_cloud_storage_no_table(self) -> None:
        """export_to_cloud_storage does not extract from a table if the table
        does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound("!")
        with self.assertLogs(level="WARNING"):
            self.assertIsNone(
                self.bq_client.export_table_to_cloud_storage_async(
                    source_table_dataset_ref=self.mock_dataset_ref,
                    source_table_id="source-table",
                    destination_uri=f"gs://{self.mock_project_id}-bucket/destination_path.json",
                    destination_format=bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON,
                    print_header=True,
                )
            )
            self.mock_client.extract_table.assert_not_called()

    def test_load_table_async_create_dataset(self) -> None:
        """Test that load_table_from_cloud_storage_async tries to create a parent dataset."""

        self.mock_client.get_dataset.side_effect = exceptions.NotFound("!")

        self.bq_client.load_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset_ref,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[
                SchemaField("my_column", "STRING", "NULLABLE", None, ())
            ],
            source_uri="gs://bucket/export-uri",
        )

        self.mock_client.create_dataset.assert_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_load_table_async_dataset_exists(self) -> None:
        """Test that load_table_from_cloud_storage_async does not try to create a
        parent dataset if it already exists."""

        self.bq_client.load_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset_ref,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[
                SchemaField("my_column", "STRING", "NULLABLE", None, ())
            ],
            source_uri="gs://bucket/export-uri",
        )

        self.mock_client.create_dataset.assert_not_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_export_query_results_to_cloud_storage_no_table(self) -> None:
        bucket = self.mock_project_id + "-bucket"
        self.mock_client.get_table.side_effect = exceptions.NotFound("!")
        with self.assertLogs(level="WARNING"):
            self.bq_client.export_query_results_to_cloud_storage(
                [
                    ExportQueryConfig.from_view_query(
                        view=self.mock_view,
                        view_filter_clause="WHERE x = y",
                        intermediate_table_name=self.mock_table_id,
                        output_uri=f"gs://{bucket}/view.json",
                        output_format=bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON,
                    )
                ],
                print_header=True,
            )

    def test_export_query_results_to_cloud_storage(self) -> None:
        """export_query_results_to_cloud_storage creates the table from the view query and
        exports the table."""
        bucket = self.mock_project_id + "-bucket"
        query_job: futures.Future = futures.Future()
        query_job.set_result([])
        extract_job: futures.Future = futures.Future()
        extract_job.set_result(None)
        self.mock_client.query.return_value = query_job
        self.mock_client.extract_table.return_value = extract_job
        self.bq_client.export_query_results_to_cloud_storage(
            [
                ExportQueryConfig.from_view_query(
                    view=self.mock_view,
                    view_filter_clause="WHERE x = y",
                    intermediate_table_name=self.mock_table_id,
                    output_uri=f"gs://{bucket}/view.json",
                    output_format=bigquery.DestinationFormat.NEWLINE_DELIMITED_JSON,
                )
            ],
            print_header=True,
        )
        self.mock_client.query.assert_called()
        self.mock_client.extract_table.assert_called()
        self.mock_client.delete_table.assert_called_with(
            bigquery.DatasetReference(
                self.mock_project_id, self.mock_view.dataset_id
            ).table(self.mock_table_id)
        )

    def test_create_table_from_query(self) -> None:
        """Tests that the create_table_from_query function calls the function to create a table from a query."""
        self.bq_client.create_table_from_query_async(
            self.mock_dataset_id,
            self.mock_table_id,
            query="SELECT * FROM some.fake.table",
            query_parameters=[],
        )
        self.mock_client.query.assert_called()

    @mock.patch("google.cloud.bigquery.job.QueryJobConfig")
    def test_insert_into_table_from_table_async(
        self, mock_job_config: mock.MagicMock
    ) -> None:
        """Tests that the insert_into_table_from_table_async function runs a query."""
        self.bq_client.insert_into_table_from_table_async(
            source_dataset_id=self.mock_dataset_id,
            source_table_id=self.mock_table_id,
            destination_dataset_id=self.mock_dataset_id,
            destination_table_id="fake_table_temp",
        )
        expected_query = f"SELECT * FROM `fake-recidiviz-project.{self.mock_dataset_id}.{self.mock_table_id}`"
        self.mock_client.get_table.assert_called()
        self.mock_client.query.assert_called_with(
            query=expected_query, location=self.location, job_config=mock_job_config()
        )

    @mock.patch("google.cloud.bigquery.job.QueryJobConfig")
    def test_insert_into_table_from_table_async_hydrate_missing_columns(
        self, mock_job_config: mock.MagicMock
    ) -> None:
        """Tests that the insert_into_table_from_table_async generates a query with missing columns as NULL."""
        with mock.patch(
            "recidiviz.big_query.big_query_client.BigQueryClientImpl"
            "._get_excess_schema_fields"
        ) as mock_missing:
            mock_missing.return_value = [
                bigquery.SchemaField("state_code", "STRING", "REQUIRED"),
                bigquery.SchemaField("new_column_name", "INTEGER", "REQUIRED"),
            ]
            self.mock_destination_id = "fake_table_temp"
            self.bq_client.insert_into_table_from_table_async(
                source_dataset_id=self.mock_dataset_id,
                source_table_id=self.mock_table_id,
                destination_dataset_id=self.mock_dataset_id,
                destination_table_id=self.mock_destination_id,
                hydrate_missing_columns_with_null=True,
                allow_field_additions=True,
            )
            expected_query = (
                "SELECT *, CAST(NULL AS STRING) AS state_code, CAST(NULL AS INTEGER) AS new_column_name "
                f"FROM `fake-recidiviz-project.{self.mock_dataset_id}.{self.mock_table_id}`"
            )
            self.mock_client.query.assert_called_with(
                query=expected_query,
                location=self.location,
                job_config=mock_job_config(),
            )

    def test_insert_into_table_from_table_invalid_destination(self) -> None:
        """Tests that the insert_into_table_from_table_async function does not run the query if the destination
        table does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound("!")

        with pytest.raises(ValueError):
            self.bq_client.insert_into_table_from_table_async(
                self.mock_dataset_id,
                self.mock_table_id,
                "fake_source_dataset_id",
                "fake_table_id",
            )
        self.mock_client.get_table.assert_called()
        self.mock_client.query.assert_not_called()

    def test_insert_into_table_from_table_invalid_filter_clause(self) -> None:
        """Tests that the insert_into_table_from_table_async function does not run the query if the filter clause
        does not start with a WHERE."""
        with pytest.raises(ValueError):
            self.bq_client.insert_into_table_from_table_async(
                self.mock_dataset_id,
                self.mock_table_id,
                "fake_source_dataset_id",
                "fake_table_id",
                source_data_filter_clause="bad filter clause",
            )
        self.mock_client.query.assert_not_called()

    @mock.patch("google.cloud.bigquery.job.QueryJobConfig")
    def test_insert_into_table_from_table_with_filter_clause(
        self, mock_job_config: mock.MagicMock
    ) -> None:
        """Tests that the insert_into_table_from_table_async generates a valid query when given a filter clause."""
        filter_clause = "WHERE state_code IN ('US_ND')"
        job_config = mock_job_config()
        self.bq_client.insert_into_table_from_table_async(
            self.mock_dataset_id,
            self.mock_table_id,
            "fake_source_dataset_id",
            "fake_table_id",
            source_data_filter_clause=filter_clause,
        )
        expected_query = (
            "SELECT * FROM `fake-recidiviz-project.fake-dataset.test_table` "
            "WHERE state_code IN ('US_ND')"
        )
        self.mock_client.query.assert_called_with(
            query=expected_query, location=self.location, job_config=job_config
        )

    def test_insert_into_table_from_cloud_storage_async(self) -> None:
        self.mock_client.get_dataset.side_effect = exceptions.NotFound("!")

        self.bq_client.insert_into_table_from_cloud_storage_async(
            destination_dataset_ref=self.mock_dataset_ref,
            destination_table_id=self.mock_table_id,
            destination_table_schema=[
                SchemaField("my_column", "STRING", "NULLABLE", None, ())
            ],
            source_uri="gs://bucket/export-uri",
        )

        self.mock_client.create_dataset.assert_called()
        self.mock_client.load_table_from_uri.assert_called()

    def test_delete_from_table(self) -> None:
        """Tests that the delete_from_table function runs a query."""
        self.bq_client.delete_from_table_async(
            self.mock_dataset_id, self.mock_table_id, filter_clause="WHERE x > y"
        )
        self.mock_client.query.assert_called()

    def test_delete_from_table_invalid_filter_clause(self) -> None:
        """Tests that the delete_from_table function does not run a query when the filter clause is invalid."""
        with pytest.raises(ValueError):
            self.bq_client.delete_from_table_async(
                self.mock_dataset_id, self.mock_table_id, filter_clause="x > y"
            )
        self.mock_client.query.assert_not_called()

    def test_materialize_view_to_table(self) -> None:
        """Tests that the materialize_view_to_table function calls the function to create a table from a query."""
        self.bq_client.materialize_view_to_table(self.mock_view)
        self.mock_client.query.assert_called()

    def test_materialize_view_to_table_no_materialized_view_table_id(self) -> None:
        """Tests that the materialize_view_to_table function does not call the function to create a table from a
        query if there is no set materialized_view_table_id on the view."""
        invalid_view = BigQueryView(
            dataset_id="dataset",
            view_id="test_view",
            view_query_template="SELECT NULL LIMIT 0",
            should_materialize=False,
        )

        with pytest.raises(ValueError):
            self.bq_client.materialize_view_to_table(invalid_view)
        self.mock_client.query.assert_not_called()

    def test_create_table_with_schema(self) -> None:
        """Tests that the create_table_with_schema function calls the create_table function on the client."""
        self.mock_client.get_table.side_effect = exceptions.NotFound("!")
        schema_fields = [bigquery.SchemaField("new_schema_field", "STRING")]

        self.bq_client.create_table_with_schema(
            self.mock_dataset_id, self.mock_table_id, schema_fields
        )
        self.mock_client.create_table.assert_called()

    def test_create_table_with_schema_table_exists(self) -> None:
        """Tests that the create_table_with_schema function raises an error when the table already exists."""
        self.mock_client.get_table.side_effect = None
        schema_fields = [bigquery.SchemaField("new_schema_field", "STRING")]

        with pytest.raises(ValueError):
            self.bq_client.create_table_with_schema(
                self.mock_dataset_id, self.mock_table_id, schema_fields
            )
        self.mock_client.create_table.assert_not_called()

    def test_add_missing_fields_to_schema(self) -> None:
        """Tests that the add_missing_fields_to_schema function calls the client to update the table."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [bigquery.SchemaField("fake_schema_field", "STRING")]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [bigquery.SchemaField("new_schema_field", "STRING")]

        self.bq_client.add_missing_fields_to_schema(
            self.mock_dataset_id, self.mock_table_id, new_schema_fields
        )

        self.mock_client.update_table.assert_called()

    def test_add_missing_fields_to_schema_no_missing_fields(self) -> None:
        """Tests that the add_missing_fields_to_schema function does not call the client to update the table when all
        of the fields are already present."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [bigquery.SchemaField("fake_schema_field", "STRING")]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [bigquery.SchemaField("fake_schema_field", "STRING")]

        self.bq_client.add_missing_fields_to_schema(
            self.mock_dataset_id, self.mock_table_id, new_schema_fields
        )

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_no_table(self) -> None:
        """Tests that the add_missing_fields_to_schema function does not call the client to update the table when the
        table does not exist."""
        self.mock_client.get_table.side_effect = exceptions.NotFound("!")
        new_schema_fields = [bigquery.SchemaField("fake_schema_field", "STRING")]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields
            )

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_fields_with_same_name_different_type(
        self,
    ) -> None:
        """Tests that the add_missing_fields_to_schema function raises an error when the user is trying to add a field
        with the same name but different field_type as an existing field."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [bigquery.SchemaField("fake_schema_field", "STRING")]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [bigquery.SchemaField("fake_schema_field", "INTEGER")]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields
            )

        self.mock_client.update_table.assert_not_called()

    def test_add_missing_fields_to_schema_fields_with_same_name_different_mode(
        self,
    ) -> None:
        """Tests that the add_missing_fields_to_schema function raises an error when the user is trying to add a field
        with the same name but different mode as an existing field."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [
            bigquery.SchemaField("fake_schema_field", "STRING", mode="NULLABLE")
        ]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField("fake_schema_field", "STRING", mode="REQUIRED")
        ]

        with pytest.raises(ValueError):
            self.bq_client.add_missing_fields_to_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields
            )

        self.mock_client.update_table.assert_not_called()

    def test_remove_unused_fields_from_schema(self) -> None:
        """Tests that remove_unused_fields_from_schema() calls the client to update the table with a query."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [
            bigquery.SchemaField("field_1", "STRING"),
            bigquery.SchemaField("field_2", "STRING"),
        ]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [bigquery.SchemaField("field_1", "STRING")]

        self.bq_client.remove_unused_fields_from_schema(
            self.mock_dataset_id, self.mock_table_id, new_schema_fields
        )

        self.mock_client.query.assert_called()

    def test_remove_unused_fields_from_schema_no_missing_fields(self) -> None:
        """Tests that remove_unused_fields_from_schema() does nothing if there are no missing fields."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [bigquery.SchemaField("field_1", "STRING")]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [bigquery.SchemaField("field_1", "STRING")]

        self.bq_client.remove_unused_fields_from_schema(
            self.mock_dataset_id, self.mock_table_id, new_schema_fields
        )

        self.mock_client.query.assert_not_called()

    def test_remove_unused_fields_from_schema_ignore_excess_desired_fields(
        self,
    ) -> None:
        """Tests that remove_unused_fields_from_schema() drops columns even when there are excess desired fields."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [
            bigquery.SchemaField("field_1", "STRING"),
            bigquery.SchemaField("field_2", "STRING"),
        ]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField("field_1", "STRING"),
            bigquery.SchemaField("field_3", "STRING"),
        ]

        self.bq_client.remove_unused_fields_from_schema(
            self.mock_dataset_id, self.mock_table_id, new_schema_fields
        )

        self.mock_client.query.assert_called()

    @mock.patch(
        "recidiviz.big_query.big_query_client.BigQueryClientImpl.remove_unused_fields_from_schema"
    )
    @mock.patch(
        "recidiviz.big_query.big_query_client.BigQueryClientImpl.add_missing_fields_to_schema"
    )
    def test_update_schema(
        self, remove_unused_mock: mock.MagicMock, add_missing_mock: mock.MagicMock
    ) -> None:
        """Tests that update_schema() calls both field updaters if the inputs are valid."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [
            bigquery.SchemaField("field_1", "STRING"),
            bigquery.SchemaField("field_2", "STRING"),
        ]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField("field_1", "STRING"),
            bigquery.SchemaField("field_3", "STRING"),
        ]

        self.bq_client.update_schema(
            self.mock_dataset_id, self.mock_table_id, new_schema_fields
        )

        remove_unused_mock.assert_called()
        add_missing_mock.assert_called()

    @mock.patch(
        "recidiviz.big_query.big_query_client.BigQueryClientImpl.remove_unused_fields_from_schema"
    )
    @mock.patch(
        "recidiviz.big_query.big_query_client.BigQueryClientImpl.add_missing_fields_to_schema"
    )
    def test_update_schema_fails_on_changed_type(
        self, remove_unused_mock: mock.MagicMock, add_missing_mock: mock.MagicMock
    ) -> None:
        """Tests that update_schema() throws if we try to change a field type."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [
            bigquery.SchemaField("field_1", "STRING"),
            bigquery.SchemaField("field_2", "STRING"),
        ]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField("field_1", "STRING"),
            bigquery.SchemaField("field_2", "INT"),
        ]

        with pytest.raises(ValueError):
            self.bq_client.update_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields
            )

        remove_unused_mock.assert_not_called()
        add_missing_mock.assert_not_called()

    @mock.patch(
        "recidiviz.big_query.big_query_client.BigQueryClientImpl.remove_unused_fields_from_schema"
    )
    @mock.patch(
        "recidiviz.big_query.big_query_client.BigQueryClientImpl.add_missing_fields_to_schema"
    )
    def test_update_schema_fails_on_changed_mode(
        self, remove_unused_mock: mock.MagicMock, add_missing_mock: mock.MagicMock
    ) -> None:
        """Tests that update_schema() throws if we try to change a field mode."""
        table_ref = bigquery.TableReference(self.mock_dataset_ref, self.mock_table_id)
        schema_fields = [
            bigquery.SchemaField("field_1", "STRING", "NULLABLE"),
            bigquery.SchemaField("field_2", "STRING"),
        ]
        table = bigquery.Table(table_ref, schema_fields)
        self.mock_client.get_table.return_value = table

        new_schema_fields = [
            bigquery.SchemaField("field_1", "STRING", "REQUIRED"),
            bigquery.SchemaField("field_2", "INT"),
        ]

        with pytest.raises(ValueError):
            self.bq_client.update_schema(
                self.mock_dataset_id, self.mock_table_id, new_schema_fields
            )

        remove_unused_mock.assert_not_called()
        add_missing_mock.assert_not_called()

    def test__get_excess_schema_fields_simple_excess(self) -> None:
        """Tests _get_excess_schema_fields() when extended_schema is a strict superset of base_schema."""
        base_schema = [bigquery.SchemaField("field_1", "INT")]
        extended_schema = [
            bigquery.SchemaField("field_1", "INT"),
            bigquery.SchemaField("field_2", "INT"),
            bigquery.SchemaField("field_3", "INT"),
        ]

        excess_fields = BigQueryClientImpl._get_excess_schema_fields(
            base_schema, extended_schema
        )

        self.assertEqual(
            excess_fields,
            [
                bigquery.SchemaField("field_2", "INT"),
                bigquery.SchemaField("field_3", "INT"),
            ],
        )

    def test__get_excess_schema_fields_with_extra_base_schema(self) -> None:
        """Tests _get_excess_schema_fields() when base_schema has fields not in extended_schema."""
        base_schema = [
            bigquery.SchemaField("field_1", "INT"),
            bigquery.SchemaField("field_2", "INT"),
        ]
        extended_schema = [
            bigquery.SchemaField("field_1", "INT"),
            bigquery.SchemaField("field_3", "INT"),
            bigquery.SchemaField("field_4", "INT"),
        ]

        excess_fields = BigQueryClientImpl._get_excess_schema_fields(
            base_schema, extended_schema
        )

        self.assertEqual(
            excess_fields,
            [
                bigquery.SchemaField("field_3", "INT"),
                bigquery.SchemaField("field_4", "INT"),
            ],
        )

    def test__get_excess_schema_fields_with_matching_schema(self) -> None:
        """Tests _get_excess_schema_fields() when base_schema is the same as extended_schema."""
        base_schema = [
            bigquery.SchemaField("field_1", "INT"),
            bigquery.SchemaField("field_2", "INT"),
        ]

        excess_fields = BigQueryClientImpl._get_excess_schema_fields(
            base_schema, base_schema
        )

        self.assertListEqual(excess_fields, [])

    def test__get_excess_schema_fields_no_excess(self) -> None:
        """Tests _get_excess_schema_fields() when base_schema is a superset of extended_schema."""
        base_schema = [
            bigquery.SchemaField("field_1", "INT"),
            bigquery.SchemaField("field_2", "INT"),
        ]
        extended_schema = [bigquery.SchemaField("field_2", "INT")]

        excess_fields = BigQueryClientImpl._get_excess_schema_fields(
            base_schema, extended_schema
        )

        self.assertListEqual(excess_fields, [])

    def test_delete_table(self) -> None:
        """Tests that our delete table function calls the correct client method."""
        self.bq_client.delete_table(self.mock_dataset_id, self.mock_table_id)
        self.mock_client.delete_table.assert_called()

    @mock.patch("google.cloud.bigquery.QueryJob")
    def test_paged_read_single_page_single_row(
        self, mock_query_job: mock.MagicMock
    ) -> None:
        first_row = bigquery.table.Row(
            ["parole", 15, "10N"],
            {"supervision_type": 0, "revocations": 1, "district": 2},
        )

        # First call returns a single row, second call returns nothing
        mock_query_job.result.side_effect = [[first_row], []]

        processed_results = []

        def _process_fn(row: bigquery.table.Row) -> None:
            processed_results.append(dict(row))

        self.bq_client.paged_read_and_process(mock_query_job, 1, _process_fn)

        self.assertEqual([dict(first_row)], processed_results)
        mock_query_job.result.assert_has_calls(
            [
                call(max_results=1, start_index=0),
                call(max_results=1, start_index=1),
            ]
        )

    @mock.patch("google.cloud.bigquery.QueryJob")
    def test_paged_read_single_page_multiple_rows(
        self, mock_query_job: mock.MagicMock
    ) -> None:
        first_row = bigquery.table.Row(
            ["parole", 15, "10N"],
            {"supervision_type": 0, "revocations": 1, "district": 2},
        )
        second_row = bigquery.table.Row(
            ["probation", 7, "10N"],
            {"supervision_type": 0, "revocations": 1, "district": 2},
        )

        # First call returns a single row, second call returns nothing
        mock_query_job.result.side_effect = [[first_row, second_row], []]

        processed_results = []

        def _process_fn(row: bigquery.table.Row) -> None:
            processed_results.append(dict(row))

        self.bq_client.paged_read_and_process(mock_query_job, 10, _process_fn)

        self.assertEqual([dict(first_row), dict(second_row)], processed_results)
        mock_query_job.result.assert_has_calls(
            [
                call(max_results=10, start_index=0),
                call(max_results=10, start_index=2),
            ]
        )

    @mock.patch("google.cloud.bigquery.QueryJob")
    def test_paged_read_multiple_pages(self, mock_query_job: mock.MagicMock) -> None:
        p1_r1 = bigquery.table.Row(
            ["parole", 15, "10N"],
            {"supervision_type": 0, "revocations": 1, "district": 2},
        )
        p1_r2 = bigquery.table.Row(
            ["probation", 7, "10N"],
            {"supervision_type": 0, "revocations": 1, "district": 2},
        )

        p2_r1 = bigquery.table.Row(
            ["parole", 8, "10F"],
            {"supervision_type": 0, "revocations": 1, "district": 2},
        )
        p2_r2 = bigquery.table.Row(
            ["probation", 3, "10F"],
            {"supervision_type": 0, "revocations": 1, "district": 2},
        )

        # First two calls returns results, third call returns nothing
        mock_query_job.result.side_effect = [[p1_r1, p1_r2], [p2_r1, p2_r2], []]

        processed_results = []

        def _process_fn(row: bigquery.table.Row) -> None:
            processed_results.append(dict(row))

        self.bq_client.paged_read_and_process(mock_query_job, 2, _process_fn)

        self.assertEqual(
            [dict(p1_r1), dict(p1_r2), dict(p2_r1), dict(p2_r2)], processed_results
        )
        mock_query_job.result.assert_has_calls(
            [
                call(max_results=2, start_index=0),
                call(max_results=2, start_index=2),
                call(max_results=2, start_index=4),
            ]
        )