def test_progress(self): file_name, expected_data = write_data(10) assert len(expected_data) == 10 source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder()) splits = list(source.split(desired_bundle_size=100000)) assert len(splits) == 1 fraction_consumed_report = [] split_points_report = [] range_tracker = splits[0].source.get_range_tracker( splits[0].start_position, splits[0].stop_position) for _ in splits[0].source.read(range_tracker): fraction_consumed_report.append(range_tracker.fraction_consumed()) split_points_report.append(range_tracker.split_points()) self.assertEqual( [float(i) / 10 for i in range(0, 10)], fraction_consumed_report) expected_split_points_report = [ ((i - 1), iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) for i in range(1, 10)] # At last split point, the remaining split points callback returns 1 since # the expected position of next record becomes equal to the stop position. expected_split_points_report.append((9, 1)) self.assertEqual( expected_split_points_report, split_points_report)
def __init__(self, file_name, range_tracker, file_pattern, compression_type, allow_malformed_records, **kwargs): self._header_lines = [] self._last_record = None self._file_name = file_name self._allow_malformed_records = allow_malformed_records text_source = TextSource( file_pattern, 0, # min_bundle_size compression_type, True, # strip_trailing_newlines coders.StrUtf8Coder(), # coder validate=False, header_processor_fns=(lambda x: x.startswith('#'), self._store_header_lines), **kwargs) self._text_lines = text_source.read_records(self._file_name, range_tracker) try: self._vcf_reader = vcf.Reader(fsock=self._create_generator()) except SyntaxError as e: # Throw the exception inside the generator to ensure file is properly # closed (it's opened inside TextSource.read_records). self._text_lines.throw( ValueError('An exception was raised when reading header from VCF ' 'file %s: %s' % (self._file_name, traceback.format_exc(e))))
def test_header_processing(self): file_name, expected_data = write_data(10) assert len(expected_data) == 10 def header_matcher(line): return line in expected_data[:5] header_lines = [] def store_header(lines): for line in lines: header_lines.append(line) source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder(), header_processor_fns=(header_matcher, store_header)) splits = list(source.split(desired_bundle_size=100000)) assert len(splits) == 1 range_tracker = splits[0].source.get_range_tracker( splits[0].start_position, splits[0].stop_position) read_data = list(source.read_records(file_name, range_tracker)) self.assertCountEqual(expected_data[:5], header_lines) self.assertCountEqual(expected_data[5:], read_data)
def test_read_reentrant_after_splitting(self): file_name, expected_data = write_data(10) assert len(expected_data) == 10 source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder()) splits = list(source.split(desired_bundle_size=100000)) assert len(splits) == 1 source_test_utils.assert_reentrant_reads_succeed( (splits[0].source, splits[0].start_position, splits[0].stop_position))
def test_dynamic_work_rebalancing(self): file_name, expected_data = write_data(5) assert len(expected_data) == 5 source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder()) splits = list(source.split(desired_bundle_size=100000)) assert len(splits) == 1 source_test_utils.assert_split_at_fraction_exhaustive( splits[0].source, splits[0].start_position, splits[0].stop_position)
def test_read_single_file_without_striping_eol_crlf(self): file_name, written_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS, eol=EOL.CRLF) assert len(written_data) == TextSourceTest.DEFAULT_NUM_RECORDS source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, False, coders.StrUtf8Coder()) range_tracker = source.get_range_tracker(None, None) read_data = list(source.read(range_tracker)) self.assertCountEqual([line + '\r\n' for line in written_data], read_data)
def test_dynamic_work_rebalancing_mixed_eol(self): file_name, expected_data = write_data(5, eol=EOL.MIXED) assert len(expected_data) == 5 source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder()) splits = [split for split in source.split(desired_bundle_size=100000)] assert len(splits) == 1 source_test_utils.assertSplitAtFractionExhaustive( splits[0].source, splits[0].start_position, splits[0].stop_position, perform_multi_threaded_test=False)
def test_dynamic_work_rebalancing_windows_eol(self): file_name, expected_data = write_data(15, eol=EOL.CRLF) assert len(expected_data) == 15 source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder()) splits = list(source.split(desired_bundle_size=100000)) assert len(splits) == 1 source_test_utils.assert_split_at_fraction_exhaustive( splits[0].source, splits[0].start_position, splits[0].stop_position, perform_multi_threaded_test=False)
def test_read_reentrant_without_splitting(self): file_name, expected_data = write_data(10) assert len(expected_data) == 10 source = TextSource( file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder()) source_test_utils.assert_reentrant_reads_succeed((source, None, None))
def _run_read_test(self, file_or_pattern, expected_data, buffer_size=DEFAULT_NUM_RECORDS, compression=CompressionTypes.UNCOMPRESSED): # Since each record usually takes more than 1 byte, default buffer size is # smaller than the total size of the file. This is done to # increase test coverage for cases that hit the buffer boundary. source = TextSource(file_or_pattern, 0, compression, True, coders.StrUtf8Coder(), buffer_size) range_tracker = source.get_range_tracker(None, None) read_data = list(source.read(range_tracker)) self.assertCountEqual(expected_data, read_data)
def _read_skip_header_lines(self, file_or_pattern, skip_header_lines): """Simple wrapper function for instantiating TextSource.""" source = TextSource(file_or_pattern, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder(), skip_header_lines=skip_header_lines) range_tracker = source.get_range_tracker(None, None) return list(source.read(range_tracker))
def test_read_after_splitting(self): file_name, expected_data = write_data(10) assert len(expected_data) == 10 source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder()) splits = [split for split in source.split(desired_bundle_size=33)] reference_source_info = (source, None, None) sources_info = ([(split.source, split.start_position, split.stop_position) for split in splits]) source_test_utils.assertSourcesEqualReferenceSource( reference_source_info, sources_info)
def create_source(self, file_pattern, min_bundle_size=0, compression_type=CompressionTypes.AUTO, strip_trailing_newlines=True, coder=coders.StrUtf8Coder(), validate=True, skip_header_lines=0): return TextSource(file_pattern=file_pattern, min_bundle_size=min_bundle_size, compression_type=compression_type, strip_trailing_newlines=strip_trailing_newlines, coder=coders.StrUtf8Coder(), validate=validate, skip_header_lines=skip_header_lines)
def test_progress(self): file_name, expected_data = write_data(10) assert len(expected_data) == 10 source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder()) splits = [split for split in source.split(desired_bundle_size=100000)] assert len(splits) == 1 fraction_consumed_report = [] range_tracker = splits[0].source.get_range_tracker( splits[0].start_position, splits[0].stop_position) for _ in splits[0].source.read(range_tracker): fraction_consumed_report.append(range_tracker.fraction_consumed()) self.assertEqual([float(i) / 10 for i in range(0, 10)], fraction_consumed_report)
def test_read_after_splitting_skip_header(self): file_name, expected_data = write_data(100) assert len(expected_data) == 100 source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder(), skip_header_lines=2) splits = list(source.split(desired_bundle_size=33)) reference_source_info = (source, None, None) sources_info = ([ (split.source, split.start_position, split.stop_position) for split in splits]) self.assertGreater(len(sources_info), 1) reference_lines = source_test_utils.read_from_source(*reference_source_info) split_lines = [] for source_info in sources_info: split_lines.extend(source_test_utils.read_from_source(*source_info)) self.assertEqual(expected_data[2:], reference_lines) self.assertEqual(reference_lines, split_lines)
def test_read_gzip_large_after_splitting(self): _, lines = write_data(10000) file_name = self._create_temp_file() with gzip.GzipFile(file_name, 'wb') as f: f.write('\n'.join(lines)) source = TextSource(file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder()) splits = [split for split in source.split(desired_bundle_size=1000)] if len(splits) > 1: raise ValueError( 'FileBasedSource generated more than one initial split ' 'for a compressed file.') reference_source_info = (source, None, None) sources_info = ([(split.source, split.start_position, split.stop_position) for split in splits]) source_test_utils.assert_sources_equal_reference_source( reference_source_info, sources_info)
def split(self, desired_bundle_size, start_position=None, stop_position=None): if self.split_result is None: bq = bigquery_tools.BigQueryWrapper() if self.query is not None: self._setup_temporary_dataset(bq) self.table_reference = self._execute_query(bq) schema, metadata_list = self._export_files(bq) self.split_result = [ TextSource(metadata.path, 0, CompressionTypes.UNCOMPRESSED, True, self.coder(schema)) for metadata in metadata_list ] if self.query is not None: bq.clean_up_temporary_dataset(self.project.get()) for source in self.split_result: yield SourceBundle(0, source, None, None)
def __init__(self, file_name, range_tracker, file_pattern, compression_type, **kwargs): self._header_lines = [] self._last_record = None self._file_name = file_name text_source = TextSource( file_pattern, 0, # min_bundle_size compression_type, True, # strip_trailing_newlines coders.StrUtf8Coder(), # coder validate=False, header_processor_fns=(lambda x: x.startswith('#'), self._store_header_lines), **kwargs) self._text_lines = text_source.read_records( self._file_name, range_tracker) try: self._vcf_reader = vcf.Reader(fsock=self._create_generator()) except SyntaxError as e: raise ValueError('Invalid VCF header %s' % str(e))