def test_write_read_metadata(self, file_spec, random_access, parallelism): with contextlib.closing(file_spec(self.create_tempfile, random_access)) as files: metadata_written = riegeli.RecordsMetadata() metadata_written.file_comment = 'Comment' riegeli.set_record_type(metadata_written, records_test_pb2.SimpleMessage) message_written = sample_message(7, 10) with riegeli.RecordWriter( files.writing_open(), owns_dest=files.writing_should_close, assumed_pos=files.writing_assumed_pos, options=record_writer_options(parallelism), metadata=metadata_written) as writer: writer.write_message(message_written) with riegeli.RecordReader( files.reading_open(), owns_src=files.reading_should_close, assumed_pos=files.reading_assumed_pos) as reader: metadata_read = reader.read_metadata() self.assertEqual(metadata_read, metadata_written) record_type = riegeli.get_record_type(metadata_read) assert record_type is not None self.assertEqual(record_type.DESCRIPTOR.full_name, 'riegeli.tests.SimpleMessage') message_read = reader.read_message(record_type) assert message_read is not None # Serialize and deserialize because messages have descriptors of # different origins. self.assertEqual( records_test_pb2.SimpleMessage.FromString( message_read.SerializeToString()), message_written)
def write_records(filename): print('Writing', filename) metadata = riegeli.RecordsMetadata() riegeli.set_record_type(metadata, records_test_pb2.SimpleMessage) with riegeli.RecordWriter(io.FileIO(filename, mode='wb'), options='transpose', metadata=metadata) as writer: writer.write_messages(sample_message(i, 100) for i in range(100))
def test_reads_without_traces_or_triple_raises(self): filename = _get_tmp_file_name() with riegeli.RecordWriter( io.FileIO(filename, mode='wb'), options='transpose', metadata=riegeli.RecordsMetadata()) as writer: writer.write_messages(_get_exposures_missing_required_data()) data_loader = abesim_data_loader.AbesimExposureDataLoader( filename, unconfirmed_exposures=False) with self.assertRaises(ValueError): _ = data_loader.get_next_batch(batch_size=1)
def test_reads_without_traces(self): filename = _get_tmp_file_name() with riegeli.RecordWriter( io.FileIO(filename, mode='wb'), options='transpose', metadata=riegeli.RecordsMetadata()) as writer: writer.write_messages(_get_exposures_without_proximity_traces()) data_loader = abesim_data_loader.AbesimExposureDataLoader( filename, unconfirmed_exposures=False) exposures, labels, grouping = data_loader.get_next_batch(batch_size=1) self.assertCountEqual(exposures, [([1.0], 0, 30)]) self.assertCountEqual(labels, [1]) self.assertCountEqual(grouping, [1])