예제 #1
0
    def test_split_points(self):
        file_name = self._write_data(count=12000)
        source = _create_avro_source(file_name, use_fastavro=self.use_fastavro)

        splits = [
            split for split in source.split(desired_bundle_size=float('inf'))
        ]
        assert len(splits) == 1

        range_tracker = splits[0].source.get_range_tracker(
            splits[0].start_position, splits[0].stop_position)

        split_points_report = []

        for _ in splits[0].source.read(range_tracker):
            split_points_report.append(range_tracker.split_points())

        # There are a total of three blocks. Each block has more than 10 records.

        # When reading records of the first block, range_tracker.split_points()
        # should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
        self.assertEquals(split_points_report[:10],
                          [(0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)] * 10)

        # When reading records of last block, range_tracker.split_points() should
        # return (2, 1)
        self.assertEquals(split_points_report[-10:], [(2, 1)] * 10)
예제 #2
0
  def test_split_points(self):
    num_records = 12000
    file_name = self._write_data(count=num_records)
    source = _create_avro_source(file_name, use_fastavro=self.use_fastavro)

    splits = [
        split
        for split in source.split(desired_bundle_size=float('inf'))
    ]
    assert len(splits) == 1

    range_tracker = splits[0].source.get_range_tracker(
        splits[0].start_position, splits[0].stop_position)

    split_points_report = []

    for _ in splits[0].source.read(range_tracker):
      split_points_report.append(range_tracker.split_points())

    # There will be a total of num_blocks in the generated test file,
    # proportional to number of records in the file divided by syncronization
    # interval used by avro during write. Each block has more than 10 records.
    num_blocks = int(math.ceil(14.5 * num_records /
                               avro.datafile.SYNC_INTERVAL))
    assert num_blocks > 1

    # When reading records of the first block, range_tracker.split_points()
    # should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
    self.assertEqual(
        split_points_report[:10],
        [(0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)] * 10)

    # When reading records of last block, range_tracker.split_points() should
    # return (num_blocks - 1, 1)
    self.assertEqual(split_points_report[-10:], [(num_blocks - 1, 1)] * 10)
예제 #3
0
  def test_split_points(self):
    num_records = 12000
    sync_interval = 16000
    file_name = self._write_data(count=num_records,
                                 sync_interval=sync_interval)

    source = _create_avro_source(file_name, use_fastavro=self.use_fastavro)

    splits = [
        split
        for split in source.split(desired_bundle_size=float('inf'))
    ]
    assert len(splits) == 1
    range_tracker = splits[0].source.get_range_tracker(
        splits[0].start_position, splits[0].stop_position)

    split_points_report = []

    for _ in splits[0].source.read(range_tracker):
      split_points_report.append(range_tracker.split_points())
    # There will be a total of num_blocks in the generated test file,
    # proportional to number of records in the file divided by syncronization
    # interval used by avro during write. Each block has more than 10 records.
    num_blocks = int(math.ceil(14.5 * num_records /
                               sync_interval))
    assert num_blocks > 1
    # When reading records of the first block, range_tracker.split_points()
    # should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
    self.assertEqual(
        split_points_report[:10],
        [(0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)] * 10)

    # When reading records of last block, range_tracker.split_points() should
    # return (num_blocks - 1, 1)
    self.assertEqual(split_points_report[-10:], [(num_blocks - 1, 1)] * 10)
예제 #4
0
파일: avroio_test.py 프로젝트: zhoufek/beam
 def test_read_reantrant_with_splitting(self):
   file_name = self._write_data()
   source = _create_avro_source(file_name)
   splits = [split for split in source.split(desired_bundle_size=100000)]
   assert len(splits) == 1
   source_test_utils.assert_reentrant_reads_succeed(
       (splits[0].source, splits[0].start_position, splits[0].stop_position))
예제 #5
0
  def test_split_points(self):
    file_name = self._write_data(count=12000)
    source = _create_avro_source(file_name, use_fastavro=self.use_fastavro)

    splits = [
        split
        for split in source.split(desired_bundle_size=float('inf'))
    ]
    assert len(splits) == 1

    range_tracker = splits[0].source.get_range_tracker(
        splits[0].start_position, splits[0].stop_position)

    split_points_report = []

    for _ in splits[0].source.read(range_tracker):
      split_points_report.append(range_tracker.split_points())

    # There are a total of three blocks. Each block has more than 10 records.

    # When reading records of the first block, range_tracker.split_points()
    # should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
    self.assertEquals(
        split_points_report[:10],
        [(0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)] * 10)

    # When reading records of last block, range_tracker.split_points() should
    # return (2, 1)
    self.assertEquals(split_points_report[-10:], [(2, 1)] * 10)
예제 #6
0
 def compare_split_points(file_name):
   source = _create_avro_source(file_name,
                                use_fastavro=self.use_fastavro)
   splits = [split
             for split in source.split(desired_bundle_size=float('inf'))]
   assert len(splits) == 1
   source_test_utils.assert_split_at_fraction_exhaustive(splits[0].source)
예제 #7
0
 def test_read_reantrant_with_splitting(self):
   file_name = self._write_data()
   source = _create_avro_source(file_name, use_fastavro=self.use_fastavro)
   splits = [
       split for split in source.split(desired_bundle_size=100000)]
   assert len(splits) == 1
   source_test_utils.assert_reentrant_reads_succeed(
       (splits[0].source, splits[0].start_position, splits[0].stop_position))
예제 #8
0
 def _create_source(self, path, schema):
     if not self.use_json_exports:
         return _create_avro_source(path, use_fastavro=True)
     else:
         return _TextSource(path,
                            min_bundle_size=0,
                            compression_type=CompressionTypes.UNCOMPRESSED,
                            strip_trailing_newlines=True,
                            coder=_JsonToDictCoder(schema))
예제 #9
0
 def test_dynamic_work_rebalancing_exhaustive(self):
   # Adjusting block size so that we can perform a exhaustive dynamic
   # work rebalancing test that completes within an acceptable amount of time.
   old_sync_interval = avro.datafile.SYNC_INTERVAL
   try:
     avro.datafile.SYNC_INTERVAL = 2
     file_name = self._write_data(count=5)
     source = _create_avro_source(file_name, use_fastavro=self.use_fastavro)
     splits = [split
               for split in source.split(desired_bundle_size=float('inf'))]
     assert len(splits) == 1
     source_test_utils.assert_split_at_fraction_exhaustive(splits[0].source)
   finally:
     avro.datafile.SYNC_INTERVAL = old_sync_interval
예제 #10
0
 def test_dynamic_work_rebalancing_exhaustive(self):
   # Adjusting block size so that we can perform a exhaustive dynamic
   # work rebalancing test that completes within an acceptable amount of time.
   old_sync_interval = avro.datafile.SYNC_INTERVAL
   try:
     avro.datafile.SYNC_INTERVAL = 2
     file_name = self._write_data(count=5)
     source = _create_avro_source(file_name, use_fastavro=self.use_fastavro)
     splits = [split
               for split in source.split(desired_bundle_size=float('inf'))]
     assert len(splits) == 1
     source_test_utils.assert_split_at_fraction_exhaustive(splits[0].source)
   finally:
     avro.datafile.SYNC_INTERVAL = old_sync_interval
예제 #11
0
  def test_source_display_data(self):
    file_name = 'some_avro_source'
    source = \
        _create_avro_source(
            file_name,
            validate=False,
            use_fastavro=self.use_fastavro
        )
    dd = DisplayData.create_from(source)

    # No extra avro parameters for AvroSource.
    expected_items = [
        DisplayDataItemMatcher('compression', 'auto'),
        DisplayDataItemMatcher('file_pattern', file_name)]
    hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
예제 #12
0
  def test_source_display_data(self):
    file_name = 'some_avro_source'
    source = \
        _create_avro_source(
            file_name,
            validate=False,
            use_fastavro=self.use_fastavro
        )
    dd = DisplayData.create_from(source)

    # No extra avro parameters for AvroSource.
    expected_items = [
        DisplayDataItemMatcher('compression', 'auto'),
        DisplayDataItemMatcher('file_pattern', file_name)]
    hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
예제 #13
0
파일: avroio_test.py 프로젝트: zhoufek/beam
  def test_corrupted_file(self):
    file_name = self._write_data()
    with open(file_name, 'rb') as f:
      data = f.read()

    # Corrupt the last character of the file which is also the last character of
    # the last sync_marker.
    # https://avro.apache.org/docs/current/spec.html#Object+Container+Files
    corrupted_data = bytearray(data)
    corrupted_data[-1] = (corrupted_data[-1] + 1) % 256
    with tempfile.NamedTemporaryFile(delete=False,
                                     prefix=tempfile.template) as f:
      f.write(corrupted_data)
      corrupted_file_name = f.name

    source = _create_avro_source(corrupted_file_name)
    with self.assertRaisesRegex(ValueError, r'expected sync marker'):
      source_test_utils.read_from_source(source, None, None)
예제 #14
0
  def test_corrupted_file(self):
    file_name = self._write_data()
    with open(file_name, 'rb') as f:
      data = f.read()

    # Corrupt the last character of the file which is also the last character of
    # the last sync_marker.
    last_char_index = len(data) - 1
    corrupted_data = data[:last_char_index]
    corrupted_data += 'A' if data[last_char_index] == 'B' else 'B'
    with tempfile.NamedTemporaryFile(
        delete=False, prefix=tempfile.template) as f:
      f.write(corrupted_data)
      corrupted_file_name = f.name

    source = _create_avro_source(
        corrupted_file_name, use_fastavro=self.use_fastavro)
    with self.assertRaises(ValueError) as exn:
      source_test_utils.read_from_source(source, None, None)
      self.assertEqual(0, exn.exception.message.find('Unexpected sync marker'))
예제 #15
0
  def test_corrupted_file(self):
    file_name = self._write_data()
    with open(file_name, 'rb') as f:
      data = f.read()

    # Corrupt the last character of the file which is also the last character of
    # the last sync_marker.
    last_char_index = len(data) - 1
    corrupted_data = data[:last_char_index]
    corrupted_data += b'A' if data[last_char_index] == b'B' else b'B'
    with tempfile.NamedTemporaryFile(
        delete=False, prefix=tempfile.template) as f:
      f.write(corrupted_data)
      corrupted_file_name = f.name

    source = _create_avro_source(
        corrupted_file_name, use_fastavro=self.use_fastavro)
    with self.assertRaises(ValueError) as exn:
      source_test_utils.read_from_source(source, None, None)
      self.assertEqual(0, exn.exception.message.find('Unexpected sync marker'))
예제 #16
0
파일: avroio_test.py 프로젝트: zhoufek/beam
  def _run_avro_test(
      self, pattern, desired_bundle_size, perform_splitting, expected_result):
    source = _create_avro_source(pattern)

    if perform_splitting:
      assert desired_bundle_size
      splits = [
          split
          for split in source.split(desired_bundle_size=desired_bundle_size)
      ]
      if len(splits) < 2:
        raise ValueError(
            'Test is trivial. Please adjust it so that at least '
            'two splits get generated')

      sources_info = [(split.source, split.start_position, split.stop_position)
                      for split in splits]
      source_test_utils.assert_sources_equal_reference_source(
          (source, None, None), sources_info)
    else:
      read_records = source_test_utils.read_from_source(source, None, None)
      self.assertCountEqual(expected_result, read_records)
예제 #17
0
  def _run_avro_test(self, pattern, desired_bundle_size, perform_splitting,
                     expected_result):
    source = _create_avro_source(pattern, use_fastavro=self.use_fastavro)

    read_records = []
    if perform_splitting:
      assert desired_bundle_size
      splits = [
          split
          for split in source.split(desired_bundle_size=desired_bundle_size)
      ]
      if len(splits) < 2:
        raise ValueError('Test is trivial. Please adjust it so that at least '
                         'two splits get generated')

      sources_info = [
          (split.source, split.start_position, split.stop_position)
          for split in splits
      ]
      source_test_utils.assert_sources_equal_reference_source(
          (source, None, None), sources_info)
    else:
      read_records = source_test_utils.read_from_source(source, None, None)
      self.assertItemsEqual(expected_result, read_records)
예제 #18
0
 def test_read_reentrant_without_splitting(self):
   file_name = self._write_data()
   source = _create_avro_source(file_name, use_fastavro=self.use_fastavro)
   source_test_utils.assert_reentrant_reads_succeed((source, None, None))
예제 #19
0
 def test_read_reentrant_without_splitting(self):
   file_name = self._write_data()
   source = _create_avro_source(file_name, use_fastavro=self.use_fastavro)
   source_test_utils.assert_reentrant_reads_succeed((source, None, None))