Esempio n. 1
0
    def __init__(
        self,
        *,
        region: Region,
        fs: DirectIngestGCSFileSystem,
        ingest_bucket_path: GcsfsBucketPath,
        temp_output_directory_path: GcsfsDirectoryPath,
        big_query_client: BigQueryClient,
        region_raw_file_config: Optional[
            DirectIngestRegionRawFileConfig] = None,
        upload_chunk_size: int = _DEFAULT_BQ_UPLOAD_CHUNK_SIZE,
    ):

        self.region = region
        self.fs = fs
        self.ingest_bucket_path = ingest_bucket_path
        self.temp_output_directory_path = temp_output_directory_path
        self.big_query_client = big_query_client
        self.region_raw_file_config = (
            region_raw_file_config
            if region_raw_file_config else DirectIngestRegionRawFileConfig(
                region_code=self.region.region_code,
                region_module=self.region.region_module,
            ))
        self.upload_chunk_size = upload_chunk_size
        self.csv_reader = GcsfsCsvReader(fs)
        self.raw_table_migrations = DirectIngestRawTableMigrationCollector(
            region_code=self.region.region_code,
            regions_module_override=self.region.region_module,
        ).collect_raw_table_migration_queries()
 def __init__(self,
              region_name: str,
              system_level: SystemLevel,
              ingest_directory_path: Optional[str],
              storage_directory_path: Optional[str],
              max_delay_sec_between_files: Optional[int] = None):
     super().__init__(region_name, system_level, ingest_directory_path,
                      storage_directory_path, max_delay_sec_between_files)
     self.csv_reader = GcsfsCsvReader(
         gcsfs.GCSFileSystem(project=metadata.project_id(),
                             cache_timeout=GCSFS_NO_CACHING))
Esempio n. 3
0
    def _cache_ingest_file_as_parquet_task() -> Tuple[str, int]:
        """Downloads a GCS file and stores it to our Redis cache in Parquet format

         Example:
             POST /admin/data_discovery/cache_ingest_file_as_parquet_task
         Request Body:
             gcs_file_uri: (string) The `gs://` URI of the file to cache
             file_encoding: (string) The encoding of said file
             file_separator: (string) The value delimiter of side file
             file_quoting: (int) A `csv.QUOTE_*` value for the parser i.e. 3 (csv.QUOTE_NONE)
        Args:
             N/A
         Returns:
             Cache hit/miss result
        """
        cache = get_data_discovery_cache()
        body = get_cloud_task_json_body()
        path = GcsfsFilePath.from_absolute_path(body["gcs_file_uri"])
        parquet_path = SingleIngestFileParquetCache.parquet_cache_key(path)

        if not cache.exists(parquet_path):
            fs = GcsfsFactory.build()
            parquet_cache = SingleIngestFileParquetCache(
                get_data_discovery_cache(), path, expiry=DataDiscoveryTTL.PARQUET_FILES
            )
            csv_reader = GcsfsCsvReader(fs)
            csv_reader.streaming_read(
                path,
                CacheIngestFileAsParquetDelegate(parquet_cache, path),
                encodings_to_try=list(
                    {
                        body["file_encoding"],
                        *COMMON_RAW_FILE_ENCODINGS,
                    }
                ),
                delimiter=body["file_separator"],
                quoting=body["file_quoting"],
                lineterminator=body.get("file_custom_line_terminator"),
                chunk_size=75000,
                index_col=False,
                keep_default_na=False,
            )

            return CACHE_MISS, HTTPStatus.CREATED

        return CACHE_HIT, HTTPStatus.OK
    def __init__(self,
                 *,
                 region: Region,
                 fs: DirectIngestGCSFileSystem,
                 ingest_directory_path: GcsfsDirectoryPath,
                 temp_output_directory_path: GcsfsDirectoryPath,
                 big_query_client: BigQueryClient,
                 region_raw_file_config: Optional[
                     DirectIngestRegionRawFileConfig] = None,
                 upload_chunk_size: int = _DEFAULT_BQ_UPLOAD_CHUNK_SIZE):

        self.region = region
        self.fs = fs
        self.ingest_directory_path = ingest_directory_path
        self.temp_output_directory_path = temp_output_directory_path
        self.big_query_client = big_query_client
        self.region_raw_file_config = region_raw_file_config \
            if region_raw_file_config else DirectIngestRegionRawFileConfig(region_code=self.region.region_code)
        self.upload_chunk_size = upload_chunk_size
        self.csv_reader = GcsfsCsvReader(
            gcsfs.GCSFileSystem(project=metadata.project_id(),
                                cache_timeout=GCSFS_NO_CACHING))
    def setUp(self) -> None:
        self.project_id = "recidiviz-456"
        self.project_id_patcher = patch("recidiviz.utils.metadata.project_id")
        self.project_id_patcher.start().return_value = self.project_id
        self.test_region = fake_region(region_code="us_xx",
                                       region_module=fake_regions_module)

        self.fs = DirectIngestGCSFileSystem(FakeGCSFileSystem())
        self.ingest_bucket_path = GcsfsBucketPath(
            bucket_name="my_ingest_bucket")
        self.temp_output_path = GcsfsDirectoryPath(bucket_name="temp_bucket")

        self.region_raw_file_config = DirectIngestRegionRawFileConfig(
            region_code="us_xx", region_module=fake_regions_module)

        self.mock_big_query_client = create_autospec(BigQueryClient)
        self.num_lines_uploaded = 0

        self.mock_big_query_client.insert_into_table_from_cloud_storage_async.side_effect = (
            self.mock_import_raw_file_to_big_query)

        self.import_manager = DirectIngestRawFileImportManager(
            region=self.test_region,
            fs=self.fs,
            ingest_bucket_path=self.ingest_bucket_path,
            temp_output_directory_path=self.temp_output_path,
            region_raw_file_config=self.region_raw_file_config,
            big_query_client=self.mock_big_query_client,
        )
        self.import_manager.csv_reader = GcsfsCsvReader(
            self.fs.gcs_file_system)

        self.time_patcher = patch(
            "recidiviz.ingest.direct.controllers.direct_ingest_raw_file_import_manager.time"
        )
        self.mock_time = self.time_patcher.start()

        def fake_get_dataset_ref(dataset_id: str) -> bigquery.DatasetReference:
            return bigquery.DatasetReference(project=self.project_id,
                                             dataset_id=dataset_id)

        self.mock_big_query_client.dataset_ref_for_id = fake_get_dataset_ref
    def run_single_line_gcs_csv_reader_test(
        self,
        input_file_path: str,
        expected_result_path: str,
        encoding: str,
        delimiter: str,
        line_terminator: str,
    ) -> None:
        """Runs a test reads a single line of a normalized stream using the csv reader,
        mimicking the way we read the columns from each file.
        """
        fake_fs = FakeGCSFileSystem()
        input_gcs_path = GcsfsFilePath.from_absolute_path(
            "gs://my-bucket/input.csv")
        fake_fs.test_add_path(path=input_gcs_path, local_path=input_file_path)
        input_delegate = _FakeDfCapturingDelegate()
        csv_reader = GcsfsCsvReader(fake_fs)
        csv_reader.streaming_read(
            path=input_gcs_path,
            dtype=str,
            delegate=input_delegate,
            chunk_size=1,
            encodings_to_try=[encoding],
            nrows=1,
            sep=delimiter,
            quoting=csv.QUOTE_NONE,
            lineterminator=line_terminator,
            engine="python",
        )

        expected_gcs_path = GcsfsFilePath.from_absolute_path(
            "gs://my-bucket/expected.csv")
        fake_fs.test_add_path(path=expected_gcs_path,
                              local_path=expected_result_path)
        expected_delegate = _FakeDfCapturingDelegate()
        csv_reader.streaming_read(
            path=expected_gcs_path,
            delegate=expected_delegate,
            dtype=str,
            chunk_size=1,
            nrows=1,
        )

        self.assertEqual(len(expected_delegate.dfs), len(input_delegate.dfs))
        for i, expected_df in enumerate(expected_delegate.dfs):
            expected_df.equals(input_delegate.dfs[i])
    def setUp(self) -> None:

        self.mock_gcsfs = create_autospec(gcsfs.GCSFileSystem)
        self.mock_gcsfs.open = _fake_gcsfs_open
        self.reader = GcsfsCsvReader(self.mock_gcsfs)
class GcsfsCsvReaderTest(unittest.TestCase):
    """Tests for the GcsfsCsvReader."""
    def setUp(self) -> None:

        self.mock_gcsfs = create_autospec(gcsfs.GCSFileSystem)
        self.mock_gcsfs.open = _fake_gcsfs_open
        self.reader = GcsfsCsvReader(self.mock_gcsfs)

    def _validate_empty_file_result(
            self, delegate: _TestGcsfsCsvReaderDelegate) -> None:
        self.assertEqual(1, len(delegate.encodings_attempted))
        self.assertEqual(delegate.encodings_attempted[0],
                         delegate.successful_encoding)
        self.assertEqual(0, len(delegate.dataframes))
        self.assertEqual(0, delegate.decode_errors)
        self.assertEqual(0, delegate.exceptions)

    def test_read_completely_empty_file(self) -> None:
        empty_file_path = fixtures.as_filepath("tagA.csv")

        delegate = _TestGcsfsCsvReaderDelegate()
        self.reader.streaming_read(
            GcsfsFilePath.from_absolute_path(empty_file_path),
            delegate=delegate,
            chunk_size=1,
        )
        self.assertEqual(1, len(delegate.encodings_attempted))
        self.assertEqual(delegate.encodings_attempted[0],
                         delegate.successful_encoding)
        self.assertEqual(0, len(delegate.dataframes))
        self.assertEqual(0, delegate.decode_errors)
        self.assertEqual(0, delegate.exceptions)

        delegate = _TestGcsfsCsvReaderDelegate()
        self.reader.streaming_read(
            GcsfsFilePath.from_absolute_path(empty_file_path),
            delegate=delegate,
            chunk_size=10,
        )
        self.assertEqual(1, len(delegate.encodings_attempted))
        self.assertEqual(delegate.encodings_attempted[0],
                         delegate.successful_encoding)
        self.assertEqual(0, len(delegate.dataframes))
        self.assertEqual(0, delegate.decode_errors)
        self.assertEqual(0, delegate.exceptions)

    def test_read_file_with_columns_no_contents(self) -> None:
        empty_file_path = fixtures.as_filepath("tagB.csv")

        delegate = _TestGcsfsCsvReaderDelegate()
        self.reader.streaming_read(
            GcsfsFilePath.from_absolute_path(empty_file_path),
            delegate=delegate,
            chunk_size=1,
        )
        self.assertEqual(1, len(delegate.encodings_attempted))
        self.assertEqual(delegate.encodings_attempted[0],
                         delegate.successful_encoding)
        self.assertEqual(1, len(delegate.dataframes))
        encoding, df = delegate.dataframes[0]
        self.assertEqual(encoding, delegate.successful_encoding)
        self.assertEqual(0, df.shape[0])  # No rows
        self.assertEqual(7, df.shape[1])  # 7 columns
        self.assertEqual(0, delegate.decode_errors)
        self.assertEqual(0, delegate.exceptions)

        delegate = _TestGcsfsCsvReaderDelegate()
        self.reader.streaming_read(
            GcsfsFilePath.from_absolute_path(empty_file_path),
            delegate=delegate,
            chunk_size=10,
        )
        self.assertEqual(1, len(delegate.encodings_attempted))
        self.assertEqual(delegate.encodings_attempted[0],
                         delegate.successful_encoding)
        self.assertEqual(1, len(delegate.dataframes))
        encoding, df = delegate.dataframes[0]
        self.assertEqual(encoding, delegate.successful_encoding)
        self.assertEqual(0, df.shape[0])  # No rows
        self.assertEqual(7, df.shape[1])  # 7 columns
        self.assertEqual(0, delegate.decode_errors)
        self.assertEqual(0, delegate.exceptions)

    def test_read_no_encodings_match(self) -> None:
        file_path = fixtures.as_filepath("encoded_latin_1.csv")
        delegate = _TestGcsfsCsvReaderDelegate()
        encodings_to_try = ["UTF-8", "UTF-16"]
        with self.assertRaises(ValueError):
            self.reader.streaming_read(
                GcsfsFilePath.from_absolute_path(file_path),
                delegate=delegate,
                chunk_size=10,
                encodings_to_try=encodings_to_try,
            )
        self.assertEqual(encodings_to_try, delegate.encodings_attempted)
        self.assertEqual(2, len(delegate.encodings_attempted))
        self.assertIsNone(delegate.successful_encoding)
        self.assertEqual(0, len(delegate.dataframes))
        self.assertEqual(2, delegate.decode_errors)
        self.assertEqual(0, delegate.exceptions)

    def test_read_with_failure_first(self) -> None:
        file_path = fixtures.as_filepath("encoded_latin_1.csv")
        delegate = _TestGcsfsCsvReaderDelegate()
        self.reader.streaming_read(GcsfsFilePath.from_absolute_path(file_path),
                                   delegate=delegate,
                                   chunk_size=1)

        index = COMMON_RAW_FILE_ENCODINGS.index("ISO-8859-1")
        self.assertEqual(index + 1, len(delegate.encodings_attempted))
        self.assertEqual(COMMON_RAW_FILE_ENCODINGS[:(index + 1)],
                         delegate.encodings_attempted)
        self.assertEqual("ISO-8859-1", delegate.successful_encoding)
        self.assertEqual(4, len(delegate.dataframes))
        self.assertEqual({"ISO-8859-1"},
                         {encoding
                          for encoding, df in delegate.dataframes})
        self.assertEqual(1, delegate.decode_errors)
        self.assertEqual(0, delegate.exceptions)

    def test_read_with_no_failure(self) -> None:
        file_path = fixtures.as_filepath("encoded_utf_8.csv")
        delegate = _TestGcsfsCsvReaderDelegate()
        self.reader.streaming_read(GcsfsFilePath.from_absolute_path(file_path),
                                   delegate=delegate,
                                   chunk_size=1)

        self.assertEqual(1, len(delegate.encodings_attempted))
        self.assertEqual("UTF-8", delegate.encodings_attempted[0])
        self.assertEqual("UTF-8", delegate.successful_encoding)
        self.assertEqual(4, len(delegate.dataframes))
        self.assertEqual({"UTF-8"},
                         {encoding
                          for encoding, df in delegate.dataframes})
        self.assertEqual(0, delegate.decode_errors)
        self.assertEqual(0, delegate.exceptions)

    def test_read_with_exception(self) -> None:
        class _TestException(ValueError):
            pass

        class _ExceptionDelegate(_TestGcsfsCsvReaderDelegate):
            def on_dataframe(self, encoding: str, chunk_num: int,
                             df: pd.DataFrame) -> bool:
                should_continue = super().on_dataframe(encoding, chunk_num, df)
                if chunk_num > 0:
                    raise _TestException("We crashed processing!")
                return should_continue

        file_path = fixtures.as_filepath("encoded_utf_8.csv")
        delegate = _ExceptionDelegate()

        with self.assertRaises(_TestException):
            self.reader.streaming_read(
                GcsfsFilePath.from_absolute_path(file_path),
                delegate=delegate,
                chunk_size=1,
            )

        self.assertEqual(1, len(delegate.encodings_attempted))
        self.assertEqual("UTF-8", delegate.encodings_attempted[0])
        self.assertIsNone(delegate.successful_encoding)
        self.assertEqual(2, len(delegate.dataframes))
        self.assertEqual({"UTF-8"},
                         {encoding
                          for encoding, df in delegate.dataframes})
        self.assertEqual(0, delegate.decode_errors)
        self.assertEqual(1, delegate.exceptions)
class DirectIngestRawFileImportManager:
    """Class that stores raw data import configs for a region, with functionality for executing an import of a specific
    file.
    """
    def __init__(self,
                 *,
                 region: Region,
                 fs: DirectIngestGCSFileSystem,
                 ingest_directory_path: GcsfsDirectoryPath,
                 temp_output_directory_path: GcsfsDirectoryPath,
                 big_query_client: BigQueryClient,
                 region_raw_file_config: Optional[
                     DirectIngestRegionRawFileConfig] = None,
                 upload_chunk_size: int = _DEFAULT_BQ_UPLOAD_CHUNK_SIZE):

        self.region = region
        self.fs = fs
        self.ingest_directory_path = ingest_directory_path
        self.temp_output_directory_path = temp_output_directory_path
        self.big_query_client = big_query_client
        self.region_raw_file_config = region_raw_file_config \
            if region_raw_file_config else DirectIngestRegionRawFileConfig(region_code=self.region.region_code)
        self.upload_chunk_size = upload_chunk_size
        self.csv_reader = GcsfsCsvReader(
            gcsfs.GCSFileSystem(project=metadata.project_id(),
                                cache_timeout=GCSFS_NO_CACHING))

    def get_unprocessed_raw_files_to_import(self) -> List[GcsfsFilePath]:
        if not self.region.are_raw_data_bq_imports_enabled_in_env():
            raise ValueError(
                f'Cannot import raw files for region [{self.region.region_code}]'
            )

        unprocessed_paths = self.fs.get_unprocessed_file_paths(
            self.ingest_directory_path, GcsfsDirectIngestFileType.RAW_DATA)
        paths_to_import = []
        for path in unprocessed_paths:
            parts = filename_parts_from_path(path)
            if parts.file_tag in self.region_raw_file_config.raw_file_tags:
                paths_to_import.append(path)
            else:
                logging.warning(
                    'Unrecognized raw file tag [%s] for region [%s].',
                    parts.file_tag, self.region.region_code)

        return paths_to_import

    @classmethod
    def raw_tables_dataset_for_region(cls, region_code: str):
        return f'{region_code.lower()}_raw_data'

    def import_raw_file_to_big_query(
            self, path: GcsfsFilePath,
            file_metadata: DirectIngestFileMetadata) -> None:
        """Import a raw data file at the given path to the appropriate raw data table in BigQuery."""

        if not self.region.are_raw_data_bq_imports_enabled_in_env():
            raise ValueError(
                f'Cannot import raw files for region [{self.region.region_code}]'
            )

        parts = filename_parts_from_path(path)
        if parts.file_tag not in self.region_raw_file_config.raw_file_tags:
            raise ValueError(
                f'Attempting to import raw file with tag [{parts.file_tag}] unspecified by [{self.region.region_code}] '
                f'config.')

        if parts.file_type != GcsfsDirectIngestFileType.RAW_DATA:
            raise ValueError(
                f'Unexpected file type [{parts.file_type}] for path [{parts.file_tag}].'
            )

        logging.info('Beginning BigQuery upload of raw file [%s]',
                     path.abs_path())

        temp_output_paths = self._upload_contents_to_temp_gcs_paths(
            path, file_metadata)
        self._load_contents_to_bigquery(path, temp_output_paths)

        logging.info('Completed BigQuery import of [%s]', path.abs_path())

    def _upload_contents_to_temp_gcs_paths(
        self, path: GcsfsFilePath, file_metadata: DirectIngestFileMetadata
    ) -> List[Tuple[GcsfsFilePath, List[str]]]:
        """Uploads the contents of the file at the provided path to one or more GCS files, with whitespace stripped and
        additional metadata columns added.

        Returns a list of tuple pairs containing the destination paths and corrected CSV columns for that file.
        """

        logging.info('Starting chunked upload of contents to GCS')

        parts = filename_parts_from_path(path)
        file_config = self.region_raw_file_config.raw_file_configs[
            parts.file_tag]

        columns = self._get_validated_columns(path, file_config)

        delegate = DirectIngestRawDataSplittingGcsfsCsvReaderDelegate(
            path, self.fs, file_metadata, self.temp_output_directory_path)

        self.csv_reader.streaming_read(
            path,
            delegate=delegate,
            chunk_size=self.upload_chunk_size,
            encodings_to_try=file_config.encodings_to_try(),
            index_col=False,
            header=None,
            skiprows=1,
            usecols=columns,
            names=columns,
            keep_default_na=False,
            **self._common_read_csv_kwargs(file_config))

        return delegate.output_paths_with_columns

    def _load_contents_to_bigquery(
            self, path: GcsfsFilePath,
            temp_paths_with_columns: List[Tuple[GcsfsFilePath, List[str]]]):
        """Loads the contents in the given handle to the appropriate table in BigQuery."""

        logging.info('Starting chunked load of contents to BigQuery')
        temp_output_paths = [path for path, _ in temp_paths_with_columns]
        temp_path_to_load_job: Dict[GcsfsFilePath, bigquery.LoadJob] = {}
        dataset_id = self.raw_tables_dataset_for_region(
            self.region.region_code)

        try:
            for i, (temp_output_path,
                    columns) in enumerate(temp_paths_with_columns):
                if i > 0:
                    # Note: If this sleep becomes a serious performance issue, we could refactor to intersperse reading
                    # chunks to temp paths with starting each load job. In this case, we'd have to be careful to delete
                    # any partially uploaded uploaded portion of the file if we fail to parse a chunk in the middle.
                    logging.info(
                        'Sleeping for [%s] seconds to avoid exceeding per-table update rate quotas.',
                        _PER_TABLE_UPDATE_RATE_LIMITING_SEC)
                    time.sleep(_PER_TABLE_UPDATE_RATE_LIMITING_SEC)

                parts = filename_parts_from_path(path)
                load_job = self.big_query_client.insert_into_table_from_cloud_storage_async(
                    source_uri=temp_output_path.uri(),
                    destination_dataset_ref=self.big_query_client.
                    dataset_ref_for_id(dataset_id),
                    destination_table_id=parts.file_tag,
                    destination_table_schema=self.
                    _create_raw_table_schema_from_columns(columns),
                )
                logging.info('Load job [%s] for chunk [%d] started',
                             load_job.job_id, i)

                temp_path_to_load_job[temp_output_path] = load_job
        except Exception as e:
            logging.error('Failed to start load jobs - cleaning up temp paths')
            self._delete_temp_output_paths(temp_output_paths)
            raise e

        try:
            self._wait_for_jobs(temp_path_to_load_job)
        finally:
            self._delete_temp_output_paths(temp_output_paths)

    @staticmethod
    def _wait_for_jobs(
            temp_path_to_load_job: Dict[GcsfsFilePath,
                                        bigquery.LoadJob]) -> None:
        for temp_output_path, load_job in temp_path_to_load_job.items():
            try:
                logging.info('Waiting for load of [%s]',
                             temp_output_path.abs_path())
                load_job.result()
                logging.info('BigQuery load of [%s] complete',
                             temp_output_path.abs_path())
            except BadRequest as e:
                logging.error(
                    'Insert job [%s] for path [%s] failed with errors: [%s]',
                    load_job.job_id, temp_output_path, load_job.errors)
                raise e

    def _delete_temp_output_paths(
            self, temp_output_paths: List[GcsfsFilePath]) -> None:
        for temp_output_path in temp_output_paths:
            logging.info('Deleting temp file [%s].',
                         temp_output_path.abs_path())
            self.fs.delete(temp_output_path)

    @staticmethod
    def remove_column_non_printable_characters(
            columns: List[str]) -> List[str]:
        """Removes all non-printable characters that occasionally show up in column names. This is known to happen in
        random columns """
        fixed_columns = []
        for col in columns:
            fixed_col = ''.join([x for x in col if x in string.printable])
            if fixed_col != col:
                logging.info(
                    'Found non-printable characters in column [%s]. Original: [%s]',
                    fixed_col, col.__repr__())
            fixed_columns.append(fixed_col)
        return fixed_columns

    def _get_validated_columns(
            self, path: GcsfsFilePath,
            file_config: DirectIngestRawFileConfig) -> List[str]:
        """Returns a list of normalized column names for the raw data file at the given path."""
        # TODO(3020): We should not derive the columns from what we get in the uploaded raw data CSV - we should instead
        # define the set of columns we expect to see in each input CSV (with mandatory documentation) and update
        # this function to make sure that the columns in the CSV is a strict subset of expected columns. This will allow
        # to gracefully any raw data re-imports where a new column gets introduced in a later file.

        delegate = ReadOneGcsfsCsvReaderDelegate()
        self.csv_reader.streaming_read(
            path,
            delegate=delegate,
            chunk_size=1,
            nrows=1,
            **self._common_read_csv_kwargs(file_config))
        df = delegate.df

        if not isinstance(df, pd.DataFrame):
            raise ValueError(f'Unexpected type for DataFrame: [{type(df)}]')

        columns = self.remove_column_non_printable_characters(df.columns)

        # Strip whitespace from head/tail of column names
        columns = [c.strip() for c in columns]

        for column_name in columns:
            if not column_name:
                raise ValueError(
                    f'Found empty column name in [{file_config.file_tag}]')

            non_allowable_chars = self._get_non_allowable_bq_column_chars(
                column_name)
            if non_allowable_chars:
                # TODO(3020): Some regions (US_MO) are known to have unsupported chars in their column names - will need
                #  to implement how we reliably convert these column names.
                raise ValueError(
                    f'Column [{column_name}] for file has non-allowable characters {non_allowable_chars}.'
                )

        return columns

    @staticmethod
    def _get_non_allowable_bq_column_chars(column_name: str) -> Set[str]:
        def is_bq_allowable_column_char(x: str) -> bool:
            return x in string.ascii_letters or x in string.digits or x == '_'

        return {x for x in column_name if not is_bq_allowable_column_char(x)}

    @staticmethod
    def _create_raw_table_schema_from_columns(
            columns: List[str]) -> List[bigquery.SchemaField]:
        """Creates schema for use in `to_gbq` based on the provided columns."""
        schema = []
        for name in columns:
            typ_str = bigquery.enums.SqlTypeNames.STRING.value
            mode = 'NULLABLE'
            if name == _FILE_ID_COL_NAME:
                mode = 'REQUIRED'
                typ_str = bigquery.enums.SqlTypeNames.INTEGER.value
            if name == _UPDATE_DATETIME_COL_NAME:
                mode = 'REQUIRED'
                typ_str = bigquery.enums.SqlTypeNames.DATETIME.value
            schema.append(
                bigquery.SchemaField(name=name, field_type=typ_str, mode=mode))
        return schema

    @staticmethod
    def _common_read_csv_kwargs(
            file_config: DirectIngestRawFileConfig) -> Dict[str, Any]:
        return {
            'sep':
            file_config.separator,
            'quoting': (csv.QUOTE_NONE
                        if file_config.ignore_quotes else csv.QUOTE_MINIMAL),
        }
class CsvGcsfsDirectIngestController(GcsfsDirectIngestController):
    """Direct ingest controller for regions that read CSV files from the
    GCSFileSystem.
    """
    def __init__(self,
                 region_name: str,
                 system_level: SystemLevel,
                 ingest_directory_path: Optional[str],
                 storage_directory_path: Optional[str],
                 max_delay_sec_between_files: Optional[int] = None):
        super().__init__(region_name, system_level, ingest_directory_path,
                         storage_directory_path, max_delay_sec_between_files)
        self.csv_reader = GcsfsCsvReader(
            gcsfs.GCSFileSystem(project=metadata.project_id(),
                                cache_timeout=GCSFS_NO_CACHING))

    @classmethod
    @abc.abstractmethod
    def get_file_tag_rank_list(cls) -> List[str]:
        pass

    def _file_meets_file_line_limit(self, line_limit: int,
                                    path: GcsfsFilePath) -> bool:
        delegate = ReadOneGcsfsCsvReaderDelegate()

        # Read a chunk up to one line bigger than the acceptable size
        try:
            self.csv_reader.streaming_read(path,
                                           delegate=delegate,
                                           chunk_size=(line_limit + 1))
        except FileNotFoundError:
            return True

        if delegate.df is None:
            # If the file is empty, it's fine.
            return True

        # If length of the only chunk is less than or equal to the acceptable
        # size, file meets line limit.
        return len(delegate.df) <= line_limit

    def _split_file(self, path: GcsfsFilePath) -> List[GcsfsFilePath]:
        parts = filename_parts_from_path(path)

        if self.region.is_raw_vs_ingest_file_name_detection_enabled() and \
                parts.file_type == GcsfsDirectIngestFileType.RAW_DATA:
            raise ValueError(
                f'Splitting raw files unsupported. Attempting to split [{path.abs_path()}]'
            )

        delegate = DirectIngestFileSplittingGcsfsCsvReaderDelegate(
            path, self.fs, self.temp_output_directory_path)
        self.csv_reader.streaming_read(
            path,
            delegate=delegate,
            chunk_size=self.ingest_file_split_line_limit)
        output_paths = [path for path, _ in delegate.output_paths_with_columns]

        return output_paths

    def _yaml_filepath(self, file_tag):
        return os.path.join(os.path.dirname(inspect.getfile(self.__class__)),
                            f'{self.region.region_code}_{file_tag}.yaml')

    @staticmethod
    def _wrap_with_tag(file_tag: str, callback: Optional[Callable]):
        if callback is None:
            return None

        def wrapped_cb(*args):
            return callback(file_tag, *args)

        return wrapped_cb

    @classmethod
    def _wrap_list_with_tag(cls, file_tag: str, callbacks: List[Callable]):
        return [
            cls._wrap_with_tag(file_tag, callback) for callback in callbacks
        ]

    def _parse(self, args: GcsfsIngestArgs,
               contents_handle: GcsfsFileContentsHandle) -> IngestInfo:
        file_tag = self.file_tag(args.file_path)

        if file_tag not in self.get_file_tag_rank_list():
            raise DirectIngestError(
                msg=f"No mapping found for tag [{file_tag}]",
                error_type=DirectIngestErrorType.INPUT_ERROR)

        file_mapping = self._yaml_filepath(file_tag)

        row_pre_processors = self._wrap_list_with_tag(
            file_tag, self._get_row_pre_processors_for_file(file_tag))
        row_post_processors = self._wrap_list_with_tag(
            file_tag, self._get_row_post_processors_for_file(file_tag))
        file_post_processors = self._wrap_list_with_tag(
            file_tag, self._get_file_post_processors_for_file(file_tag))
        # pylint: disable=assignment-from-none
        primary_key_override_callback = self._wrap_with_tag(
            file_tag, self._get_primary_key_override_for_file(file_tag))
        # pylint: disable=assignment-from-none
        ancestor_chain_overrides_callback = \
            self._wrap_with_tag(
                file_tag,
                self._get_ancestor_chain_overrides_callback_for_file(file_tag))
        should_set_with_empty_values = \
            file_tag in self._get_files_to_set_with_empty_values()

        data_extractor = CsvDataExtractor(
            file_mapping, row_pre_processors, row_post_processors,
            file_post_processors, ancestor_chain_overrides_callback,
            primary_key_override_callback, self.system_level,
            should_set_with_empty_values)

        return data_extractor.extract_and_populate_data(
            contents_handle.get_contents_iterator())

    def _are_contents_empty(self, args: GcsfsIngestArgs,
                            contents_handle: GcsfsFileContentsHandle) -> bool:
        """Returns true if the CSV file is emtpy, i.e. it contains no non-header
         rows.
         """
        delegate = ReadOneGcsfsCsvReaderDelegate()
        self.csv_reader.streaming_read(args.file_path,
                                       delegate=delegate,
                                       chunk_size=1,
                                       skiprows=1)
        return delegate.df is None

    def _get_row_pre_processors_for_file(self, _file_tag) -> List[Callable]:
        """Subclasses should override to return row_pre_processors for a given
        file tag.
        """
        return []

    def _get_row_post_processors_for_file(self, _file_tag) -> List[Callable]:
        """Subclasses should override to return row_post_processors for a given
        file tag.
        """
        return []

    def _get_file_post_processors_for_file(self, _file_tag) -> List[Callable]:
        """Subclasses should override to return file_post_processors for a given
        file tag.
        """
        return []

    def _get_ancestor_chain_overrides_callback_for_file(
            self, _file_tag: str) -> Optional[Callable]:
        """Subclasses should override to return an
        ancestor_chain_overrides_callback for a given file tag.
        """
        return None

    def _get_primary_key_override_for_file(
            self, _file_tag: str) -> Optional[Callable]:
        """Subclasses should override to return a primary_key_override for a
        given file tag.
        """
        return None

    def _get_files_to_set_with_empty_values(self) -> List[str]:
        """Subclasses should override to return which files to set with empty
        values (see CsvDataExtractor).
        """
        return []
Esempio n. 11
0
def build_gcsfs_controller_for_tests(
    controller_cls: Type[CsvGcsfsDirectIngestController],
    ingest_instance: DirectIngestInstance,
    run_async: bool,
    can_start_ingest: bool = True,
    regions_module: ModuleType = fake_regions_module,
) -> BaseDirectIngestController:
    """Builds an instance of |controller_cls| for use in tests with several internal classes mocked properly. """
    fake_fs = FakeGCSFileSystem()

    def mock_build_fs() -> FakeGCSFileSystem:
        return fake_fs

    if "TestGcsfsDirectIngestController" in controller_cls.__name__:
        view_collector_cls: Type[
            BigQueryViewCollector] = FakeDirectIngestPreProcessedIngestViewCollector
    else:
        view_collector_cls = DirectIngestPreProcessedIngestViewCollector

    with patch(
            f"{BaseDirectIngestController.__module__}.DirectIngestCloudTaskManagerImpl"
    ) as mock_task_factory_cls:
        with patch(
                f"{BaseDirectIngestController.__module__}.BigQueryClientImpl"
        ) as mock_big_query_client_cls:
            with patch(
                    f"{BaseDirectIngestController.__module__}.DirectIngestRawFileImportManager",
                    FakeDirectIngestRawFileImportManager,
            ):
                with patch(
                        f"{BaseDirectIngestController.__module__}.DirectIngestPreProcessedIngestViewCollector",
                        view_collector_cls,
                ):
                    task_manager = (
                        FakeAsyncDirectIngestCloudTaskManager() if run_async
                        else FakeSynchronousDirectIngestCloudTaskManager())
                    mock_task_factory_cls.return_value = task_manager
                    mock_big_query_client_cls.return_value = (
                        FakeDirectIngestBigQueryClient(
                            project_id=metadata.project_id(),
                            fs=fake_fs,
                            region_code=controller_cls.region_code(),
                        ))
                    with patch.object(GcsfsFactory, "build",
                                      new=mock_build_fs):
                        with patch.object(
                                direct_ingest_raw_table_migration_collector,
                                "regions",
                                new=regions_module,
                        ):
                            controller = controller_cls(
                                ingest_bucket_path=
                                gcsfs_direct_ingest_bucket_for_region(
                                    region_code=controller_cls.region_code(),
                                    system_level=SystemLevel.for_region_code(
                                        controller_cls.region_code(),
                                        is_direct_ingest=True,
                                    ),
                                    ingest_instance=ingest_instance,
                                    project_id="recidiviz-xxx",
                                ))
                            controller.csv_reader = GcsfsCsvReader(fake_fs)
                            controller.raw_file_import_manager.csv_reader = (
                                controller.csv_reader)

                            task_manager.set_controller(controller)
                            fake_fs.test_set_delegate(
                                DirectIngestFakeGCSFileSystemDelegate(
                                    controller,
                                    can_start_ingest=can_start_ingest))
                            return controller
Esempio n. 12
0
class DirectIngestRawFileImportManager:
    """Class that stores raw data import configs for a region, with functionality for executing an import of a specific
    file.
    """
    def __init__(
        self,
        *,
        region: Region,
        fs: DirectIngestGCSFileSystem,
        ingest_bucket_path: GcsfsBucketPath,
        temp_output_directory_path: GcsfsDirectoryPath,
        big_query_client: BigQueryClient,
        region_raw_file_config: Optional[
            DirectIngestRegionRawFileConfig] = None,
        upload_chunk_size: int = _DEFAULT_BQ_UPLOAD_CHUNK_SIZE,
    ):

        self.region = region
        self.fs = fs
        self.ingest_bucket_path = ingest_bucket_path
        self.temp_output_directory_path = temp_output_directory_path
        self.big_query_client = big_query_client
        self.region_raw_file_config = (
            region_raw_file_config
            if region_raw_file_config else DirectIngestRegionRawFileConfig(
                region_code=self.region.region_code,
                region_module=self.region.region_module,
            ))
        self.upload_chunk_size = upload_chunk_size
        self.csv_reader = GcsfsCsvReader(fs)
        self.raw_table_migrations = DirectIngestRawTableMigrationCollector(
            region_code=self.region.region_code,
            regions_module_override=self.region.region_module,
        ).collect_raw_table_migration_queries()

    def get_unprocessed_raw_files_to_import(self) -> List[GcsfsFilePath]:
        unprocessed_paths = self.fs.get_unprocessed_file_paths(
            self.ingest_bucket_path, GcsfsDirectIngestFileType.RAW_DATA)
        paths_to_import = []
        unrecognized_file_tags = set()
        for path in unprocessed_paths:
            parts = filename_parts_from_path(path)
            if parts.file_tag in self.region_raw_file_config.raw_file_tags:
                paths_to_import.append(path)
            else:
                unrecognized_file_tags.add(parts.file_tag)

        for file_tag in sorted(unrecognized_file_tags):
            logging.warning(
                "Unrecognized raw file tag [%s] for region [%s].",
                file_tag,
                self.region.region_code,
            )

        return paths_to_import

    @classmethod
    def raw_tables_dataset_for_region(cls, region_code: str) -> str:
        return f"{region_code.lower()}_raw_data"

    def _raw_tables_dataset(self) -> str:
        return self.raw_tables_dataset_for_region(self.region.region_code)

    def import_raw_file_to_big_query(
            self, path: GcsfsFilePath,
            file_metadata: DirectIngestRawFileMetadata) -> None:
        """Import a raw data file at the given path to the appropriate raw data table in BigQuery."""
        parts = filename_parts_from_path(path)
        if parts.file_tag not in self.region_raw_file_config.raw_file_tags:
            raise ValueError(
                f"Attempting to import raw file with tag [{parts.file_tag}] unspecified by [{self.region.region_code}] "
                f"config.")

        if parts.file_type != GcsfsDirectIngestFileType.RAW_DATA:
            raise ValueError(
                f"Unexpected file type [{parts.file_type}] for path [{parts.file_tag}]."
            )

        logging.info("Beginning BigQuery upload of raw file [%s]",
                     path.abs_path())

        temp_output_paths = self._upload_contents_to_temp_gcs_paths(
            path, file_metadata)
        self._load_contents_to_bigquery(path, temp_output_paths)

        migration_queries = self.raw_table_migrations.get(parts.file_tag, [])
        logging.info(
            "Running [%s] migration queries for table [%s]",
            len(migration_queries),
            parts.file_tag,
        )
        for migration_query in migration_queries:
            query_job = self.big_query_client.run_query_async(
                query_str=migration_query)
            # Wait for the migration query to complete before running the next one
            query_job.result()

        logging.info("Completed BigQuery import of [%s]", path.abs_path())

    def _upload_contents_to_temp_gcs_paths(
        self, path: GcsfsFilePath, file_metadata: DirectIngestRawFileMetadata
    ) -> List[Tuple[GcsfsFilePath, List[str]]]:
        """Uploads the contents of the file at the provided path to one or more GCS files, with whitespace stripped and
        additional metadata columns added.
        Returns a list of tuple pairs containing the destination paths and corrected CSV columns for that file.
        """

        logging.info("Starting chunked upload of contents to GCS")

        parts = filename_parts_from_path(path)
        file_config = self.region_raw_file_config.raw_file_configs[
            parts.file_tag]

        columns = self._get_validated_columns(path, file_config)

        delegate = DirectIngestRawDataSplittingGcsfsCsvReaderDelegate(
            path, self.fs, file_metadata, self.temp_output_directory_path)

        self.csv_reader.streaming_read(
            path,
            delegate=delegate,
            chunk_size=self.upload_chunk_size,
            encodings_to_try=file_config.encodings_to_try(),
            index_col=False,
            header=0,
            names=columns,
            keep_default_na=False,
            **self._common_read_csv_kwargs(file_config),
        )

        return delegate.output_paths_with_columns

    def _load_contents_to_bigquery(
        self,
        path: GcsfsFilePath,
        temp_paths_with_columns: List[Tuple[GcsfsFilePath, List[str]]],
    ) -> None:
        """Loads the contents in the given handle to the appropriate table in BigQuery."""

        logging.info("Starting chunked load of contents to BigQuery")
        temp_output_paths = [path for path, _ in temp_paths_with_columns]
        temp_path_to_load_job: Dict[GcsfsFilePath, bigquery.LoadJob] = {}
        dataset_id = self._raw_tables_dataset()

        try:
            for i, (temp_output_path,
                    columns) in enumerate(temp_paths_with_columns):
                if i > 0:
                    # Note: If this sleep becomes a serious performance issue, we could refactor to intersperse reading
                    # chunks to temp paths with starting each load job. In this case, we'd have to be careful to delete
                    # any partially uploaded uploaded portion of the file if we fail to parse a chunk in the middle.
                    logging.info(
                        "Sleeping for [%s] seconds to avoid exceeding per-table update rate quotas.",
                        _PER_TABLE_UPDATE_RATE_LIMITING_SEC,
                    )
                    time.sleep(_PER_TABLE_UPDATE_RATE_LIMITING_SEC)

                parts = filename_parts_from_path(path)
                load_job = self.big_query_client.insert_into_table_from_cloud_storage_async(
                    source_uri=temp_output_path.uri(),
                    destination_dataset_ref=self.big_query_client.
                    dataset_ref_for_id(dataset_id),
                    destination_table_id=parts.file_tag,
                    destination_table_schema=self.
                    _create_raw_table_schema_from_columns(columns),
                )
                logging.info("Load job [%s] for chunk [%d] started",
                             load_job.job_id, i)

                temp_path_to_load_job[temp_output_path] = load_job
        except Exception as e:
            logging.error("Failed to start load jobs - cleaning up temp paths")
            self._delete_temp_output_paths(temp_output_paths)
            raise e

        try:
            self._wait_for_jobs(temp_path_to_load_job)
        finally:
            self._delete_temp_output_paths(temp_output_paths)

    @staticmethod
    def _wait_for_jobs(
            temp_path_to_load_job: Dict[GcsfsFilePath,
                                        bigquery.LoadJob]) -> None:
        for temp_output_path, load_job in temp_path_to_load_job.items():
            try:
                logging.info(
                    "Waiting for load of [%s] into [%s]",
                    temp_output_path.abs_path(),
                    load_job.destination,
                )
                load_job.result()
                logging.info("BigQuery load of [%s] complete",
                             temp_output_path.abs_path())
            except BadRequest as e:
                logging.error(
                    "Insert job [%s] for path [%s] failed with errors: [%s]",
                    load_job.job_id,
                    temp_output_path,
                    load_job.errors,
                )
                raise e

    def _delete_temp_output_paths(
            self, temp_output_paths: List[GcsfsFilePath]) -> None:
        for temp_output_path in temp_output_paths:
            logging.info("Deleting temp file [%s].",
                         temp_output_path.abs_path())
            self.fs.delete(temp_output_path)

    @staticmethod
    def remove_column_non_printable_characters(
            columns: List[str]) -> List[str]:
        """Removes all non-printable characters that occasionally show up in column names. This is known to happen in
        random columns"""
        fixed_columns = []
        for col in columns:
            fixed_col = "".join([x for x in col if x in string.printable])
            if fixed_col != col:
                logging.info(
                    "Found non-printable characters in column [%s]. Original: [%s]",
                    fixed_col,
                    col.__repr__(),
                )
            fixed_columns.append(fixed_col)
        return fixed_columns

    def _get_validated_columns(
            self, path: GcsfsFilePath,
            file_config: DirectIngestRawFileConfig) -> List[str]:
        """Returns a list of normalized column names for the raw data file at the given path."""
        # TODO(#3807): We should not derive the columns from what we get in the uploaded raw data CSV - we should
        # instead define the set of columns we expect to see in each input CSV (with mandatory documentation) and update
        # this function to make sure that the columns in the CSV is a strict subset of expected columns. This will allow
        # to gracefully any raw data re-imports where a new column gets introduced in a later file.

        delegate = ReadOneGcsfsCsvReaderDelegate()
        self.csv_reader.streaming_read(
            path,
            delegate=delegate,
            chunk_size=1,
            encodings_to_try=file_config.encodings_to_try(),
            nrows=1,
            **self._common_read_csv_kwargs(file_config),
        )
        df = delegate.df

        if not isinstance(df, pd.DataFrame):
            raise ValueError(f"Unexpected type for DataFrame: [{type(df)}]")

        columns = self.remove_column_non_printable_characters(df.columns)

        # Strip whitespace from head/tail of column names
        columns = [c.strip() for c in columns]

        normalized_columns = set()
        for i, column_name in enumerate(columns):
            if not column_name:
                raise ValueError(
                    f"Found empty column name in [{file_config.file_tag}]")

            column_name = self._convert_non_allowable_bq_column_chars(
                column_name)

            # BQ doesn't allow column names to begin with a number, so we prepend an underscore in that case
            if column_name[0] in string.digits:
                column_name = "_" + column_name

            # If the capitalization of the column name doesn't match the capitalization
            # listed in the file config, update the capitalization.
            if column_name not in file_config.columns:
                caps_normalized_col = file_config.caps_normalized_col(
                    column_name)
                if caps_normalized_col:
                    column_name = caps_normalized_col

            if column_name in normalized_columns:
                raise ValueError(
                    f"Multiple columns with name [{column_name}] after normalization."
                )
            normalized_columns.add(column_name)
            columns[i] = column_name

        if len(normalized_columns) == 1:
            # A single-column file is almost always indicative of a parsing error. If
            # this column name is not registered in the file config, we throw.
            column = one(normalized_columns)
            if column not in file_config.columns:
                raise ValueError(
                    f"Found only one column: [{column}]. Columns likely did not "
                    f"parse properly. Are you using the correct separator and encoding "
                    f"for this file? If this file really has just one column, the "
                    f"column name must be registered in the raw file config before "
                    f"upload.")

        return columns

    @staticmethod
    def _convert_non_allowable_bq_column_chars(column_name: str) -> str:
        def is_bq_allowable_column_char(x: str) -> bool:
            return x in string.ascii_letters or x in string.digits or x == "_"

        column_name = "".join([
            c if is_bq_allowable_column_char(c) else "_" for c in column_name
        ])
        return column_name

    @staticmethod
    def _create_raw_table_schema_from_columns(
        columns: List[str], ) -> List[bigquery.SchemaField]:
        """Creates schema for use in `to_gbq` based on the provided columns."""
        schema = []
        for name in columns:
            typ_str = bigquery.enums.SqlTypeNames.STRING.value
            mode = "NULLABLE"
            if name == FILE_ID_COL_NAME:
                mode = "REQUIRED"
                typ_str = bigquery.enums.SqlTypeNames.INTEGER.value
            if name == UPDATE_DATETIME_COL_NAME:
                mode = "REQUIRED"
                typ_str = bigquery.enums.SqlTypeNames.DATETIME.value
            schema.append(
                bigquery.SchemaField(name=name, field_type=typ_str, mode=mode))
        return schema

    @staticmethod
    def _common_read_csv_kwargs(
        file_config: DirectIngestRawFileConfig, ) -> Dict[str, Any]:
        """Returns a set of arguments to be passed to the pandas.read_csv() call, based
        on the provided raw file config.
        """
        kwargs = {
            "sep":
            file_config.separator,
            "quoting": (csv.QUOTE_NONE
                        if file_config.ignore_quotes else csv.QUOTE_MINIMAL),
        }

        if file_config.custom_line_terminator:
            kwargs["lineterminator"] = file_config.custom_line_terminator

        # We get the following warning if we do not override the
        # engine in this case: "ParserWarning: Falling back to the 'python'
        # engine because the separator encoded in utf-8 is > 1 char
        # long, and the 'c' engine does not support such separators;
        # you can avoid this warning by specifying engine='python'.
        if len(file_config.separator.encode(UTF_8_ENCODING)) > 1:
            # The python engine is slower but more feature-complete.
            kwargs["engine"] = "python"

        return kwargs
Esempio n. 13
0
 def __init__(self, ingest_bucket_path: GcsfsBucketPath):
     super().__init__(ingest_bucket_path)
     self.csv_reader = GcsfsCsvReader(GcsfsFactory.build())
Esempio n. 14
0
class CsvGcsfsDirectIngestController(BaseDirectIngestController):
    """Direct ingest controller for regions that read CSV files from the
    GCSFileSystem.
    """
    def __init__(self, ingest_bucket_path: GcsfsBucketPath):
        super().__init__(ingest_bucket_path)
        self.csv_reader = GcsfsCsvReader(GcsfsFactory.build())

    @abc.abstractmethod
    def get_file_tag_rank_list(self) -> List[str]:
        pass

    def _file_meets_file_line_limit(self, line_limit: int,
                                    path: GcsfsFilePath) -> bool:
        delegate = ReadOneGcsfsCsvReaderDelegate()

        # Read a chunk up to one line bigger than the acceptable size
        try:
            self.csv_reader.streaming_read(path,
                                           delegate=delegate,
                                           chunk_size=(line_limit + 1))
        except GCSBlobDoesNotExistError:
            return True

        if delegate.df is None:
            # If the file is empty, it's fine.
            return True

        # If length of the only chunk is less than or equal to the acceptable
        # size, file meets line limit.
        return len(delegate.df) <= line_limit

    def _split_file(self, path: GcsfsFilePath) -> List[GcsfsFilePath]:
        parts = filename_parts_from_path(path)

        if parts.file_type == GcsfsDirectIngestFileType.RAW_DATA:
            raise ValueError(
                f"Splitting raw files unsupported. Attempting to split [{path.abs_path()}]"
            )

        delegate = DirectIngestFileSplittingGcsfsCsvReaderDelegate(
            path, self.fs, self.temp_output_directory_path)
        self.csv_reader.streaming_read(
            path,
            delegate=delegate,
            chunk_size=self.ingest_file_split_line_limit)
        output_paths = [path for path, _ in delegate.output_paths_with_columns]

        return output_paths

    def _yaml_filepath(self, file_tag: str) -> str:
        return os.path.join(
            os.path.dirname(self.region.region_module.__file__),
            self.region.region_code.lower(),
            f"{self.region.region_code.lower()}_{file_tag}.yaml",
        )

    def _parse(self, args: GcsfsIngestArgs,
               contents_handle: GcsfsFileContentsHandle) -> IngestInfo:
        file_tag = self.file_tag(args.file_path)
        gating_context = IngestGatingContext(
            file_tag=file_tag, ingest_instance=self.ingest_instance)

        if file_tag not in self.get_file_tag_rank_list():
            raise DirectIngestError(
                msg=f"No mapping found for tag [{file_tag}]",
                error_type=DirectIngestErrorType.INPUT_ERROR,
            )

        file_mapping = self._yaml_filepath(file_tag)

        row_pre_processors = self._get_row_pre_processors_for_file(file_tag)
        row_post_processors = self._get_row_post_processors_for_file(file_tag)
        file_post_processors = self._get_file_post_processors_for_file(
            file_tag)
        # pylint: disable=assignment-from-none
        primary_key_override_callback = self._get_primary_key_override_for_file(
            file_tag)
        # pylint: disable=assignment-from-none
        ancestor_chain_overrides_callback = (
            self._get_ancestor_chain_overrides_callback_for_file(file_tag))
        should_set_with_empty_values = (
            gating_context.file_tag
            in self._get_files_to_set_with_empty_values())

        data_extractor = CsvDataExtractor(
            file_mapping,
            gating_context,
            row_pre_processors,
            row_post_processors,
            file_post_processors,
            ancestor_chain_overrides_callback,
            primary_key_override_callback,
            self.system_level,
            should_set_with_empty_values,
        )

        return data_extractor.extract_and_populate_data(
            contents_handle.get_contents_iterator())

    def _are_contents_empty(self, args: GcsfsIngestArgs,
                            contents_handle: GcsfsFileContentsHandle) -> bool:
        """Returns true if the CSV file is empty, i.e. it contains no non-header
        rows.
        """
        delegate = ReadOneGcsfsCsvReaderDelegate()
        self.csv_reader.streaming_read(args.file_path,
                                       delegate=delegate,
                                       chunk_size=1,
                                       skiprows=1)
        return delegate.df is None

    def _get_row_pre_processors_for_file(
            self, _file_tag: str) -> List[IngestRowPrehookCallable]:
        """Subclasses should override to return row_pre_processors for a given
        file tag.
        """
        return []

    def _get_row_post_processors_for_file(
            self, _file_tag: str) -> List[IngestRowPosthookCallable]:
        """Subclasses should override to return row_post_processors for a given
        file tag.
        """
        return []

    def _get_file_post_processors_for_file(
            self, _file_tag: str) -> List[IngestFilePostprocessorCallable]:
        """Subclasses should override to return file_post_processors for a given
        file tag.
        """
        return []

    def _get_ancestor_chain_overrides_callback_for_file(
            self,
            _file_tag: str) -> Optional[IngestAncestorChainOverridesCallable]:
        """Subclasses should override to return an
        ancestor_chain_overrides_callback for a given file tag.
        """
        return None

    def _get_primary_key_override_for_file(
            self,
            _file_tag: str) -> Optional[IngestPrimaryKeyOverrideCallable]:
        """Subclasses should override to return a primary_key_override for a
        given file tag.
        """
        return None

    def _get_files_to_set_with_empty_values(self) -> List[str]:
        """Subclasses should override to return which files to set with empty
        values (see CsvDataExtractor).
        """
        return []
Esempio n. 15
0
 def setUp(self) -> None:
     self.fake_gcs = FakeGCSFileSystem()
     self.reader = GcsfsCsvReader(self.fake_gcs)