Пример #1
0
    def test_header_processing(self):
        file_name, expected_data = write_data(10)
        assert len(expected_data) == 10

        def header_matcher(line):
            return line in expected_data[:5]

        header_lines = []

        def store_header(lines):
            for line in lines:
                header_lines.append(line)

        source = TextSource(file_name,
                            0,
                            CompressionTypes.UNCOMPRESSED,
                            True,
                            coders.StrUtf8Coder(),
                            header_processor_fns=(header_matcher,
                                                  store_header))
        splits = list(source.split(desired_bundle_size=100000))
        assert len(splits) == 1
        range_tracker = splits[0].source.get_range_tracker(
            splits[0].start_position, splits[0].stop_position)
        read_data = list(source.read_records(file_name, range_tracker))

        self.assertCountEqual(expected_data[:5], header_lines)
        self.assertCountEqual(expected_data[5:], read_data)
Пример #2
0
  def test_progress(self):
    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10
    source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                        coders.StrUtf8Coder())
    splits = list(source.split(desired_bundle_size=100000))
    assert len(splits) == 1
    fraction_consumed_report = []
    split_points_report = []
    range_tracker = splits[0].source.get_range_tracker(
        splits[0].start_position, splits[0].stop_position)
    for _ in splits[0].source.read(range_tracker):
      fraction_consumed_report.append(range_tracker.fraction_consumed())
      split_points_report.append(range_tracker.split_points())

    self.assertEqual(
        [float(i) / 10 for i in range(0, 10)], fraction_consumed_report)
    expected_split_points_report = [
        ((i - 1), iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
        for i in range(1, 10)]

    # At last split point, the remaining split points callback returns 1 since
    # the expected position of next record becomes equal to the stop position.
    expected_split_points_report.append((9, 1))

    self.assertEqual(
        expected_split_points_report, split_points_report)
Пример #3
0
  def test_progress(self):
    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10
    source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                        coders.StrUtf8Coder())
    splits = list(source.split(desired_bundle_size=100000))
    assert len(splits) == 1
    fraction_consumed_report = []
    split_points_report = []
    range_tracker = splits[0].source.get_range_tracker(
        splits[0].start_position, splits[0].stop_position)
    for _ in splits[0].source.read(range_tracker):
      fraction_consumed_report.append(range_tracker.fraction_consumed())
      split_points_report.append(range_tracker.split_points())

    self.assertEqual(
        [float(i) / 10 for i in range(0, 10)], fraction_consumed_report)
    expected_split_points_report = [
        ((i - 1), iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
        for i in range(1, 10)]

    # At last split point, the remaining split points callback returns 1 since
    # the expected position of next record becomes equal to the stop position.
    expected_split_points_report.append((9, 1))

    self.assertEqual(
        expected_split_points_report, split_points_report)
Пример #4
0
 def test_read_reentrant_after_splitting(self):
   file_name, expected_data = write_data(10)
   assert len(expected_data) == 10
   source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                       coders.StrUtf8Coder())
   splits = list(source.split(desired_bundle_size=100000))
   assert len(splits) == 1
   source_test_utils.assert_reentrant_reads_succeed(
       (splits[0].source, splits[0].start_position, splits[0].stop_position))
Пример #5
0
 def test_read_reentrant_after_splitting(self):
   file_name, expected_data = write_data(10)
   assert len(expected_data) == 10
   source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                       coders.StrUtf8Coder())
   splits = list(source.split(desired_bundle_size=100000))
   assert len(splits) == 1
   source_test_utils.assert_reentrant_reads_succeed(
       (splits[0].source, splits[0].start_position, splits[0].stop_position))
Пример #6
0
 def test_dynamic_work_rebalancing(self):
   file_name, expected_data = write_data(5)
   assert len(expected_data) == 5
   source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                       coders.StrUtf8Coder())
   splits = list(source.split(desired_bundle_size=100000))
   assert len(splits) == 1
   source_test_utils.assert_split_at_fraction_exhaustive(
       splits[0].source, splits[0].start_position, splits[0].stop_position)
Пример #7
0
 def test_dynamic_work_rebalancing(self):
   file_name, expected_data = write_data(5)
   assert len(expected_data) == 5
   source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                       coders.StrUtf8Coder())
   splits = list(source.split(desired_bundle_size=100000))
   assert len(splits) == 1
   source_test_utils.assert_split_at_fraction_exhaustive(
       splits[0].source, splits[0].start_position, splits[0].stop_position)
Пример #8
0
 def test_dynamic_work_rebalancing_windows_eol(self):
   file_name, expected_data = write_data(15, eol=EOL.CRLF)
   assert len(expected_data) == 15
   source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                       coders.StrUtf8Coder())
   splits = list(source.split(desired_bundle_size=100000))
   assert len(splits) == 1
   source_test_utils.assert_split_at_fraction_exhaustive(
       splits[0].source, splits[0].start_position, splits[0].stop_position,
       perform_multi_threaded_test=False)
Пример #9
0
 def test_dynamic_work_rebalancing_mixed_eol(self):
   file_name, expected_data = write_data(5, eol=EOL.MIXED)
   assert len(expected_data) == 5
   source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                       coders.StrUtf8Coder())
   splits = [split for split in source.split(desired_bundle_size=100000)]
   assert len(splits) == 1
   source_test_utils.assertSplitAtFractionExhaustive(
       splits[0].source, splits[0].start_position, splits[0].stop_position,
       perform_multi_threaded_test=False)
Пример #10
0
 def test_dynamic_work_rebalancing_windows_eol(self):
   file_name, expected_data = write_data(15, eol=EOL.CRLF)
   assert len(expected_data) == 15
   source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                       coders.StrUtf8Coder())
   splits = list(source.split(desired_bundle_size=100000))
   assert len(splits) == 1
   source_test_utils.assert_split_at_fraction_exhaustive(
       splits[0].source, splits[0].start_position, splits[0].stop_position,
       perform_multi_threaded_test=False)
Пример #11
0
 def test_dynamic_work_rebalancing_mixed_eol(self):
   file_name, expected_data = write_data(5, eol=EOL.MIXED)
   assert len(expected_data) == 5
   source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                       coders.StrUtf8Coder())
   splits = [split for split in source.split(desired_bundle_size=100000)]
   assert len(splits) == 1
   source_test_utils.assertSplitAtFractionExhaustive(
       splits[0].source, splits[0].start_position, splits[0].stop_position,
       perform_multi_threaded_test=False)
Пример #12
0
    def test_read_after_splitting(self):
        file_name, expected_data = write_data(10)
        assert len(expected_data) == 10
        source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                            coders.StrUtf8Coder())
        splits = [split for split in source.split(desired_bundle_size=33)]

        reference_source_info = (source, None, None)
        sources_info = ([(split.source, split.start_position,
                          split.stop_position) for split in splits])
        source_test_utils.assertSourcesEqualReferenceSource(
            reference_source_info, sources_info)
Пример #13
0
  def test_read_after_splitting(self):
    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10
    source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                        coders.StrUtf8Coder())
    splits = list(source.split(desired_bundle_size=33))

    reference_source_info = (source, None, None)
    sources_info = ([
        (split.source, split.start_position, split.stop_position) for
        split in splits])
    source_test_utils.assert_sources_equal_reference_source(
        reference_source_info, sources_info)
Пример #14
0
    def test_progress(self):
        file_name, expected_data = write_data(10)
        assert len(expected_data) == 10
        source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                            coders.StrUtf8Coder())
        splits = [split for split in source.split(desired_bundle_size=100000)]
        assert len(splits) == 1
        fraction_consumed_report = []
        range_tracker = splits[0].source.get_range_tracker(
            splits[0].start_position, splits[0].stop_position)
        for _ in splits[0].source.read(range_tracker):
            fraction_consumed_report.append(range_tracker.fraction_consumed())

        self.assertEqual([float(i) / 10 for i in range(0, 10)],
                         fraction_consumed_report)
Пример #15
0
  def test_progress(self):
    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10
    source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                        coders.StrUtf8Coder())
    splits = [split for split in source.split(desired_bundle_size=100000)]
    assert len(splits) == 1
    fraction_consumed_report = []
    range_tracker = splits[0].source.get_range_tracker(
        splits[0].start_position, splits[0].stop_position)
    for _ in splits[0].source.read(range_tracker):
      fraction_consumed_report.append(range_tracker.fraction_consumed())

    self.assertEqual(
        [float(i) / 10 for i in range(0, 10)], fraction_consumed_report)
Пример #16
0
  def test_read_after_splitting_skip_header(self):
    file_name, expected_data = write_data(100)
    assert len(expected_data) == 100
    source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                        coders.StrUtf8Coder(), skip_header_lines=2)
    splits = list(source.split(desired_bundle_size=33))

    reference_source_info = (source, None, None)
    sources_info = ([
        (split.source, split.start_position, split.stop_position) for
        split in splits])
    self.assertGreater(len(sources_info), 1)
    reference_lines = source_test_utils.read_from_source(*reference_source_info)
    split_lines = []
    for source_info in sources_info:
      split_lines.extend(source_test_utils.read_from_source(*source_info))

    self.assertEqual(expected_data[2:], reference_lines)
    self.assertEqual(reference_lines, split_lines)
Пример #17
0
  def test_read_after_splitting_skip_header(self):
    file_name, expected_data = write_data(100)
    assert len(expected_data) == 100
    source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                        coders.StrUtf8Coder(), skip_header_lines=2)
    splits = list(source.split(desired_bundle_size=33))

    reference_source_info = (source, None, None)
    sources_info = ([
        (split.source, split.start_position, split.stop_position) for
        split in splits])
    self.assertGreater(len(sources_info), 1)
    reference_lines = source_test_utils.read_from_source(*reference_source_info)
    split_lines = []
    for source_info in sources_info:
      split_lines.extend(source_test_utils.read_from_source(*source_info))

    self.assertEqual(expected_data[2:], reference_lines)
    self.assertEqual(reference_lines, split_lines)
Пример #18
0
    def test_read_gzip_large_after_splitting(self):
        _, lines = write_data(10000)
        file_name = self._create_temp_file()
        with gzip.GzipFile(file_name, 'wb') as f:
            f.write('\n'.join(lines))

        source = TextSource(file_name, 0, CompressionTypes.GZIP, True,
                            coders.StrUtf8Coder())
        splits = [split for split in source.split(desired_bundle_size=1000)]

        if len(splits) > 1:
            raise ValueError(
                'FileBasedSource generated more than one initial split '
                'for a compressed file.')

        reference_source_info = (source, None, None)
        sources_info = ([(split.source, split.start_position,
                          split.stop_position) for split in splits])
        source_test_utils.assert_sources_equal_reference_source(
            reference_source_info, sources_info)
Пример #19
0
  def test_read_gzip_large_after_splitting(self):
    _, lines = write_data(10000)
    file_name = self._create_temp_file()
    with gzip.GzipFile(file_name, 'wb') as f:
      f.write('\n'.join(lines))

    source = TextSource(file_name, 0, CompressionTypes.GZIP, True,
                        coders.StrUtf8Coder())
    splits = [split for split in source.split(desired_bundle_size=1000)]

    if len(splits) > 1:
      raise ValueError('FileBasedSource generated more than one initial split '
                       'for a compressed file.')

    reference_source_info = (source, None, None)
    sources_info = ([
        (split.source, split.start_position, split.stop_position) for
        split in splits])
    source_test_utils.assert_sources_equal_reference_source(
        reference_source_info, sources_info)
Пример #20
0
  def test_header_processing(self):
    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10

    def header_matcher(line):
      return line in expected_data[:5]

    header_lines = []

    def store_header(lines):
      for line in lines:
        header_lines.append(line)

    source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                        coders.StrUtf8Coder(),
                        header_processor_fns=(header_matcher, store_header))
    splits = list(source.split(desired_bundle_size=100000))
    assert len(splits) == 1
    range_tracker = splits[0].source.get_range_tracker(
        splits[0].start_position, splits[0].stop_position)
    read_data = list(source.read_records(file_name, range_tracker))

    self.assertCountEqual(expected_data[:5], header_lines)
    self.assertCountEqual(expected_data[5:], read_data)