def test_corruption_recovery(self, file_spec, random_access, parallelism): with contextlib.closing(file_spec(self.create_tempfile, random_access)) as files: keys = [] with riegeli.RecordWriter( files.writing_open(), close=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism)) as writer: for i in range(23): keys.append( writer.write_record_with_key(sample_string(i, 10000))) # Corrupt the header of the chunk containing records [9, 12). self.corrupt_at(files, keys[9].chunk_begin + 20) # Read records [0, 9) and [15, 23) successfully (all except the corrupted # chunk and the next chunk which intersects the same block). skipped_regions = [] with riegeli.RecordReader( files.reading_open(), close=files.reading_should_close, assumed_pos=files.reading_assumed_pos, recovery=skipped_regions.append) as reader: for i in range(9): self.assertEqual(reader.read_record(), sample_string(i, 10000)) for i in range(15, 23): self.assertEqual(reader.read_record(), sample_string(i, 10000)) self.assertIsNone(reader.read_record()) self.assertLen(skipped_regions, 1) skipped_region = skipped_regions[0] self.assertEqual(skipped_region.begin, keys[9].numeric) self.assertEqual(skipped_region.end, keys[15].numeric)
def test_corruption_recovery_exception(self, file_spec, random_access, parallelism): with contextlib.closing(file_spec(self.create_tempfile, random_access)) as files: positions = [] with riegeli.RecordWriter( files.writing_open(), owns_dest=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism)) as writer: for i in range(23): writer.write_record(sample_string(i, 10000)) positions.append(writer.last_pos) # Corrupt the header of the chunk containing records [9..12). self.corrupt_at(files, positions[9].chunk_begin + 20) # Propagate exception from the recovery function def recovery(skipped_region): raise KeyboardInterrupt with riegeli.RecordReader(files.reading_open(), owns_src=files.reading_should_close, assumed_pos=files.reading_assumed_pos, recovery=recovery) as reader: for i in range(9): self.assertEqual(reader.read_record(), sample_string(i, 10000)) with self.assertRaises(KeyboardInterrupt): reader.read_record()
def test_write_read_record(self, file_spec, random_access, parallelism): with contextlib.closing(file_spec(self.create_tempfile, random_access)) as files: positions = [] with riegeli.RecordWriter( files.writing_open(), owns_dest=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism)) as writer: for i in range(23): pos = writer.pos writer.write_record(sample_string(i, 10000)) canonical_pos = writer.last_pos if positions: self.assertGreater(pos, positions[-1]) self.assertLessEqual(pos, canonical_pos) positions.append(canonical_pos) writer.close() end_pos = writer.pos with riegeli.RecordReader( files.reading_open(), owns_src=files.reading_should_close, assumed_pos=files.reading_assumed_pos) as reader: for i in range(23): pos = reader.pos self.assertEqual(reader.read_record(), sample_string(i, 10000)) canonical_pos = reader.last_pos self.assertEqual(canonical_pos, positions[i]) self.assertLessEqual(pos, canonical_pos) self.assertIsNone(reader.read_record()) self.assertEqual(reader.pos, end_pos) reader.close() self.assertEqual(reader.pos, end_pos)
def test_corruption_exception(self, file_spec, random_access, parallelism): with contextlib.closing(file_spec(self.create_tempfile, random_access)) as files: positions = [] with riegeli.RecordWriter( files.writing_open(), owns_dest=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism)) as writer: for i in range(23): writer.write_record(sample_string(i, 10000)) positions.append(writer.last_pos) # Corrupt the header of the chunk containing records [9..12). self.corrupt_at(files, positions[9].chunk_begin + 20) # Read records [0..9) successfully (all before the corrupted chunk). reader = riegeli.RecordReader( files.reading_open(), owns_src=files.reading_should_close, assumed_pos=files.reading_assumed_pos) for i in range(9): self.assertEqual(reader.read_record(), sample_string(i, 10000)) with self.assertRaises(riegeli.RiegeliError): reader.read_record() with self.assertRaises(riegeli.RiegeliError): reader.close()
def test_field_projection_existence_only(self, file_spec, random_access, parallelism): with contextlib.closing(file_spec(self.create_tempfile, random_access)) as files: with riegeli.RecordWriter( files.writing_open(), owns_dest=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=f'{record_writer_options(parallelism)},transpose' ) as writer: for i in range(23): writer.write_message(sample_message(i, 10000)) with riegeli.RecordReader( files.reading_open(), owns_src=files.reading_should_close, assumed_pos=files.reading_assumed_pos, field_projection=[ [ records_test_pb2.SimpleMessage.DESCRIPTOR. fields_by_name['id'].number ], [ records_test_pb2.SimpleMessage.DESCRIPTOR. fields_by_name['payload'].number, riegeli.EXISTENCE_ONLY, ], ]) as reader: for i in range(23): self.assertEqual( reader.read_message(records_test_pb2.SimpleMessage), records_test_pb2.SimpleMessage(id=i, payload=b'')) self.assertIsNone( reader.read_message(records_test_pb2.SimpleMessage))
def test_corruption_recovery_stop_iteration(self, file_spec, random_access, parallelism): with contextlib.closing(file_spec(self.create_tempfile, random_access)) as files: positions = [] with riegeli.RecordWriter( files.writing_open(), owns_dest=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism)) as writer: for i in range(23): writer.write_record(sample_string(i, 10000)) positions.append(writer.last_pos) # Corrupt the header of the chunk containing records [9..12). self.corrupt_at(files, positions[9].chunk_begin + 20) # Read records [0..9) successfully (all before the corrupted chunk). skipped_regions = [] def recovery(skipped_region): skipped_regions.append(skipped_region) raise StopIteration with riegeli.RecordReader(files.reading_open(), owns_src=files.reading_should_close, assumed_pos=files.reading_assumed_pos, recovery=recovery) as reader: for i in range(9): self.assertEqual(reader.read_record(), sample_string(i, 10000)) self.assertIsNone(reader.read_record()) self.assertLen(skipped_regions, 1) skipped_region = skipped_regions[0] self.assertEqual(skipped_region.begin, positions[9].numeric) self.assertEqual(skipped_region.end, positions[15].numeric)
def test_search_for_record(self, file_spec): with contextlib.closing( file_spec(self.create_tempfile, random_access=True)) as files: with riegeli.RecordWriter( files.writing_open(), close=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism=0)) as writer: keys = writer.write_messages_with_keys( sample_message(i, 10000) for i in range(23)) writer.close() end_pos = writer.pos with riegeli.RecordReader( files.reading_open(), close=files.reading_should_close, assumed_pos=files.reading_assumed_pos) as reader: def test_function(search_target): def test(record): message = records_test_pb2.SimpleMessage.FromString( record) return (message.id > search_target) - (message.id < search_target) return test reader.search_for_record(test_function(7)) self.assertEqual(reader.pos, keys[7]) reader.search_for_record(test_function(0)) self.assertEqual(reader.pos, keys[0]) reader.search_for_record(test_function(22)) self.assertEqual(reader.pos, keys[22]) reader.search_for_record(test_function(23)) self.assertEqual(reader.pos, end_pos)
def test_field_projection(self, file_spec, random_access, parallelism): with contextlib.closing(file_spec(self.create_tempfile, random_access)) as files: with riegeli.RecordWriter( files.writing_open(), close=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism) + ',transpose') as writer: for i in range(23): writer.write_message(sample_message(i, 10000)) with riegeli.RecordReader( files.reading_open(), close=files.reading_should_close, assumed_pos=files.reading_assumed_pos, field_projection=[[ records_test_pb2.SimpleMessage.DESCRIPTOR. fields_by_name['id'].number ]]) as reader: for i in range(23): self.assertEqual( reader.read_message(records_test_pb2.SimpleMessage), records_test_pb2.SimpleMessage(id=i)) self.assertIsNone( reader.read_message(records_test_pb2.SimpleMessage))
def test_write_read_messages_with_field_projection_later( self, file_spec, parallelism): with contextlib.closing( file_spec(self.create_tempfile, random_access=True)) as files: with riegeli.RecordWriter( files.writing_open(), close=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism, transpose=True)) as writer: writer.write_messages( sample_message(i, 10000) for i in range(23)) with riegeli.RecordReader( files.reading_open(), close=files.reading_should_close, assumed_pos=files.reading_assumed_pos) as reader: for i in range(4): self.assertEqual( reader.read_message(records_test_pb2.SimpleMessage), sample_message(i, 10000)) reader.set_field_projection([[ records_test_pb2.SimpleMessage.DESCRIPTOR. fields_by_name['id'].number ]]) for i in range(4, 14): self.assertEqual( reader.read_message(records_test_pb2.SimpleMessage), sample_message_id_only(i)) reader.set_field_projection(None) for i in range(14, 23): self.assertEqual( reader.read_message(records_test_pb2.SimpleMessage), sample_message(i, 10000)) self.assertIsNone( reader.read_message(records_test_pb2.SimpleMessage))
def test_write_read_record_with_key(self, file_spec, random_access, parallelism): with contextlib.closing(file_spec(self.create_tempfile, random_access)) as files: keys = [] with riegeli.RecordWriter( files.writing_open(), close=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism)) as writer: for i in range(23): pos = writer.pos key = writer.write_record_with_key(sample_string(i, 10000)) if keys: self.assertGreater(pos, keys[-1]) self.assertLessEqual(pos, key) keys.append(key) writer.close() end_pos = writer.pos with riegeli.RecordReader( files.reading_open(), close=files.reading_should_close, assumed_pos=files.reading_assumed_pos) as reader: for i in range(23): pos = reader.pos self.assertEqual(reader.read_record_with_key(), (keys[i], sample_string(i, 10000))) self.assertLessEqual(pos, key) self.assertIsNone(reader.read_record_with_key()) self.assertEqual(reader.pos, end_pos) reader.close() self.assertEqual(reader.pos, end_pos)
def test_record_reader_exception_from_file(self, random_access): byte_reader = FakeFile(random_access) with self.assertRaises(NotImplementedError): with riegeli.RecordReader(byte_reader, assumed_pos=None if random_access else 0, close=False) as reader: reader.read_record()
def test_write_read_metadata(self, file_spec, random_access, parallelism): with contextlib.closing(file_spec(self.create_tempfile, random_access)) as files: metadata_written = riegeli.RecordsMetadata() metadata_written.file_comment = 'Comment' riegeli.set_record_type(metadata_written, records_test_pb2.SimpleMessage) message_written = sample_message(7, 10) with riegeli.RecordWriter( files.writing_open(), owns_dest=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism), metadata=metadata_written) as writer: writer.write_message(message_written) with riegeli.RecordReader( files.reading_open(), owns_src=files.reading_should_close, assumed_pos=files.reading_assumed_pos) as reader: metadata_read = reader.read_metadata() self.assertEqual(metadata_read, metadata_written) record_type = riegeli.get_record_type(metadata_read) assert record_type is not None self.assertEqual(record_type.DESCRIPTOR.full_name, 'riegeli.tests.SimpleMessage') message_read = reader.read_message(record_type) assert message_read is not None # Serialize and deserialize because messages have descriptors of # different origins. self.assertEqual( records_test_pb2.SimpleMessage.FromString( message_read.SerializeToString()), message_written)
def __init__(self, file_path, unconfirmed_exposures=False, window_around_infection_onset_time=False, selection_window_left=10, selection_window_right=0, file_io=io.FileIO): """Initialize the Abesim exposure data loder. Args: file_path: The path of the Riegeli file that stores ExposureResult protos. unconfirmed_exposures: Whether to query unconfirmed exposures. window_around_infection_onset_time: Whether to set the exposure selection center around the infection onset time (True) or the test administered time (False). selection_window_left: Days from the left selection bound to the center. selection_window_right: Days from the right selection bound to the center. file_io: A method for constructing a file object for reading. """ self.file_path = file_path self.unconfirmed_exposures = unconfirmed_exposures self.window_around_infection_onset_time = window_around_infection_onset_time self.selection_window_left = selection_window_left self.selection_window_right = selection_window_right self.file_io = file_io self.index_reader = riegeli.RecordReader( self.file_io(file_path, mode='rb'))
def test_seek_numeric(self, file_spec, parallelism): with contextlib.closing( file_spec(self.create_tempfile, random_access=True)) as files: keys = [] with riegeli.RecordWriter( files.writing_open(), close=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism)) as writer: for i in range(23): pos = writer.pos key = writer.write_record_with_key(sample_string(i, 10000)) if keys: self.assertGreater(pos, keys[-1]) self.assertLessEqual(pos, key) keys.append(key) writer.close() end_pos = writer.pos with riegeli.RecordReader( files.reading_open(), close=files.reading_should_close, assumed_pos=files.reading_assumed_pos) as reader: reader.seek_numeric(keys[9].numeric) self.assertGreater(reader.pos, keys[8]) self.assertLessEqual(reader.pos, keys[9]) reader.seek_numeric(keys[9].numeric) self.assertGreater(reader.pos, keys[8]) self.assertLessEqual(reader.pos, keys[9]) reader.seek_numeric(keys[11].numeric) self.assertGreater(reader.pos, keys[10]) self.assertLessEqual(reader.pos, keys[11]) self.assertEqual(reader.read_record(), sample_string(11, 10000)) reader.seek_numeric(keys[9].numeric) self.assertGreater(reader.pos, keys[8]) self.assertLessEqual(reader.pos, keys[9]) self.assertEqual(reader.read_record(), sample_string(9, 10000)) reader.seek_numeric(keys[11].numeric) self.assertGreater(reader.pos, keys[10]) self.assertLessEqual(reader.pos, keys[11]) self.assertEqual(reader.read_record(), sample_string(11, 10000)) reader.seek_numeric(keys[13].numeric) self.assertGreater(reader.pos, keys[12]) self.assertLessEqual(reader.pos, keys[13]) self.assertEqual(reader.read_record(), sample_string(13, 10000)) reader.seek_numeric(0) self.assertLessEqual(reader.pos, keys[0]) self.assertEqual(reader.read_record(), sample_string(0, 10000)) reader.seek_numeric(end_pos.numeric) self.assertLessEqual(reader.pos, end_pos) self.assertIsNone(reader.read_record()) reader.seek_numeric(keys[11].numeric) self.assertGreater(reader.pos, keys[10]) self.assertLessEqual(reader.pos, keys[11]) reader.close() self.assertGreater(reader.pos, keys[10]) self.assertLessEqual(reader.pos, keys[11])
def test_record_reader_exception_from_file(self, random_access): byte_reader = FakeFile(random_access is RandomAccess.RANDOM_ACCESS) with self.assertRaises(NotImplementedError): with riegeli.RecordReader( byte_reader, owns_src=False, assumed_pos=(0 if random_access is RandomAccess.SEQUENTIAL_ACCESS_EXPLICIT else None)) as reader: reader.read_record()
def read_records(filename): print('Reading', filename) with riegeli.RecordReader(io.FileIO(filename, mode='rb'), field_projection=[[ records_test_pb2.SimpleMessage.DESCRIPTOR. fields_by_name['id'].number ]]) as reader: print(' '.join( str(record.id) for record in reader.read_messages( records_test_pb2.SimpleMessage)))
def test_write_read_records(self, file_spec, random_access, parallelism): with contextlib.closing(file_spec(self.create_tempfile, random_access)) as files: with riegeli.RecordWriter( files.writing_open(), owns_dest=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism)) as writer: writer.write_records( sample_string(i, 10000) for i in range(23)) with riegeli.RecordReader( files.reading_open(), owns_src=files.reading_should_close, assumed_pos=files.reading_assumed_pos) as reader: self.assertEqual(list(reader.read_records()), [sample_string(i, 10000) for i in range(23)])
def test_search_for_message(self, file_spec): with contextlib.closing( file_spec(self.create_tempfile, random_access=RandomAccess.RANDOM_ACCESS)) as files: with riegeli.RecordWriter( files.writing_open(), owns_dest=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism=0)) as writer: positions = [] for i in range(23): writer.write_message(sample_message(i, 10000)) positions.append(writer.last_pos) writer.close() end_pos = writer.pos with riegeli.RecordReader( files.reading_open(), owns_src=files.reading_should_close, assumed_pos=files.reading_assumed_pos) as reader: def test_function(search_target): def test(message): return (message.id > search_target) - (message.id < search_target) return test self.assertEqual( reader.search_for_message(records_test_pb2.SimpleMessage, test_function(7)), 0) self.assertEqual(reader.pos, positions[7]) self.assertEqual( reader.search_for_message(records_test_pb2.SimpleMessage, test_function(0)), 0) self.assertEqual(reader.pos, positions[0]) self.assertEqual( reader.search_for_message(records_test_pb2.SimpleMessage, test_function(22)), 0) self.assertEqual(reader.pos, positions[22]) self.assertEqual( reader.search_for_message(records_test_pb2.SimpleMessage, test_function(23)), -1) self.assertEqual(reader.pos, end_pos)
def test_write_read_messages_with_keys(self, file_spec, random_access, parallelism): with contextlib.closing(file_spec(self.create_tempfile, random_access)) as files: with riegeli.RecordWriter( files.writing_open(), close=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism)) as writer: keys = writer.write_messages_with_keys( sample_message(i, 10000) for i in range(23)) with riegeli.RecordReader( files.reading_open(), close=files.reading_should_close, assumed_pos=files.reading_assumed_pos) as reader: self.assertEqual( list( reader.read_messages_with_keys(records_test_pb2.SimpleMessage)), [(keys[i], sample_message(i, 10000)) for i in range(23)])
def test_seek_back(self, file_spec): with contextlib.closing( file_spec(self.create_tempfile, random_access=True)) as files: with riegeli.RecordWriter( files.writing_open(), close=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism=0)) as writer: for i in range(23): writer.write_record(sample_string(i, 10000)) with riegeli.RecordReader( files.reading_open(), close=files.reading_should_close, assumed_pos=files.reading_assumed_pos) as reader: reader.seek_numeric(reader.size()) for i in reversed(range(23)): self.assertTrue(reader.seek_back()) self.assertEqual(reader.read_record(), sample_string(i, 10000)) self.assertTrue(reader.seek_back()) self.assertFalse(reader.seek_back())
def test_write_read_messages_with_field_projection(self, file_spec, random_access, parallelism): with contextlib.closing(file_spec(self.create_tempfile, random_access)) as files: with riegeli.RecordWriter( files.writing_open(), owns_dest=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism, transpose=True)) as writer: writer.write_messages( sample_message(i, 10000) for i in range(23)) with riegeli.RecordReader( files.reading_open(), owns_src=files.reading_should_close, assumed_pos=files.reading_assumed_pos, field_projection=[[ records_test_pb2.SimpleMessage.DESCRIPTOR. fields_by_name['id'].number ]]) as reader: self.assertEqual( list(reader.read_messages(records_test_pb2.SimpleMessage)), [sample_message_id_only(i) for i in range(23)])
def reset_file(self): # Reset to the start of the riegeli file. self.index_reader.close() self.index_reader = riegeli.RecordReader( self.file_io(self.file_path, mode='rb'))
def test_seek_numeric(self, file_spec, parallelism): with contextlib.closing( file_spec(self.create_tempfile, random_access=RandomAccess.RANDOM_ACCESS)) as files: positions = [] with riegeli.RecordWriter( files.writing_open(), owns_dest=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism)) as writer: for i in range(23): pos = writer.pos writer.write_record(sample_string(i, 10000)) canonical_pos = writer.last_pos if positions: self.assertGreater(pos, positions[-1]) self.assertLessEqual(pos, canonical_pos) positions.append(canonical_pos) writer.close() end_pos = writer.pos with riegeli.RecordReader( files.reading_open(), owns_src=files.reading_should_close, assumed_pos=files.reading_assumed_pos) as reader: reader.seek_numeric(positions[9].numeric) self.assertGreater(reader.pos, positions[8]) self.assertLessEqual(reader.pos, positions[9]) reader.seek_numeric(positions[9].numeric) self.assertGreater(reader.pos, positions[8]) self.assertLessEqual(reader.pos, positions[9]) reader.seek_numeric(positions[11].numeric) self.assertGreater(reader.pos, positions[10]) self.assertLessEqual(reader.pos, positions[11]) self.assertEqual(reader.read_record(), sample_string(11, 10000)) reader.seek_numeric(positions[9].numeric) self.assertGreater(reader.pos, positions[8]) self.assertLessEqual(reader.pos, positions[9]) self.assertEqual(reader.read_record(), sample_string(9, 10000)) reader.seek_numeric(positions[11].numeric) self.assertGreater(reader.pos, positions[10]) self.assertLessEqual(reader.pos, positions[11]) self.assertEqual(reader.read_record(), sample_string(11, 10000)) reader.seek_numeric(positions[13].numeric) self.assertGreater(reader.pos, positions[12]) self.assertLessEqual(reader.pos, positions[13]) self.assertEqual(reader.read_record(), sample_string(13, 10000)) reader.seek_numeric(0) self.assertLessEqual(reader.pos, positions[0]) self.assertEqual(reader.read_record(), sample_string(0, 10000)) reader.seek_numeric(end_pos.numeric) self.assertLessEqual(reader.pos, end_pos) self.assertIsNone(reader.read_record()) reader.seek_numeric(positions[11].numeric) self.assertGreater(reader.pos, positions[10]) self.assertLessEqual(reader.pos, positions[11]) reader.close() self.assertGreater(reader.pos, positions[10]) self.assertLessEqual(reader.pos, positions[11])