Ejemplo n.º 1
0
    def test_read_after_splitting(self):
        file_name, expected_data = write_data(10)
        assert len(expected_data) == 10
        source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                            coders.StrUtf8Coder())
        splits = [split for split in source.split(desired_bundle_size=33)]

        reference_source_info = (source, None, None)
        sources_info = ([(split.source, split.start_position,
                          split.stop_position) for split in splits])
        source_test_utils.assertSourcesEqualReferenceSource(
            reference_source_info, sources_info)
Ejemplo n.º 2
0
  def test_source_equals_reference_source(self):
    data = self._create_data(100)
    reference_source = self._create_source(data)
    sources_info = [(split.source, split.start_position, split.stop_position)
                    for split in reference_source.split(desired_bundle_size=50)]
    if len(sources_info) < 2:
      raise ValueError('Test is too trivial since splitting only generated %d'
                       'bundles. Please adjust the test so that at least '
                       'two splits get generated.', len(sources_info))

    source_test_utils.assertSourcesEqualReferenceSource(
        (reference_source, None, None), sources_info)
Ejemplo n.º 3
0
 def check_read_with_initial_splits(self, values, coder, num_splits):
   """A test that splits the given source into `num_splits` and verifies that
   the data read from original source is equal to the union of the data read
   from the split sources.
   """
   source = Create._create_source_from_iterable(values, coder)
   desired_bundle_size = source._total_size / num_splits
   splits = source.split(desired_bundle_size)
   splits_info = [
       (split.source, split.start_position, split.stop_position)
       for split in splits]
   source_test_utils.assertSourcesEqualReferenceSource((source, None, None),
                                                       splits_info)
Ejemplo n.º 4
0
  def test_read_after_splitting(self):
    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10
    source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                        coders.StrUtf8Coder())
    splits = [split for split in source.split(desired_bundle_size=33)]

    reference_source_info = (source, None, None)
    sources_info = ([
        (split.source, split.start_position, split.stop_position) for
        split in splits])
    source_test_utils.assertSourcesEqualReferenceSource(
        reference_source_info, sources_info)
Ejemplo n.º 5
0
    def test_read_gzip_large_after_splitting(self):
        _, lines = write_data(10000)
        file_name = self._create_temp_file()
        with gzip.GzipFile(file_name, 'wb') as f:
            f.write('\n'.join(lines))

        source = TextSource(file_name, 0, CompressionTypes.GZIP, True,
                            coders.StrUtf8Coder())
        splits = [split for split in source.split(desired_bundle_size=1000)]

        if len(splits) > 1:
            raise ValueError(
                'FileBasedSource generated more than one initial split '
                'for a compressed file.')

        reference_source_info = (source, None, None)
        sources_info = ([(split.source, split.start_position,
                          split.stop_position) for split in splits])
        source_test_utils.assertSourcesEqualReferenceSource(
            reference_source_info, sources_info)
Ejemplo n.º 6
0
  def test_read_gzip_large_after_splitting(self):
    _, lines = write_data(10000)
    file_name = self._create_temp_file()
    with gzip.GzipFile(file_name, 'wb') as f:
      f.write('\n'.join(lines))

    source = TextSource(file_name, 0, CompressionTypes.GZIP, True,
                        coders.StrUtf8Coder())
    splits = [split for split in source.split(desired_bundle_size=1000)]

    if len(splits) > 1:
      raise ValueError('FileBasedSource generated more than one initial split '
                       'for a compressed file.')

    reference_source_info = (source, None, None)
    sources_info = ([
        (split.source, split.start_position, split.stop_position) for
        split in splits])
    source_test_utils.assertSourcesEqualReferenceSource(
        reference_source_info, sources_info)
Ejemplo n.º 7
0
    def _run_avro_test(self, pattern, desired_bundle_size, perform_splitting,
                       expected_result):
        source = AvroSource(pattern)

        read_records = []
        if perform_splitting:
            assert desired_bundle_size
            splits = [
                split for split in source.split(
                    desired_bundle_size=desired_bundle_size)
            ]
            if len(splits) < 2:
                raise ValueError(
                    'Test is trivial. Please adjust it so that at least '
                    'two splits get generated')

            sources_info = [(split.source, split.start_position,
                             split.stop_position) for split in splits]
            source_test_utils.assertSourcesEqualReferenceSource(
                (source, None, None), sources_info)
        else:
            read_records = source_test_utils.readFromSource(source, None, None)
            self.assertItemsEqual(expected_result, read_records)
Ejemplo n.º 8
0
  def _run_avro_test(self, pattern, desired_bundle_size, perform_splitting,
                     expected_result):
    source = AvroSource(pattern)

    read_records = []
    if perform_splitting:
      assert desired_bundle_size
      splits = [
          split
          for split in source.split(desired_bundle_size=desired_bundle_size)
      ]
      if len(splits) < 2:
        raise ValueError('Test is trivial. Please adjust it so that at least '
                         'two splits get generated')

      sources_info = [
          (split.source, split.start_position, split.stop_position)
          for split in splits
      ]
      source_test_utils.assertSourcesEqualReferenceSource((source, None, None),
                                                          sources_info)
    else:
      read_records = source_test_utils.readFromSource(source, None, None)
      self.assertItemsEqual(expected_result, read_records)