def create_source(self, file_pattern, min_bundle_size=0, compression_type=CompressionTypes.AUTO, strip_trailing_newlines=True, coder=coders.StrUtf8Coder(), validate=True, skip_header_lines=0): return TextSource(file_pattern=file_pattern, min_bundle_size=min_bundle_size, compression_type=compression_type, strip_trailing_newlines=strip_trailing_newlines, coder=coders.StrUtf8Coder(), validate=validate, skip_header_lines=skip_header_lines)
def test_header_processing(self): file_name, expected_data = write_data(10) assert len(expected_data) == 10 def header_matcher(line): return line in expected_data[:5] header_lines = [] def store_header(lines): for line in lines: header_lines.append(line) source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder(), header_processor_fns=(header_matcher, store_header)) splits = list(source.split(desired_bundle_size=100000)) assert len(splits) == 1 range_tracker = splits[0].source.get_range_tracker( splits[0].start_position, splits[0].stop_position) read_data = list(source.read_records(file_name, range_tracker)) self.assertCountEqual(expected_data[:5], header_lines) self.assertCountEqual(expected_data[5:], read_data)
def test_read_gzip_empty_file(self): file_name = self._create_temp_file() pipeline = TestPipeline() pcoll = pipeline | 'Read' >> ReadFromText( file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder()) assert_that(pcoll, equal_to([])) pipeline.run()
def test_read_gzip_empty_file(self): with TempDir() as tempdir: file_name = tempdir.create_temp_file() with TestPipeline() as pipeline: pcoll = pipeline | 'Read' >> ReadFromText( file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder()) assert_that(pcoll, equal_to([]))
def test_progress(self): file_name, expected_data = write_data(10) assert len(expected_data) == 10 source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder()) splits = list(source.split(desired_bundle_size=100000)) assert len(splits) == 1 fraction_consumed_report = [] split_points_report = [] range_tracker = splits[0].source.get_range_tracker( splits[0].start_position, splits[0].stop_position) for _ in splits[0].source.read(range_tracker): fraction_consumed_report.append(range_tracker.fraction_consumed()) split_points_report.append(range_tracker.split_points()) self.assertEqual( [float(i) / 10 for i in range(0, 10)], fraction_consumed_report) expected_split_points_report = [ ((i - 1), iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) for i in range(1, 10)] # At last split point, the remaining split points callback returns 1 since # the expected position of next record becomes equal to the stop position. expected_split_points_report.append((9, 1)) self.assertEqual( expected_split_points_report, split_points_report)
def __init__( self, file_pattern=None, min_bundle_size=0, compression_type=fileio.CompressionTypes.AUTO, strip_trailing_newlines=True, coder=coders.StrUtf8Coder(), validate=True, **kwargs): """Initialize the ReadFromText transform. Args: file_pattern: The file path to read from as a local file path or a GCS gs:// path. The path can contain glob characters (*, ?, and [...] sets). min_bundle_size: Minimum size of bundles that should be generated when splitting this source into bundles. See ``FileBasedSource`` for more details. compression_type: Used to handle compressed input files. Typical value is CompressionTypes.AUTO, in which case the underlying file_path's extension will be used to detect the compression. strip_trailing_newlines: Indicates whether this source should remove the newline char in each line it reads before decoding that line. validate: flag to verify that the files exist during the pipeline creation time. coder: Coder used to decode each line. """ super(ReadFromText, self).__init__(**kwargs) self._strip_trailing_newlines = strip_trailing_newlines self._source = _TextSource(file_pattern, min_bundle_size, compression_type, strip_trailing_newlines, coder, validate=validate)
def test_read_gzip_empty_file(self): filename = tempfile.NamedTemporaryFile(delete=False, prefix=tempfile.template).name pipeline = TestPipeline() pcoll = pipeline | 'Read' >> ReadFromText( filename, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder()) assert_that(pcoll, equal_to([])) pipeline.run()
def test_read_reentrant_after_splitting(self): file_name, expected_data = write_data(10) assert len(expected_data) == 10 source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder()) splits = list(source.split(desired_bundle_size=100000)) assert len(splits) == 1 source_test_utils.assert_reentrant_reads_succeed( (splits[0].source, splits[0].start_position, splits[0].stop_position))
def test_dynamic_work_rebalancing(self): file_name, expected_data = write_data(5) assert len(expected_data) == 5 source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder()) splits = list(source.split(desired_bundle_size=100000)) assert len(splits) == 1 source_test_utils.assert_split_at_fraction_exhaustive( splits[0].source, splits[0].start_position, splits[0].stop_position)
def __init__(self, topic, subscription=None, id_label=None, coder=coders.StrUtf8Coder()): self.topic = topic self.subscription = subscription self.id_label = id_label self.coder = coder
def test_dynamic_work_rebalancing_windows_eol(self): file_name, expected_data = write_data(15, eol=EOL.CRLF) assert len(expected_data) == 15 source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder()) splits = list(source.split(desired_bundle_size=100000)) assert len(splits) == 1 source_test_utils.assert_split_at_fraction_exhaustive( splits[0].source, splits[0].start_position, splits[0].stop_position, perform_multi_threaded_test=False)
def test_dynamic_work_rebalancing_mixed_eol(self): file_name, expected_data = write_data(5, eol=EOL.MIXED) assert len(expected_data) == 5 source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder()) splits = [split for split in source.split(desired_bundle_size=100000)] assert len(splits) == 1 source_test_utils.assertSplitAtFractionExhaustive( splits[0].source, splits[0].start_position, splits[0].stop_position, perform_multi_threaded_test=False)
def test_read_reentrant_without_splitting(self): file_name, expected_data = write_data(10) assert len(expected_data) == 10 source = TextSource( file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder()) source_test_utils.assert_reentrant_reads_succeed((source, None, None))
def test_read_single_file_without_striping_eol_crlf(self): file_name, written_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS, eol=EOL.CRLF) assert len(written_data) == TextSourceTest.DEFAULT_NUM_RECORDS source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, False, coders.StrUtf8Coder()) range_tracker = source.get_range_tracker(None, None) read_data = list(source.read(range_tracker)) self.assertCountEqual([line + '\r\n' for line in written_data], read_data)
def test_read_gzip(self): _, lines = write_data(15) file_name = self._create_temp_file() with gzip.GzipFile(file_name, 'wb') as f: f.write('\n'.join(lines)) pipeline = TestPipeline() pcoll = pipeline | 'Read' >> ReadFromText( file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder()) assert_that(pcoll, equal_to(lines)) pipeline.run()
def test_read_gzip(self): _, lines = write_data(15) with TempDir() as tempdir: file_name = tempdir.create_temp_file() with gzip.GzipFile(file_name, 'wb') as f: f.write('\n'.join(lines).encode('utf-8')) with TestPipeline() as pipeline: pcoll = pipeline | 'Read' >> ReadFromText( file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder()) assert_that(pcoll, equal_to(lines))
def __init__(self, file_pattern, compression_type=fileio.CompressionTypes.AUTO, coder=coders.StrUtf8Coder(), validate=True): """ Initialize a JsonLinesFileSource. """ super(self.__class__, self).__init__(file_pattern, min_bundle_size=0, compression_type=compression_type, validate=validate, splittable=False) self._coder = coder
def _read_skip_header_lines(self, file_or_pattern, skip_header_lines): """Simple wrapper function for instantiating TextSource.""" source = TextSource(file_or_pattern, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder(), skip_header_lines=skip_header_lines) range_tracker = source.get_range_tracker(None, None) return list(source.read(range_tracker))
def _run_read_test(self, file_or_pattern, expected_data, buffer_size=DEFAULT_NUM_RECORDS, compression=CompressionTypes.UNCOMPRESSED): # Since each record usually takes more than 1 byte, default buffer size is # smaller than the total size of the file. This is done to # increase test coverage for cases that hit the buffer boundary. source = TextSource(file_or_pattern, 0, compression, True, coders.StrUtf8Coder(), buffer_size) range_tracker = source.get_range_tracker(None, None) read_data = list(source.read(range_tracker)) self.assertCountEqual(expected_data, read_data)
def test_read_after_splitting(self): file_name, expected_data = write_data(10) assert len(expected_data) == 10 source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder()) splits = [split for split in source.split(desired_bundle_size=33)] reference_source_info = (source, None, None) sources_info = ([(split.source, split.start_position, split.stop_position) for split in splits]) source_test_utils.assertSourcesEqualReferenceSource( reference_source_info, sources_info)
def test_read_gzip_large(self): _, lines = write_data(10000) file_name = tempfile.NamedTemporaryFile(delete=False, prefix=tempfile.template).name with gzip.GzipFile(file_name, 'wb') as f: f.write('\n'.join(lines)) pipeline = TestPipeline() pcoll = pipeline | 'Read' >> ReadFromText( file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder()) assert_that(pcoll, equal_to(lines)) pipeline.run()
def test_read_deflate(self): _, lines = write_data(15) with TempDir() as tempdir: file_name = tempdir.create_temp_file() with open(file_name, 'wb') as f: f.write(zlib.compress('\n'.join(lines).encode('utf-8'))) pipeline = TestPipeline() pcoll = pipeline | 'Read' >> ReadFromText( file_name, 0, CompressionTypes.DEFLATE, True, coders.StrUtf8Coder()) assert_that(pcoll, equal_to(lines)) pipeline.run()
def test_read_gzip_with_skip_lines(self): _, lines = write_data(15) with TempDir() as tempdir: file_name = tempdir.create_temp_file() with gzip.GzipFile(file_name, 'wb') as f: f.write('\n'.join(lines)) pipeline = TestPipeline() pcoll = pipeline | 'Read' >> ReadFromText( file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder(), skip_header_lines=2) assert_that(pcoll, equal_to(lines[2:])) pipeline.run()
def test_progress(self): file_name, expected_data = write_data(10) assert len(expected_data) == 10 source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder()) splits = [split for split in source.split(desired_bundle_size=100000)] assert len(splits) == 1 fraction_consumed_report = [] range_tracker = splits[0].source.get_range_tracker( splits[0].start_position, splits[0].stop_position) for _ in splits[0].source.read(range_tracker): fraction_consumed_report.append(range_tracker.fraction_consumed()) self.assertEqual([float(i) / 10 for i in range(0, 10)], fraction_consumed_report)
def test_read_corrupted_deflate_fails(self): _, lines = write_data(15) with TempDir() as tempdir: file_name = tempdir.create_temp_file() with open(file_name, 'wb') as f: f.write(zlib.compress('\n'.join(lines).encode('utf-8'))) with open(file_name, 'wb') as f: f.write(b'corrupt') with self.assertRaises(Exception): with TestPipeline() as pipeline: pcoll = pipeline | 'Read' >> ReadFromText( file_name, 0, CompressionTypes.DEFLATE, True, coders.StrUtf8Coder()) assert_that(pcoll, equal_to(lines))
def test_read_corrupted_gzip_fails(self): _, lines = write_data(15) with TempDir() as tempdir: file_name = tempdir.create_temp_file() with gzip.GzipFile(file_name, 'wb') as f: f.write(b'\n'.join(lines)) with open(file_name, 'wb') as f: f.write('corrupt') pipeline = TestPipeline() pcoll = pipeline | 'Read' >> ReadFromText( file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder()) assert_that(pcoll, equal_to(lines)) with self.assertRaises(Exception): pipeline.run()
def test_read_after_splitting_skip_header(self): file_name, expected_data = write_data(100) assert len(expected_data) == 100 source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True, coders.StrUtf8Coder(), skip_header_lines=2) splits = list(source.split(desired_bundle_size=33)) reference_source_info = (source, None, None) sources_info = ([ (split.source, split.start_position, split.stop_position) for split in splits]) self.assertGreater(len(sources_info), 1) reference_lines = source_test_utils.read_from_source(*reference_source_info) split_lines = [] for source_info in sources_info: split_lines.extend(source_test_utils.read_from_source(*source_info)) self.assertEqual(expected_data[2:], reference_lines) self.assertEqual(reference_lines, split_lines)
def test_read_gzip_large_after_splitting(self): _, lines = write_data(10000) file_name = self._create_temp_file() with gzip.GzipFile(file_name, 'wb') as f: f.write('\n'.join(lines)) source = TextSource(file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder()) splits = [split for split in source.split(desired_bundle_size=1000)] if len(splits) > 1: raise ValueError( 'FileBasedSource generated more than one initial split ' 'for a compressed file.') reference_source_info = (source, None, None) sources_info = ([(split.source, split.start_position, split.stop_position) for split in splits]) source_test_utils.assert_sources_equal_reference_source( reference_source_info, sources_info)
def __init__(self, file_patterns, min_bundle_size=0, compression_type=CompressionTypes.AUTO, strip_trailing_newlines=True, coder=coders.StrUtf8Coder(), validate=True, skip_header_lines=0, **kwargs): """Initialize the ReadFromText transform. Args: file_patterns: The file paths/patterns to read from as local file paths or GCS files. Paths/patterns seperated by commas. min_bundle_size: Minimum size of bundles that should be generated when splitting this source into bundles. See ``FileBasedSource`` for more details. compression_type: Used to handle compressed input files. Typical value is CompressionTypes.AUTO, in which case the underlying file_path's extension will be used to detect the compression. strip_trailing_newlines: Indicates whether this source should remove the newline char in each line it reads before decoding that line. coder: Coder used to decode each line. validate: flag to verify that the files exist during the pipeline creation time. skip_header_lines: Number of header lines to skip. Same number is skipped from each source file. Must be 0 or higher. Large number of skipped lines might impact performance. **kwargs: optional args dictionary. """ super(ReadFromMultiFilesText, self).__init__(**kwargs) self._source = _MultiTextSource( file_patterns, min_bundle_size=min_bundle_size, compression_type=compression_type, strip_trailing_newlines=strip_trailing_newlines, coder=coder, validate=validate, skip_header_lines=skip_header_lines)
def __init__(self, file_to_read, compression_type): self.file_to_read = file_to_read self.compression_type = compression_type self.coder = coders.StrUtf8Coder()