Exemple #1
0
    def test_search_for_record(self, file_spec):
        with contextlib.closing(
                file_spec(self.create_tempfile, random_access=True)) as files:
            with riegeli.RecordWriter(
                    files.writing_open(),
                    close=files.writing_should_close,
                    assumed_pos=files.writing_assumed_pos,
                    options=record_writer_options(parallelism=0)) as writer:
                keys = writer.write_messages_with_keys(
                    sample_message(i, 10000) for i in range(23))
                writer.close()
                end_pos = writer.pos
            with riegeli.RecordReader(
                    files.reading_open(),
                    close=files.reading_should_close,
                    assumed_pos=files.reading_assumed_pos) as reader:

                def test_function(search_target):
                    def test(record):
                        message = records_test_pb2.SimpleMessage.FromString(
                            record)
                        return (message.id > search_target) - (message.id <
                                                               search_target)

                    return test

                reader.search_for_record(test_function(7))
                self.assertEqual(reader.pos, keys[7])
                reader.search_for_record(test_function(0))
                self.assertEqual(reader.pos, keys[0])
                reader.search_for_record(test_function(22))
                self.assertEqual(reader.pos, keys[22])
                reader.search_for_record(test_function(23))
                self.assertEqual(reader.pos, end_pos)
def main():
    logging.basicConfig(level=logging.INFO,
                        format='%(levelname)-8s: %(message)s')
    parser = argparse.ArgumentParser(description=__doc__.strip())
    parser.add_argument('filename', help='Ledger filename')
    args = parser.parse_args()

    with open('/tmp/ledger.rieg', 'wb') as outfile:
        with riegeli.RecordWriter(outfile) as writer:
            entries, errors, options_map = loader.load_file(args.filename)
            for entry in entries:
                if isinstance(entry, data.Transaction):
                    pbent = convert_Transaction(entry)
                elif isinstance(entry, data.Open):
                    pbent = convert_Open(entry)
                elif isinstance(entry, data.Close):
                    pbent = convert_Close(entry)
                else:
                    pbent = None

                if pbent is not None:
                    #print(type(txn))
                    #print(txn)
                    writer.write_message(pbent)

                if 0:
                    print('-' * 100)
                    printer.print_entry(entry)
                    print(txn)
                    print()
Exemple #3
0
 def test_write_read_record(self, file_spec, random_access, parallelism):
     with contextlib.closing(file_spec(self.create_tempfile,
                                       random_access)) as files:
         positions = []
         with riegeli.RecordWriter(
                 files.writing_open(),
                 owns_dest=files.writing_should_close,
                 assumed_pos=files.writing_assumed_pos,
                 options=record_writer_options(parallelism)) as writer:
             for i in range(23):
                 pos = writer.pos
                 writer.write_record(sample_string(i, 10000))
                 canonical_pos = writer.last_pos
                 if positions:
                     self.assertGreater(pos, positions[-1])
                 self.assertLessEqual(pos, canonical_pos)
                 positions.append(canonical_pos)
             writer.close()
             end_pos = writer.pos
         with riegeli.RecordReader(
                 files.reading_open(),
                 owns_src=files.reading_should_close,
                 assumed_pos=files.reading_assumed_pos) as reader:
             for i in range(23):
                 pos = reader.pos
                 self.assertEqual(reader.read_record(),
                                  sample_string(i, 10000))
                 canonical_pos = reader.last_pos
                 self.assertEqual(canonical_pos, positions[i])
                 self.assertLessEqual(pos, canonical_pos)
             self.assertIsNone(reader.read_record())
             self.assertEqual(reader.pos, end_pos)
             reader.close()
             self.assertEqual(reader.pos, end_pos)
Exemple #4
0
 def test_field_projection_existence_only(self, file_spec, random_access,
                                          parallelism):
     with contextlib.closing(file_spec(self.create_tempfile,
                                       random_access)) as files:
         with riegeli.RecordWriter(
                 files.writing_open(),
                 owns_dest=files.writing_should_close,
                 assumed_pos=files.writing_assumed_pos,
                 options=f'{record_writer_options(parallelism)},transpose'
         ) as writer:
             for i in range(23):
                 writer.write_message(sample_message(i, 10000))
         with riegeli.RecordReader(
                 files.reading_open(),
                 owns_src=files.reading_should_close,
                 assumed_pos=files.reading_assumed_pos,
                 field_projection=[
                     [
                         records_test_pb2.SimpleMessage.DESCRIPTOR.
                         fields_by_name['id'].number
                     ],
                     [
                         records_test_pb2.SimpleMessage.DESCRIPTOR.
                         fields_by_name['payload'].number,
                         riegeli.EXISTENCE_ONLY,
                     ],
                 ]) as reader:
             for i in range(23):
                 self.assertEqual(
                     reader.read_message(records_test_pb2.SimpleMessage),
                     records_test_pb2.SimpleMessage(id=i, payload=b''))
             self.assertIsNone(
                 reader.read_message(records_test_pb2.SimpleMessage))
Exemple #5
0
 def test_field_projection(self, file_spec, random_access, parallelism):
     with contextlib.closing(file_spec(self.create_tempfile,
                                       random_access)) as files:
         with riegeli.RecordWriter(
                 files.writing_open(),
                 close=files.writing_should_close,
                 assumed_pos=files.writing_assumed_pos,
                 options=record_writer_options(parallelism) +
                 ',transpose') as writer:
             for i in range(23):
                 writer.write_message(sample_message(i, 10000))
         with riegeli.RecordReader(
                 files.reading_open(),
                 close=files.reading_should_close,
                 assumed_pos=files.reading_assumed_pos,
                 field_projection=[[
                     records_test_pb2.SimpleMessage.DESCRIPTOR.
                     fields_by_name['id'].number
                 ]]) as reader:
             for i in range(23):
                 self.assertEqual(
                     reader.read_message(records_test_pb2.SimpleMessage),
                     records_test_pb2.SimpleMessage(id=i))
             self.assertIsNone(
                 reader.read_message(records_test_pb2.SimpleMessage))
Exemple #6
0
 def test_write_read_record_with_key(self, file_spec, random_access,
                                     parallelism):
     with contextlib.closing(file_spec(self.create_tempfile,
                                       random_access)) as files:
         keys = []
         with riegeli.RecordWriter(
                 files.writing_open(),
                 close=files.writing_should_close,
                 assumed_pos=files.writing_assumed_pos,
                 options=record_writer_options(parallelism)) as writer:
             for i in range(23):
                 pos = writer.pos
                 key = writer.write_record_with_key(sample_string(i, 10000))
                 if keys:
                     self.assertGreater(pos, keys[-1])
                 self.assertLessEqual(pos, key)
                 keys.append(key)
             writer.close()
             end_pos = writer.pos
         with riegeli.RecordReader(
                 files.reading_open(),
                 close=files.reading_should_close,
                 assumed_pos=files.reading_assumed_pos) as reader:
             for i in range(23):
                 pos = reader.pos
                 self.assertEqual(reader.read_record_with_key(),
                                  (keys[i], sample_string(i, 10000)))
                 self.assertLessEqual(pos, key)
             self.assertIsNone(reader.read_record_with_key())
             self.assertEqual(reader.pos, end_pos)
             reader.close()
             self.assertEqual(reader.pos, end_pos)
Exemple #7
0
 def test_write_read_messages_with_field_projection_later(
         self, file_spec, parallelism):
     with contextlib.closing(
             file_spec(self.create_tempfile, random_access=True)) as files:
         with riegeli.RecordWriter(
                 files.writing_open(),
                 close=files.writing_should_close,
                 assumed_pos=files.writing_assumed_pos,
                 options=record_writer_options(parallelism,
                                               transpose=True)) as writer:
             writer.write_messages(
                 sample_message(i, 10000) for i in range(23))
         with riegeli.RecordReader(
                 files.reading_open(),
                 close=files.reading_should_close,
                 assumed_pos=files.reading_assumed_pos) as reader:
             for i in range(4):
                 self.assertEqual(
                     reader.read_message(records_test_pb2.SimpleMessage),
                     sample_message(i, 10000))
             reader.set_field_projection([[
                 records_test_pb2.SimpleMessage.DESCRIPTOR.
                 fields_by_name['id'].number
             ]])
             for i in range(4, 14):
                 self.assertEqual(
                     reader.read_message(records_test_pb2.SimpleMessage),
                     sample_message_id_only(i))
             reader.set_field_projection(None)
             for i in range(14, 23):
                 self.assertEqual(
                     reader.read_message(records_test_pb2.SimpleMessage),
                     sample_message(i, 10000))
             self.assertIsNone(
                 reader.read_message(records_test_pb2.SimpleMessage))
Exemple #8
0
    def test_corruption_recovery_stop_iteration(self, file_spec, random_access,
                                                parallelism):
        with contextlib.closing(file_spec(self.create_tempfile,
                                          random_access)) as files:
            positions = []
            with riegeli.RecordWriter(
                    files.writing_open(),
                    owns_dest=files.writing_should_close,
                    assumed_pos=files.writing_assumed_pos,
                    options=record_writer_options(parallelism)) as writer:
                for i in range(23):
                    writer.write_record(sample_string(i, 10000))
                    positions.append(writer.last_pos)
            # Corrupt the header of the chunk containing records [9..12).
            self.corrupt_at(files, positions[9].chunk_begin + 20)
            # Read records [0..9) successfully (all before the corrupted chunk).
            skipped_regions = []

            def recovery(skipped_region):
                skipped_regions.append(skipped_region)
                raise StopIteration

            with riegeli.RecordReader(files.reading_open(),
                                      owns_src=files.reading_should_close,
                                      assumed_pos=files.reading_assumed_pos,
                                      recovery=recovery) as reader:
                for i in range(9):
                    self.assertEqual(reader.read_record(),
                                     sample_string(i, 10000))
                self.assertIsNone(reader.read_record())
            self.assertLen(skipped_regions, 1)
            skipped_region = skipped_regions[0]
            self.assertEqual(skipped_region.begin, positions[9].numeric)
            self.assertEqual(skipped_region.end, positions[15].numeric)
Exemple #9
0
 def test_corruption_exception(self, file_spec, random_access, parallelism):
     with contextlib.closing(file_spec(self.create_tempfile,
                                       random_access)) as files:
         positions = []
         with riegeli.RecordWriter(
                 files.writing_open(),
                 owns_dest=files.writing_should_close,
                 assumed_pos=files.writing_assumed_pos,
                 options=record_writer_options(parallelism)) as writer:
             for i in range(23):
                 writer.write_record(sample_string(i, 10000))
                 positions.append(writer.last_pos)
         # Corrupt the header of the chunk containing records [9..12).
         self.corrupt_at(files, positions[9].chunk_begin + 20)
         # Read records [0..9) successfully (all before the corrupted chunk).
         reader = riegeli.RecordReader(
             files.reading_open(),
             owns_src=files.reading_should_close,
             assumed_pos=files.reading_assumed_pos)
         for i in range(9):
             self.assertEqual(reader.read_record(), sample_string(i, 10000))
         with self.assertRaises(riegeli.RiegeliError):
             reader.read_record()
         with self.assertRaises(riegeli.RiegeliError):
             reader.close()
Exemple #10
0
    def test_corruption_recovery_exception(self, file_spec, random_access,
                                           parallelism):
        with contextlib.closing(file_spec(self.create_tempfile,
                                          random_access)) as files:
            positions = []
            with riegeli.RecordWriter(
                    files.writing_open(),
                    owns_dest=files.writing_should_close,
                    assumed_pos=files.writing_assumed_pos,
                    options=record_writer_options(parallelism)) as writer:
                for i in range(23):
                    writer.write_record(sample_string(i, 10000))
                    positions.append(writer.last_pos)
            # Corrupt the header of the chunk containing records [9..12).
            self.corrupt_at(files, positions[9].chunk_begin + 20)

            # Propagate exception from the recovery function
            def recovery(skipped_region):
                raise KeyboardInterrupt

            with riegeli.RecordReader(files.reading_open(),
                                      owns_src=files.reading_should_close,
                                      assumed_pos=files.reading_assumed_pos,
                                      recovery=recovery) as reader:
                for i in range(9):
                    self.assertEqual(reader.read_record(),
                                     sample_string(i, 10000))
                with self.assertRaises(KeyboardInterrupt):
                    reader.read_record()
Exemple #11
0
 def test_write_read_metadata(self, file_spec, random_access, parallelism):
     with contextlib.closing(file_spec(self.create_tempfile,
                                       random_access)) as files:
         metadata_written = riegeli.RecordsMetadata()
         metadata_written.file_comment = 'Comment'
         riegeli.set_record_type(metadata_written,
                                 records_test_pb2.SimpleMessage)
         message_written = sample_message(7, 10)
         with riegeli.RecordWriter(
                 files.writing_open(),
                 owns_dest=files.writing_should_close,
                 assumed_pos=files.writing_assumed_pos,
                 options=record_writer_options(parallelism),
                 metadata=metadata_written) as writer:
             writer.write_message(message_written)
         with riegeli.RecordReader(
                 files.reading_open(),
                 owns_src=files.reading_should_close,
                 assumed_pos=files.reading_assumed_pos) as reader:
             metadata_read = reader.read_metadata()
             self.assertEqual(metadata_read, metadata_written)
             record_type = riegeli.get_record_type(metadata_read)
             assert record_type is not None
             self.assertEqual(record_type.DESCRIPTOR.full_name,
                              'riegeli.tests.SimpleMessage')
             message_read = reader.read_message(record_type)
             assert message_read is not None
             # Serialize and deserialize because messages have descriptors of
             # different origins.
             self.assertEqual(
                 records_test_pb2.SimpleMessage.FromString(
                     message_read.SerializeToString()), message_written)
Exemple #12
0
 def test_corruption_recovery(self, file_spec, random_access, parallelism):
     with contextlib.closing(file_spec(self.create_tempfile,
                                       random_access)) as files:
         keys = []
         with riegeli.RecordWriter(
                 files.writing_open(),
                 close=files.writing_should_close,
                 assumed_pos=files.writing_assumed_pos,
                 options=record_writer_options(parallelism)) as writer:
             for i in range(23):
                 keys.append(
                     writer.write_record_with_key(sample_string(i, 10000)))
         # Corrupt the header of the chunk containing records [9, 12).
         self.corrupt_at(files, keys[9].chunk_begin + 20)
         # Read records [0, 9) and [15, 23) successfully (all except the corrupted
         # chunk and the next chunk which intersects the same block).
         skipped_regions = []
         with riegeli.RecordReader(
                 files.reading_open(),
                 close=files.reading_should_close,
                 assumed_pos=files.reading_assumed_pos,
                 recovery=skipped_regions.append) as reader:
             for i in range(9):
                 self.assertEqual(reader.read_record(),
                                  sample_string(i, 10000))
             for i in range(15, 23):
                 self.assertEqual(reader.read_record(),
                                  sample_string(i, 10000))
             self.assertIsNone(reader.read_record())
         self.assertLen(skipped_regions, 1)
         skipped_region = skipped_regions[0]
         self.assertEqual(skipped_region.begin, keys[9].numeric)
         self.assertEqual(skipped_region.end, keys[15].numeric)
Exemple #13
0
def export_v2_data(filename: str, output_filename: str,
                   num_directives: Optional[int]):
    if output_filename.endswith(".pbtxt"):
        output = open(output_filename, 'w')
        writer = None

        def write(message):
            print(message, file=output)
    else:
        output = open(output_filename, 'wb')
        writer = riegeli.RecordWriter(output)
        write = writer.write_message

    #entries, errors, options_map = loader.load_file(filename)
    entries, errors, options_map = parser.parse_file(filename)
    entries = data.sorted(entries)

    if num_directives:
        entries = itertools.islice(entries, num_directives)
    for entry in entries:
        if isinstance(entry, data.Transaction):
            pbdir = convert_Transaction(entry)
        elif isinstance(entry, data.Open):
            pbdir = convert_Open(entry)
        elif isinstance(entry, data.Close):
            pbdir = convert_Close(entry)
        elif isinstance(entry, data.Commodity):
            pbdir = convert_Commodity(entry)
        elif isinstance(entry, data.Event):
            pbdir = convert_Event(entry)
        elif isinstance(entry, data.Note):
            pbdir = convert_Note(entry)
        elif isinstance(entry, data.Query):
            pbdir = convert_Query(entry)
        elif isinstance(entry, data.Price):
            pbdir = convert_Price(entry)
        elif isinstance(entry, data.Balance):
            pbdir = convert_Balance(entry)
        elif isinstance(entry, data.Pad):
            pbdir = convert_Pad(entry)
        else:
            pbdir = None

        if pbdir is not None:
            write("#---")
            write("# {}".format(pbdir.location.lineno))
            write("#")
            write(pbdir)
            write("")

        if 0:
            print('-' * 80)
            printer.print_entry(entry)
            print(txn)
            print()

    if hasattr(writer, "close"):
        writer.close()
    output.close()
Exemple #14
0
 def test_seek_numeric(self, file_spec, parallelism):
     with contextlib.closing(
             file_spec(self.create_tempfile, random_access=True)) as files:
         keys = []
         with riegeli.RecordWriter(
                 files.writing_open(),
                 close=files.writing_should_close,
                 assumed_pos=files.writing_assumed_pos,
                 options=record_writer_options(parallelism)) as writer:
             for i in range(23):
                 pos = writer.pos
                 key = writer.write_record_with_key(sample_string(i, 10000))
                 if keys:
                     self.assertGreater(pos, keys[-1])
                 self.assertLessEqual(pos, key)
                 keys.append(key)
             writer.close()
             end_pos = writer.pos
         with riegeli.RecordReader(
                 files.reading_open(),
                 close=files.reading_should_close,
                 assumed_pos=files.reading_assumed_pos) as reader:
             reader.seek_numeric(keys[9].numeric)
             self.assertGreater(reader.pos, keys[8])
             self.assertLessEqual(reader.pos, keys[9])
             reader.seek_numeric(keys[9].numeric)
             self.assertGreater(reader.pos, keys[8])
             self.assertLessEqual(reader.pos, keys[9])
             reader.seek_numeric(keys[11].numeric)
             self.assertGreater(reader.pos, keys[10])
             self.assertLessEqual(reader.pos, keys[11])
             self.assertEqual(reader.read_record(),
                              sample_string(11, 10000))
             reader.seek_numeric(keys[9].numeric)
             self.assertGreater(reader.pos, keys[8])
             self.assertLessEqual(reader.pos, keys[9])
             self.assertEqual(reader.read_record(), sample_string(9, 10000))
             reader.seek_numeric(keys[11].numeric)
             self.assertGreater(reader.pos, keys[10])
             self.assertLessEqual(reader.pos, keys[11])
             self.assertEqual(reader.read_record(),
                              sample_string(11, 10000))
             reader.seek_numeric(keys[13].numeric)
             self.assertGreater(reader.pos, keys[12])
             self.assertLessEqual(reader.pos, keys[13])
             self.assertEqual(reader.read_record(),
                              sample_string(13, 10000))
             reader.seek_numeric(0)
             self.assertLessEqual(reader.pos, keys[0])
             self.assertEqual(reader.read_record(), sample_string(0, 10000))
             reader.seek_numeric(end_pos.numeric)
             self.assertLessEqual(reader.pos, end_pos)
             self.assertIsNone(reader.read_record())
             reader.seek_numeric(keys[11].numeric)
             self.assertGreater(reader.pos, keys[10])
             self.assertLessEqual(reader.pos, keys[11])
             reader.close()
             self.assertGreater(reader.pos, keys[10])
             self.assertLessEqual(reader.pos, keys[11])
def write_records(filename):
    print('Writing', filename)
    metadata = riegeli.RecordsMetadata()
    riegeli.set_record_type(metadata, records_test_pb2.SimpleMessage)
    with riegeli.RecordWriter(io.FileIO(filename, mode='wb'),
                              options='transpose',
                              metadata=metadata) as writer:
        writer.write_messages(sample_message(i, 100) for i in range(100))
Exemple #16
0
 def test_record_writer_exception_from_file(self, random_access, parallelism):
   byte_writer = FakeFile(random_access)
   with self.assertRaises(NotImplementedError):
     with riegeli.RecordWriter(
         byte_writer,
         assumed_pos=None if random_access else 0,
         options=record_writer_options(parallelism)) as writer:
       writer.write_record(sample_string(0, 10000))
Exemple #17
0
 def test_record_writer_exception_from_file(self, random_access,
                                            parallelism):
     byte_writer = FakeFile(random_access is RandomAccess.RANDOM_ACCESS)
     with self.assertRaises(NotImplementedError):
         with riegeli.RecordWriter(
                 byte_writer,
                 assumed_pos=(0 if random_access is
                              RandomAccess.SEQUENTIAL_ACCESS_EXPLICIT else
                              None),
                 options=record_writer_options(parallelism)) as writer:
             writer.write_record(sample_string(0, 10000))
Exemple #18
0
    def test_reads_without_traces_or_triple_raises(self):
        filename = _get_tmp_file_name()
        with riegeli.RecordWriter(
                io.FileIO(filename, mode='wb'),
                options='transpose',
                metadata=riegeli.RecordsMetadata()) as writer:
            writer.write_messages(_get_exposures_missing_required_data())

        data_loader = abesim_data_loader.AbesimExposureDataLoader(
            filename, unconfirmed_exposures=False)
        with self.assertRaises(ValueError):
            _ = data_loader.get_next_batch(batch_size=1)
Exemple #19
0
  def _create_files(self):
    filenames = []
    for i in range(self._num_files):
      filename = os.path.join(self.get_temp_dir(), 'riegeli.{}'.format(i))
      filenames.append(filename)

      # Note: if records were serialized proto messages, passing
      # options='transpose' to RecordWriter would make compression better.
      with riegeli.RecordWriter(tf.io.gfile.GFile(filename, 'wb')) as writer:
        for j in range(self._num_records):
          writer.write_record(self._record(i, j))
    return filenames
Exemple #20
0
    def write_examples(cls, path: type_utils.PathLike,
                       iterator: Iterable[bytes]) -> Optional[Iterable[Any]]:
        """Write examples from given iterator in given path.

    Args:
      path: Path where to write the examples.
      iterator: Iterable of examples.

    Returns:
      List of record positions for each record in the given iterator.
    """
        with riegeli.RecordWriter(tf.io.gfile.GFile(os.fspath(path), 'wb'),
                                  options='transpose') as writer:
            return writer.write_records_with_keys(records=iterator)
Exemple #21
0
    def test_reads_without_traces(self):
        filename = _get_tmp_file_name()
        with riegeli.RecordWriter(
                io.FileIO(filename, mode='wb'),
                options='transpose',
                metadata=riegeli.RecordsMetadata()) as writer:
            writer.write_messages(_get_exposures_without_proximity_traces())

        data_loader = abesim_data_loader.AbesimExposureDataLoader(
            filename, unconfirmed_exposures=False)
        exposures, labels, grouping = data_loader.get_next_batch(batch_size=1)
        self.assertCountEqual(exposures, [([1.0], 0, 30)])
        self.assertCountEqual(labels, [1])
        self.assertCountEqual(grouping, [1])
Exemple #22
0
 def test_write_read_records(self, file_spec, random_access, parallelism):
     with contextlib.closing(file_spec(self.create_tempfile,
                                       random_access)) as files:
         with riegeli.RecordWriter(
                 files.writing_open(),
                 owns_dest=files.writing_should_close,
                 assumed_pos=files.writing_assumed_pos,
                 options=record_writer_options(parallelism)) as writer:
             writer.write_records(
                 sample_string(i, 10000) for i in range(23))
         with riegeli.RecordReader(
                 files.reading_open(),
                 owns_src=files.reading_should_close,
                 assumed_pos=files.reading_assumed_pos) as reader:
             self.assertEqual(list(reader.read_records()),
                              [sample_string(i, 10000) for i in range(23)])
Exemple #23
0
 def test_write_read_messages_with_keys(self, file_spec, random_access,
                                        parallelism):
   with contextlib.closing(file_spec(self.create_tempfile,
                                     random_access)) as files:
     with riegeli.RecordWriter(
         files.writing_open(),
         close=files.writing_should_close,
         assumed_pos=files.writing_assumed_pos,
         options=record_writer_options(parallelism)) as writer:
       keys = writer.write_messages_with_keys(
           sample_message(i, 10000) for i in range(23))
     with riegeli.RecordReader(
         files.reading_open(),
         close=files.reading_should_close,
         assumed_pos=files.reading_assumed_pos) as reader:
       self.assertEqual(
           list(
               reader.read_messages_with_keys(records_test_pb2.SimpleMessage)),
           [(keys[i], sample_message(i, 10000)) for i in range(23)])
Exemple #24
0
    def test_search_for_message(self, file_spec):
        with contextlib.closing(
                file_spec(self.create_tempfile,
                          random_access=RandomAccess.RANDOM_ACCESS)) as files:
            with riegeli.RecordWriter(
                    files.writing_open(),
                    owns_dest=files.writing_should_close,
                    assumed_pos=files.writing_assumed_pos,
                    options=record_writer_options(parallelism=0)) as writer:
                positions = []
                for i in range(23):
                    writer.write_message(sample_message(i, 10000))
                    positions.append(writer.last_pos)
                writer.close()
                end_pos = writer.pos
            with riegeli.RecordReader(
                    files.reading_open(),
                    owns_src=files.reading_should_close,
                    assumed_pos=files.reading_assumed_pos) as reader:

                def test_function(search_target):
                    def test(message):
                        return (message.id > search_target) - (message.id <
                                                               search_target)

                    return test

                self.assertEqual(
                    reader.search_for_message(records_test_pb2.SimpleMessage,
                                              test_function(7)), 0)
                self.assertEqual(reader.pos, positions[7])
                self.assertEqual(
                    reader.search_for_message(records_test_pb2.SimpleMessage,
                                              test_function(0)), 0)
                self.assertEqual(reader.pos, positions[0])
                self.assertEqual(
                    reader.search_for_message(records_test_pb2.SimpleMessage,
                                              test_function(22)), 0)
                self.assertEqual(reader.pos, positions[22])
                self.assertEqual(
                    reader.search_for_message(records_test_pb2.SimpleMessage,
                                              test_function(23)), -1)
                self.assertEqual(reader.pos, end_pos)
Exemple #25
0
 def test_seek_back(self, file_spec):
     with contextlib.closing(
             file_spec(self.create_tempfile, random_access=True)) as files:
         with riegeli.RecordWriter(
                 files.writing_open(),
                 close=files.writing_should_close,
                 assumed_pos=files.writing_assumed_pos,
                 options=record_writer_options(parallelism=0)) as writer:
             for i in range(23):
                 writer.write_record(sample_string(i, 10000))
         with riegeli.RecordReader(
                 files.reading_open(),
                 close=files.reading_should_close,
                 assumed_pos=files.reading_assumed_pos) as reader:
             reader.seek_numeric(reader.size())
             for i in reversed(range(23)):
                 self.assertTrue(reader.seek_back())
                 self.assertEqual(reader.read_record(),
                                  sample_string(i, 10000))
                 self.assertTrue(reader.seek_back())
             self.assertFalse(reader.seek_back())
Exemple #26
0
    def write_examples(
        cls,
        path: type_utils.PathLike,
        iterator: Iterable[type_utils.KeySerializedExample],
    ) -> Optional[ExamplePositions]:
        """Write examples from given iterator in given path.

    Args:
      path: Path where to write the examples.
      iterator: Iterable of examples.

    Returns:
      List of record positions for each record in the given iterator.
    """
        positions = []
        import riegeli  # pylint: disable=g-import-not-at-top
        with tf.io.gfile.GFile(os.fspath(path), 'wb') as f:
            with riegeli.RecordWriter(f, options='transpose') as writer:
                for _, record in iterator:
                    writer.write_record(record)
                    positions.append(writer.last_pos)
        return positions
Exemple #27
0
 def test_write_read_messages_with_field_projection(self, file_spec,
                                                    random_access,
                                                    parallelism):
     with contextlib.closing(file_spec(self.create_tempfile,
                                       random_access)) as files:
         with riegeli.RecordWriter(
                 files.writing_open(),
                 owns_dest=files.writing_should_close,
                 assumed_pos=files.writing_assumed_pos,
                 options=record_writer_options(parallelism,
                                               transpose=True)) as writer:
             writer.write_messages(
                 sample_message(i, 10000) for i in range(23))
         with riegeli.RecordReader(
                 files.reading_open(),
                 owns_src=files.reading_should_close,
                 assumed_pos=files.reading_assumed_pos,
                 field_projection=[[
                     records_test_pb2.SimpleMessage.DESCRIPTOR.
                     fields_by_name['id'].number
                 ]]) as reader:
             self.assertEqual(
                 list(reader.read_messages(records_test_pb2.SimpleMessage)),
                 [sample_message_id_only(i) for i in range(23)])
Exemple #28
0
 def write_examples(cls, path: type_utils.PathLike,
                    iterator: Iterable[bytes]):
     """Write examples from given iterator in given path."""
     with riegeli.RecordWriter(tf.io.gfile.GFile(os.fspath(path), 'wb'),
                               options='transpose') as writer:
         writer.write_records(records=iterator)
Exemple #29
0
 def test_seek_numeric(self, file_spec, parallelism):
     with contextlib.closing(
             file_spec(self.create_tempfile,
                       random_access=RandomAccess.RANDOM_ACCESS)) as files:
         positions = []
         with riegeli.RecordWriter(
                 files.writing_open(),
                 owns_dest=files.writing_should_close,
                 assumed_pos=files.writing_assumed_pos,
                 options=record_writer_options(parallelism)) as writer:
             for i in range(23):
                 pos = writer.pos
                 writer.write_record(sample_string(i, 10000))
                 canonical_pos = writer.last_pos
                 if positions:
                     self.assertGreater(pos, positions[-1])
                 self.assertLessEqual(pos, canonical_pos)
                 positions.append(canonical_pos)
             writer.close()
             end_pos = writer.pos
         with riegeli.RecordReader(
                 files.reading_open(),
                 owns_src=files.reading_should_close,
                 assumed_pos=files.reading_assumed_pos) as reader:
             reader.seek_numeric(positions[9].numeric)
             self.assertGreater(reader.pos, positions[8])
             self.assertLessEqual(reader.pos, positions[9])
             reader.seek_numeric(positions[9].numeric)
             self.assertGreater(reader.pos, positions[8])
             self.assertLessEqual(reader.pos, positions[9])
             reader.seek_numeric(positions[11].numeric)
             self.assertGreater(reader.pos, positions[10])
             self.assertLessEqual(reader.pos, positions[11])
             self.assertEqual(reader.read_record(),
                              sample_string(11, 10000))
             reader.seek_numeric(positions[9].numeric)
             self.assertGreater(reader.pos, positions[8])
             self.assertLessEqual(reader.pos, positions[9])
             self.assertEqual(reader.read_record(), sample_string(9, 10000))
             reader.seek_numeric(positions[11].numeric)
             self.assertGreater(reader.pos, positions[10])
             self.assertLessEqual(reader.pos, positions[11])
             self.assertEqual(reader.read_record(),
                              sample_string(11, 10000))
             reader.seek_numeric(positions[13].numeric)
             self.assertGreater(reader.pos, positions[12])
             self.assertLessEqual(reader.pos, positions[13])
             self.assertEqual(reader.read_record(),
                              sample_string(13, 10000))
             reader.seek_numeric(0)
             self.assertLessEqual(reader.pos, positions[0])
             self.assertEqual(reader.read_record(), sample_string(0, 10000))
             reader.seek_numeric(end_pos.numeric)
             self.assertLessEqual(reader.pos, end_pos)
             self.assertIsNone(reader.read_record())
             reader.seek_numeric(positions[11].numeric)
             self.assertGreater(reader.pos, positions[10])
             self.assertLessEqual(reader.pos, positions[11])
             reader.close()
             self.assertGreater(reader.pos, positions[10])
             self.assertLessEqual(reader.pos, positions[11])