def test_split_points(self): file_name = self._write_data(count=12000) source = _create_avro_source(file_name, use_fastavro=self.use_fastavro) splits = [ split for split in source.split(desired_bundle_size=float('inf')) ] assert len(splits) == 1 range_tracker = splits[0].source.get_range_tracker( splits[0].start_position, splits[0].stop_position) split_points_report = [] for _ in splits[0].source.read(range_tracker): split_points_report.append(range_tracker.split_points()) # There are a total of three blocks. Each block has more than 10 records. # When reading records of the first block, range_tracker.split_points() # should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) self.assertEquals(split_points_report[:10], [(0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)] * 10) # When reading records of last block, range_tracker.split_points() should # return (2, 1) self.assertEquals(split_points_report[-10:], [(2, 1)] * 10)
def test_split_points(self): num_records = 12000 file_name = self._write_data(count=num_records) source = _create_avro_source(file_name, use_fastavro=self.use_fastavro) splits = [ split for split in source.split(desired_bundle_size=float('inf')) ] assert len(splits) == 1 range_tracker = splits[0].source.get_range_tracker( splits[0].start_position, splits[0].stop_position) split_points_report = [] for _ in splits[0].source.read(range_tracker): split_points_report.append(range_tracker.split_points()) # There will be a total of num_blocks in the generated test file, # proportional to number of records in the file divided by syncronization # interval used by avro during write. Each block has more than 10 records. num_blocks = int(math.ceil(14.5 * num_records / avro.datafile.SYNC_INTERVAL)) assert num_blocks > 1 # When reading records of the first block, range_tracker.split_points() # should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) self.assertEqual( split_points_report[:10], [(0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)] * 10) # When reading records of last block, range_tracker.split_points() should # return (num_blocks - 1, 1) self.assertEqual(split_points_report[-10:], [(num_blocks - 1, 1)] * 10)
def test_split_points(self): num_records = 12000 sync_interval = 16000 file_name = self._write_data(count=num_records, sync_interval=sync_interval) source = _create_avro_source(file_name, use_fastavro=self.use_fastavro) splits = [ split for split in source.split(desired_bundle_size=float('inf')) ] assert len(splits) == 1 range_tracker = splits[0].source.get_range_tracker( splits[0].start_position, splits[0].stop_position) split_points_report = [] for _ in splits[0].source.read(range_tracker): split_points_report.append(range_tracker.split_points()) # There will be a total of num_blocks in the generated test file, # proportional to number of records in the file divided by syncronization # interval used by avro during write. Each block has more than 10 records. num_blocks = int(math.ceil(14.5 * num_records / sync_interval)) assert num_blocks > 1 # When reading records of the first block, range_tracker.split_points() # should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) self.assertEqual( split_points_report[:10], [(0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)] * 10) # When reading records of last block, range_tracker.split_points() should # return (num_blocks - 1, 1) self.assertEqual(split_points_report[-10:], [(num_blocks - 1, 1)] * 10)
def test_read_reantrant_with_splitting(self): file_name = self._write_data() source = _create_avro_source(file_name) splits = [split for split in source.split(desired_bundle_size=100000)] assert len(splits) == 1 source_test_utils.assert_reentrant_reads_succeed( (splits[0].source, splits[0].start_position, splits[0].stop_position))
def test_split_points(self): file_name = self._write_data(count=12000) source = _create_avro_source(file_name, use_fastavro=self.use_fastavro) splits = [ split for split in source.split(desired_bundle_size=float('inf')) ] assert len(splits) == 1 range_tracker = splits[0].source.get_range_tracker( splits[0].start_position, splits[0].stop_position) split_points_report = [] for _ in splits[0].source.read(range_tracker): split_points_report.append(range_tracker.split_points()) # There are a total of three blocks. Each block has more than 10 records. # When reading records of the first block, range_tracker.split_points() # should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) self.assertEquals( split_points_report[:10], [(0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)] * 10) # When reading records of last block, range_tracker.split_points() should # return (2, 1) self.assertEquals(split_points_report[-10:], [(2, 1)] * 10)
def compare_split_points(file_name): source = _create_avro_source(file_name, use_fastavro=self.use_fastavro) splits = [split for split in source.split(desired_bundle_size=float('inf'))] assert len(splits) == 1 source_test_utils.assert_split_at_fraction_exhaustive(splits[0].source)
def test_read_reantrant_with_splitting(self): file_name = self._write_data() source = _create_avro_source(file_name, use_fastavro=self.use_fastavro) splits = [ split for split in source.split(desired_bundle_size=100000)] assert len(splits) == 1 source_test_utils.assert_reentrant_reads_succeed( (splits[0].source, splits[0].start_position, splits[0].stop_position))
def _create_source(self, path, schema): if not self.use_json_exports: return _create_avro_source(path, use_fastavro=True) else: return _TextSource(path, min_bundle_size=0, compression_type=CompressionTypes.UNCOMPRESSED, strip_trailing_newlines=True, coder=_JsonToDictCoder(schema))
def test_dynamic_work_rebalancing_exhaustive(self): # Adjusting block size so that we can perform a exhaustive dynamic # work rebalancing test that completes within an acceptable amount of time. old_sync_interval = avro.datafile.SYNC_INTERVAL try: avro.datafile.SYNC_INTERVAL = 2 file_name = self._write_data(count=5) source = _create_avro_source(file_name, use_fastavro=self.use_fastavro) splits = [split for split in source.split(desired_bundle_size=float('inf'))] assert len(splits) == 1 source_test_utils.assert_split_at_fraction_exhaustive(splits[0].source) finally: avro.datafile.SYNC_INTERVAL = old_sync_interval
def test_source_display_data(self): file_name = 'some_avro_source' source = \ _create_avro_source( file_name, validate=False, use_fastavro=self.use_fastavro ) dd = DisplayData.create_from(source) # No extra avro parameters for AvroSource. expected_items = [ DisplayDataItemMatcher('compression', 'auto'), DisplayDataItemMatcher('file_pattern', file_name)] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
def test_corrupted_file(self): file_name = self._write_data() with open(file_name, 'rb') as f: data = f.read() # Corrupt the last character of the file which is also the last character of # the last sync_marker. # https://avro.apache.org/docs/current/spec.html#Object+Container+Files corrupted_data = bytearray(data) corrupted_data[-1] = (corrupted_data[-1] + 1) % 256 with tempfile.NamedTemporaryFile(delete=False, prefix=tempfile.template) as f: f.write(corrupted_data) corrupted_file_name = f.name source = _create_avro_source(corrupted_file_name) with self.assertRaisesRegex(ValueError, r'expected sync marker'): source_test_utils.read_from_source(source, None, None)
def test_corrupted_file(self): file_name = self._write_data() with open(file_name, 'rb') as f: data = f.read() # Corrupt the last character of the file which is also the last character of # the last sync_marker. last_char_index = len(data) - 1 corrupted_data = data[:last_char_index] corrupted_data += 'A' if data[last_char_index] == 'B' else 'B' with tempfile.NamedTemporaryFile( delete=False, prefix=tempfile.template) as f: f.write(corrupted_data) corrupted_file_name = f.name source = _create_avro_source( corrupted_file_name, use_fastavro=self.use_fastavro) with self.assertRaises(ValueError) as exn: source_test_utils.read_from_source(source, None, None) self.assertEqual(0, exn.exception.message.find('Unexpected sync marker'))
def test_corrupted_file(self): file_name = self._write_data() with open(file_name, 'rb') as f: data = f.read() # Corrupt the last character of the file which is also the last character of # the last sync_marker. last_char_index = len(data) - 1 corrupted_data = data[:last_char_index] corrupted_data += b'A' if data[last_char_index] == b'B' else b'B' with tempfile.NamedTemporaryFile( delete=False, prefix=tempfile.template) as f: f.write(corrupted_data) corrupted_file_name = f.name source = _create_avro_source( corrupted_file_name, use_fastavro=self.use_fastavro) with self.assertRaises(ValueError) as exn: source_test_utils.read_from_source(source, None, None) self.assertEqual(0, exn.exception.message.find('Unexpected sync marker'))
def _run_avro_test( self, pattern, desired_bundle_size, perform_splitting, expected_result): source = _create_avro_source(pattern) if perform_splitting: assert desired_bundle_size splits = [ split for split in source.split(desired_bundle_size=desired_bundle_size) ] if len(splits) < 2: raise ValueError( 'Test is trivial. Please adjust it so that at least ' 'two splits get generated') sources_info = [(split.source, split.start_position, split.stop_position) for split in splits] source_test_utils.assert_sources_equal_reference_source( (source, None, None), sources_info) else: read_records = source_test_utils.read_from_source(source, None, None) self.assertCountEqual(expected_result, read_records)
def _run_avro_test(self, pattern, desired_bundle_size, perform_splitting, expected_result): source = _create_avro_source(pattern, use_fastavro=self.use_fastavro) read_records = [] if perform_splitting: assert desired_bundle_size splits = [ split for split in source.split(desired_bundle_size=desired_bundle_size) ] if len(splits) < 2: raise ValueError('Test is trivial. Please adjust it so that at least ' 'two splits get generated') sources_info = [ (split.source, split.start_position, split.stop_position) for split in splits ] source_test_utils.assert_sources_equal_reference_source( (source, None, None), sources_info) else: read_records = source_test_utils.read_from_source(source, None, None) self.assertItemsEqual(expected_result, read_records)
def test_read_reentrant_without_splitting(self): file_name = self._write_data() source = _create_avro_source(file_name, use_fastavro=self.use_fastavro) source_test_utils.assert_reentrant_reads_succeed((source, None, None))