Example #1
0
  def test_corrupted_file(self):
    file_name = self._write_data()
    with open(file_name, 'rb') as f:
      data = f.read()

    # Corrupt the last character of the file which is also the last character of
    # the last sync_marker.
    last_char_index = len(data) - 1
    corrupted_data = data[:last_char_index]
    corrupted_data += 'A' if data[last_char_index] == 'B' else 'B'
    with tempfile.NamedTemporaryFile(
        delete=False, prefix=tempfile.template) as f:
      f.write(corrupted_data)
      corrupted_file_name = f.name

    source = AvroSource(corrupted_file_name)
    with self.assertRaises(ValueError) as exn:
      source_test_utils.readFromSource(source, None, None)
      self.assertEqual(0, exn.exception.message.find('Unexpected sync marker'))
Example #2
0
  def test_read_after_splitting_skip_header(self):
    file_name, expected_data = write_data(100)
    assert len(expected_data) == 100
    source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                        coders.StrUtf8Coder(), skip_header_lines=2)
    splits = [split for split in source.split(desired_bundle_size=33)]

    reference_source_info = (source, None, None)
    sources_info = ([
        (split.source, split.start_position, split.stop_position) for
        split in splits])
    self.assertGreater(len(sources_info), 1)
    reference_lines = source_test_utils.readFromSource(*reference_source_info)
    split_lines = []
    for source_info in sources_info:
      split_lines.extend(source_test_utils.readFromSource(*source_info))

    self.assertEqual(expected_data[2:], reference_lines)
    self.assertEqual(reference_lines, split_lines)
Example #3
0
  def test_read_after_splitting_skip_header(self):
    file_name, expected_data = write_data(100)
    assert len(expected_data) == 100
    source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                        coders.StrUtf8Coder(), skip_header_lines=2)
    splits = [split for split in source.split(desired_bundle_size=33)]

    reference_source_info = (source, None, None)
    sources_info = ([
        (split.source, split.start_position, split.stop_position) for
        split in splits])
    self.assertGreater(len(sources_info), 1)
    reference_lines = source_test_utils.readFromSource(*reference_source_info)
    split_lines = []
    for source_info in sources_info:
      split_lines.extend(source_test_utils.readFromSource(*source_info))

    self.assertEqual(expected_data[2:], reference_lines)
    self.assertEqual(reference_lines, split_lines)
Example #4
0
    def test_corrupted_file(self):
        file_name = self._write_data()
        with open(file_name, 'rb') as f:
            data = f.read()

        # Corrupt the last character of the file which is also the last character of
        # the last sync_marker.
        last_char_index = len(data) - 1
        corrupted_data = data[:last_char_index]
        corrupted_data += 'A' if data[last_char_index] == 'B' else 'B'
        with tempfile.NamedTemporaryFile(delete=False,
                                         prefix=tempfile.template) as f:
            f.write(corrupted_data)
            corrupted_file_name = f.name

        source = AvroSource(corrupted_file_name)
        with self.assertRaises(ValueError) as exn:
            source_test_utils.readFromSource(source, None, None)
            self.assertEqual(
                0, exn.exception.message.find('Unexpected sync marker'))
Example #5
0
    def _run_avro_test(self, pattern, desired_bundle_size, perform_splitting,
                       expected_result):
        source = AvroSource(pattern)

        read_records = []
        if perform_splitting:
            assert desired_bundle_size
            splits = [
                split for split in source.split(
                    desired_bundle_size=desired_bundle_size)
            ]
            if len(splits) < 2:
                raise ValueError(
                    'Test is trivial. Please adjust it so that at least '
                    'two splits get generated')

            sources_info = [(split.source, split.start_position,
                             split.stop_position) for split in splits]
            source_test_utils.assertSourcesEqualReferenceSource(
                (source, None, None), sources_info)
        else:
            read_records = source_test_utils.readFromSource(source, None, None)
            self.assertItemsEqual(expected_result, read_records)
Example #6
0
  def _run_avro_test(self, pattern, desired_bundle_size, perform_splitting,
                     expected_result):
    source = AvroSource(pattern)

    read_records = []
    if perform_splitting:
      assert desired_bundle_size
      splits = [
          split
          for split in source.split(desired_bundle_size=desired_bundle_size)
      ]
      if len(splits) < 2:
        raise ValueError('Test is trivial. Please adjust it so that at least '
                         'two splits get generated')

      sources_info = [
          (split.source, split.start_position, split.stop_position)
          for split in splits
      ]
      source_test_utils.assertSourcesEqualReferenceSource((source, None, None),
                                                          sources_info)
    else:
      read_records = source_test_utils.readFromSource(source, None, None)
      self.assertItemsEqual(expected_result, read_records)
Example #7
0
 def check_read(self, values, coder):
   source = Create._create_source_from_iterable(values, coder)
   read_values = source_test_utils.readFromSource(source)
   self.assertEqual(sorted(values), sorted(read_values))
 def test_read_from_source(self):
   data = self._create_data(100)
   source = self._create_source(data)
   self.assertItemsEqual(
       data, source_test_utils.readFromSource(source, None, None))