Ejemplo n.º 1
0
 def create_source(self,
                   file_pattern,
                   min_bundle_size=0,
                   compression_type=CompressionTypes.AUTO,
                   strip_trailing_newlines=True,
                   coder=coders.StrUtf8Coder(),
                   validate=True,
                   skip_header_lines=0):
     return TextSource(file_pattern=file_pattern,
                       min_bundle_size=min_bundle_size,
                       compression_type=compression_type,
                       strip_trailing_newlines=strip_trailing_newlines,
                       coder=coders.StrUtf8Coder(),
                       validate=validate,
                       skip_header_lines=skip_header_lines)
Ejemplo n.º 2
0
    def test_header_processing(self):
        file_name, expected_data = write_data(10)
        assert len(expected_data) == 10

        def header_matcher(line):
            return line in expected_data[:5]

        header_lines = []

        def store_header(lines):
            for line in lines:
                header_lines.append(line)

        source = TextSource(file_name,
                            0,
                            CompressionTypes.UNCOMPRESSED,
                            True,
                            coders.StrUtf8Coder(),
                            header_processor_fns=(header_matcher,
                                                  store_header))
        splits = list(source.split(desired_bundle_size=100000))
        assert len(splits) == 1
        range_tracker = splits[0].source.get_range_tracker(
            splits[0].start_position, splits[0].stop_position)
        read_data = list(source.read_records(file_name, range_tracker))

        self.assertCountEqual(expected_data[:5], header_lines)
        self.assertCountEqual(expected_data[5:], read_data)
Ejemplo n.º 3
0
 def test_read_gzip_empty_file(self):
     file_name = self._create_temp_file()
     pipeline = TestPipeline()
     pcoll = pipeline | 'Read' >> ReadFromText(
         file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder())
     assert_that(pcoll, equal_to([]))
     pipeline.run()
Ejemplo n.º 4
0
 def test_read_gzip_empty_file(self):
   with TempDir() as tempdir:
     file_name = tempdir.create_temp_file()
     with TestPipeline() as pipeline:
       pcoll = pipeline | 'Read' >> ReadFromText(
           file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder())
       assert_that(pcoll, equal_to([]))
Ejemplo n.º 5
0
  def test_progress(self):
    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10
    source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                        coders.StrUtf8Coder())
    splits = list(source.split(desired_bundle_size=100000))
    assert len(splits) == 1
    fraction_consumed_report = []
    split_points_report = []
    range_tracker = splits[0].source.get_range_tracker(
        splits[0].start_position, splits[0].stop_position)
    for _ in splits[0].source.read(range_tracker):
      fraction_consumed_report.append(range_tracker.fraction_consumed())
      split_points_report.append(range_tracker.split_points())

    self.assertEqual(
        [float(i) / 10 for i in range(0, 10)], fraction_consumed_report)
    expected_split_points_report = [
        ((i - 1), iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
        for i in range(1, 10)]

    # At last split point, the remaining split points callback returns 1 since
    # the expected position of next record becomes equal to the stop position.
    expected_split_points_report.append((9, 1))

    self.assertEqual(
        expected_split_points_report, split_points_report)
Ejemplo n.º 6
0
  def __init__(
      self,
      file_pattern=None,
      min_bundle_size=0,
      compression_type=fileio.CompressionTypes.AUTO,
      strip_trailing_newlines=True,
      coder=coders.StrUtf8Coder(),
      validate=True,
      **kwargs):
    """Initialize the ReadFromText transform.

    Args:
      file_pattern: The file path to read from as a local file path or a GCS
        gs:// path. The path can contain glob characters (*, ?, and [...]
        sets).
      min_bundle_size: Minimum size of bundles that should be generated when
                       splitting this source into bundles. See
                       ``FileBasedSource`` for more details.
      compression_type: Used to handle compressed input files. Typical value
          is CompressionTypes.AUTO, in which case the underlying file_path's
          extension will be used to detect the compression.
      strip_trailing_newlines: Indicates whether this source should remove
                               the newline char in each line it reads before
                               decoding that line.
      validate: flag to verify that the files exist during the pipeline
                creation time.
      coder: Coder used to decode each line.
    """

    super(ReadFromText, self).__init__(**kwargs)
    self._strip_trailing_newlines = strip_trailing_newlines
    self._source = _TextSource(file_pattern, min_bundle_size, compression_type,
                               strip_trailing_newlines, coder,
                               validate=validate)
Ejemplo n.º 7
0
 def test_read_gzip_empty_file(self):
     filename = tempfile.NamedTemporaryFile(delete=False,
                                            prefix=tempfile.template).name
     pipeline = TestPipeline()
     pcoll = pipeline | 'Read' >> ReadFromText(
         filename, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder())
     assert_that(pcoll, equal_to([]))
     pipeline.run()
Ejemplo n.º 8
0
 def test_read_reentrant_after_splitting(self):
   file_name, expected_data = write_data(10)
   assert len(expected_data) == 10
   source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                       coders.StrUtf8Coder())
   splits = list(source.split(desired_bundle_size=100000))
   assert len(splits) == 1
   source_test_utils.assert_reentrant_reads_succeed(
       (splits[0].source, splits[0].start_position, splits[0].stop_position))
Ejemplo n.º 9
0
 def test_dynamic_work_rebalancing(self):
   file_name, expected_data = write_data(5)
   assert len(expected_data) == 5
   source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                       coders.StrUtf8Coder())
   splits = list(source.split(desired_bundle_size=100000))
   assert len(splits) == 1
   source_test_utils.assert_split_at_fraction_exhaustive(
       splits[0].source, splits[0].start_position, splits[0].stop_position)
Ejemplo n.º 10
0
 def __init__(self,
              topic,
              subscription=None,
              id_label=None,
              coder=coders.StrUtf8Coder()):
     self.topic = topic
     self.subscription = subscription
     self.id_label = id_label
     self.coder = coder
Ejemplo n.º 11
0
 def test_dynamic_work_rebalancing_windows_eol(self):
   file_name, expected_data = write_data(15, eol=EOL.CRLF)
   assert len(expected_data) == 15
   source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                       coders.StrUtf8Coder())
   splits = list(source.split(desired_bundle_size=100000))
   assert len(splits) == 1
   source_test_utils.assert_split_at_fraction_exhaustive(
       splits[0].source, splits[0].start_position, splits[0].stop_position,
       perform_multi_threaded_test=False)
Ejemplo n.º 12
0
 def test_dynamic_work_rebalancing_mixed_eol(self):
   file_name, expected_data = write_data(5, eol=EOL.MIXED)
   assert len(expected_data) == 5
   source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                       coders.StrUtf8Coder())
   splits = [split for split in source.split(desired_bundle_size=100000)]
   assert len(splits) == 1
   source_test_utils.assertSplitAtFractionExhaustive(
       splits[0].source, splits[0].start_position, splits[0].stop_position,
       perform_multi_threaded_test=False)
Ejemplo n.º 13
0
 def test_read_reentrant_without_splitting(self):
   file_name, expected_data = write_data(10)
   assert len(expected_data) == 10
   source = TextSource(
       file_name,
       0,
       CompressionTypes.UNCOMPRESSED,
       True,
       coders.StrUtf8Coder())
   source_test_utils.assert_reentrant_reads_succeed((source, None, None))
Ejemplo n.º 14
0
  def test_read_single_file_without_striping_eol_crlf(self):
    file_name, written_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS,
                                         eol=EOL.CRLF)
    assert len(written_data) == TextSourceTest.DEFAULT_NUM_RECORDS
    source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED,
                        False, coders.StrUtf8Coder())

    range_tracker = source.get_range_tracker(None, None)
    read_data = list(source.read(range_tracker))
    self.assertCountEqual([line + '\r\n' for line in written_data], read_data)
Ejemplo n.º 15
0
    def test_read_gzip(self):
        _, lines = write_data(15)
        file_name = self._create_temp_file()
        with gzip.GzipFile(file_name, 'wb') as f:
            f.write('\n'.join(lines))

        pipeline = TestPipeline()
        pcoll = pipeline | 'Read' >> ReadFromText(
            file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder())
        assert_that(pcoll, equal_to(lines))
        pipeline.run()
Ejemplo n.º 16
0
  def test_read_gzip(self):
    _, lines = write_data(15)
    with TempDir() as tempdir:
      file_name = tempdir.create_temp_file()
      with gzip.GzipFile(file_name, 'wb') as f:
        f.write('\n'.join(lines).encode('utf-8'))

      with TestPipeline() as pipeline:
        pcoll = pipeline | 'Read' >> ReadFromText(
            file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder())
        assert_that(pcoll, equal_to(lines))
Ejemplo n.º 17
0
  def __init__(self, file_pattern,
               compression_type=fileio.CompressionTypes.AUTO,
               coder=coders.StrUtf8Coder(), validate=True):
    """ Initialize a JsonLinesFileSource.
    """

    super(self.__class__, self).__init__(file_pattern, min_bundle_size=0,
                                         compression_type=compression_type,
                                         validate=validate,
                                         splittable=False)
    self._coder = coder
Ejemplo n.º 18
0
    def _read_skip_header_lines(self, file_or_pattern, skip_header_lines):
        """Simple wrapper function for instantiating TextSource."""
        source = TextSource(file_or_pattern,
                            0,
                            CompressionTypes.UNCOMPRESSED,
                            True,
                            coders.StrUtf8Coder(),
                            skip_header_lines=skip_header_lines)

        range_tracker = source.get_range_tracker(None, None)
        return list(source.read(range_tracker))
Ejemplo n.º 19
0
 def _run_read_test(self, file_or_pattern, expected_data,
                    buffer_size=DEFAULT_NUM_RECORDS,
                    compression=CompressionTypes.UNCOMPRESSED):
   # Since each record usually takes more than 1 byte, default buffer size is
   # smaller than the total size of the file. This is done to
   # increase test coverage for cases that hit the buffer boundary.
   source = TextSource(file_or_pattern, 0, compression,
                       True, coders.StrUtf8Coder(), buffer_size)
   range_tracker = source.get_range_tracker(None, None)
   read_data = list(source.read(range_tracker))
   self.assertCountEqual(expected_data, read_data)
Ejemplo n.º 20
0
    def test_read_after_splitting(self):
        file_name, expected_data = write_data(10)
        assert len(expected_data) == 10
        source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                            coders.StrUtf8Coder())
        splits = [split for split in source.split(desired_bundle_size=33)]

        reference_source_info = (source, None, None)
        sources_info = ([(split.source, split.start_position,
                          split.stop_position) for split in splits])
        source_test_utils.assertSourcesEqualReferenceSource(
            reference_source_info, sources_info)
Ejemplo n.º 21
0
    def test_read_gzip_large(self):
        _, lines = write_data(10000)
        file_name = tempfile.NamedTemporaryFile(delete=False,
                                                prefix=tempfile.template).name
        with gzip.GzipFile(file_name, 'wb') as f:
            f.write('\n'.join(lines))

        pipeline = TestPipeline()
        pcoll = pipeline | 'Read' >> ReadFromText(
            file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder())
        assert_that(pcoll, equal_to(lines))
        pipeline.run()
Ejemplo n.º 22
0
    def test_read_deflate(self):
        _, lines = write_data(15)
        with TempDir() as tempdir:
            file_name = tempdir.create_temp_file()
            with open(file_name, 'wb') as f:
                f.write(zlib.compress('\n'.join(lines).encode('utf-8')))

            pipeline = TestPipeline()
            pcoll = pipeline | 'Read' >> ReadFromText(
                file_name, 0, CompressionTypes.DEFLATE, True,
                coders.StrUtf8Coder())
            assert_that(pcoll, equal_to(lines))
            pipeline.run()
Ejemplo n.º 23
0
  def test_read_gzip_with_skip_lines(self):
    _, lines = write_data(15)
    with TempDir() as tempdir:
      file_name = tempdir.create_temp_file()
      with gzip.GzipFile(file_name, 'wb') as f:
        f.write('\n'.join(lines))

      pipeline = TestPipeline()
      pcoll = pipeline | 'Read' >> ReadFromText(
          file_name, 0, CompressionTypes.GZIP,
          True, coders.StrUtf8Coder(), skip_header_lines=2)
      assert_that(pcoll, equal_to(lines[2:]))
      pipeline.run()
Ejemplo n.º 24
0
    def test_progress(self):
        file_name, expected_data = write_data(10)
        assert len(expected_data) == 10
        source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                            coders.StrUtf8Coder())
        splits = [split for split in source.split(desired_bundle_size=100000)]
        assert len(splits) == 1
        fraction_consumed_report = []
        range_tracker = splits[0].source.get_range_tracker(
            splits[0].start_position, splits[0].stop_position)
        for _ in splits[0].source.read(range_tracker):
            fraction_consumed_report.append(range_tracker.fraction_consumed())

        self.assertEqual([float(i) / 10 for i in range(0, 10)],
                         fraction_consumed_report)
Ejemplo n.º 25
0
    def test_read_corrupted_deflate_fails(self):
        _, lines = write_data(15)
        with TempDir() as tempdir:
            file_name = tempdir.create_temp_file()
            with open(file_name, 'wb') as f:
                f.write(zlib.compress('\n'.join(lines).encode('utf-8')))

            with open(file_name, 'wb') as f:
                f.write(b'corrupt')

            with self.assertRaises(Exception):
                with TestPipeline() as pipeline:
                    pcoll = pipeline | 'Read' >> ReadFromText(
                        file_name, 0, CompressionTypes.DEFLATE, True,
                        coders.StrUtf8Coder())
                    assert_that(pcoll, equal_to(lines))
Ejemplo n.º 26
0
    def test_read_corrupted_gzip_fails(self):
        _, lines = write_data(15)
        with TempDir() as tempdir:
            file_name = tempdir.create_temp_file()
            with gzip.GzipFile(file_name, 'wb') as f:
                f.write(b'\n'.join(lines))

            with open(file_name, 'wb') as f:
                f.write('corrupt')

            pipeline = TestPipeline()
            pcoll = pipeline | 'Read' >> ReadFromText(
                file_name, 0, CompressionTypes.GZIP, True,
                coders.StrUtf8Coder())
            assert_that(pcoll, equal_to(lines))

            with self.assertRaises(Exception):
                pipeline.run()
Ejemplo n.º 27
0
  def test_read_after_splitting_skip_header(self):
    file_name, expected_data = write_data(100)
    assert len(expected_data) == 100
    source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                        coders.StrUtf8Coder(), skip_header_lines=2)
    splits = list(source.split(desired_bundle_size=33))

    reference_source_info = (source, None, None)
    sources_info = ([
        (split.source, split.start_position, split.stop_position) for
        split in splits])
    self.assertGreater(len(sources_info), 1)
    reference_lines = source_test_utils.read_from_source(*reference_source_info)
    split_lines = []
    for source_info in sources_info:
      split_lines.extend(source_test_utils.read_from_source(*source_info))

    self.assertEqual(expected_data[2:], reference_lines)
    self.assertEqual(reference_lines, split_lines)
Ejemplo n.º 28
0
    def test_read_gzip_large_after_splitting(self):
        _, lines = write_data(10000)
        file_name = self._create_temp_file()
        with gzip.GzipFile(file_name, 'wb') as f:
            f.write('\n'.join(lines))

        source = TextSource(file_name, 0, CompressionTypes.GZIP, True,
                            coders.StrUtf8Coder())
        splits = [split for split in source.split(desired_bundle_size=1000)]

        if len(splits) > 1:
            raise ValueError(
                'FileBasedSource generated more than one initial split '
                'for a compressed file.')

        reference_source_info = (source, None, None)
        sources_info = ([(split.source, split.start_position,
                          split.stop_position) for split in splits])
        source_test_utils.assert_sources_equal_reference_source(
            reference_source_info, sources_info)
Ejemplo n.º 29
0
    def __init__(self,
                 file_patterns,
                 min_bundle_size=0,
                 compression_type=CompressionTypes.AUTO,
                 strip_trailing_newlines=True,
                 coder=coders.StrUtf8Coder(),
                 validate=True,
                 skip_header_lines=0,
                 **kwargs):
        """Initialize the ReadFromText transform.

    Args:
      file_patterns: The file paths/patterns to read from as local file paths
      or GCS files. Paths/patterns seperated by commas.
      min_bundle_size: Minimum size of bundles that should be generated when
        splitting this source into bundles. See ``FileBasedSource`` for more
        details.
      compression_type: Used to handle compressed input files. Typical value
        is CompressionTypes.AUTO, in which case the underlying file_path's
        extension will be used to detect the compression.
      strip_trailing_newlines: Indicates whether this source should remove
        the newline char in each line it reads before decoding that line.
      coder: Coder used to decode each line.
      validate: flag to verify that the files exist during the pipeline
        creation time.
      skip_header_lines: Number of header lines to skip. Same number is skipped
        from each source file. Must be 0 or higher. Large number of skipped
        lines might impact performance.
       **kwargs: optional args dictionary.
    """

        super(ReadFromMultiFilesText, self).__init__(**kwargs)
        self._source = _MultiTextSource(
            file_patterns,
            min_bundle_size=min_bundle_size,
            compression_type=compression_type,
            strip_trailing_newlines=strip_trailing_newlines,
            coder=coder,
            validate=validate,
            skip_header_lines=skip_header_lines)
Ejemplo n.º 30
0
 def __init__(self, file_to_read, compression_type):
     self.file_to_read = file_to_read
     self.compression_type = compression_type
     self.coder = coders.StrUtf8Coder()