예제 #1
0
    def _read_examples(self, train_test_val_split, path):
        """ Read examples from path.

    Args:
      train_test_val_split: a list of three floats that specify the relevant
        sizes of train, test and validation sets.
      path: the directory path where train, test and validation datasets can be
        found.

    Returns:
      A list of tf.Example.
    """
        train_examples = []
        eval_examples = []
        test_examples = []
        with genomics_reader.TFRecordReader(
                ngs_errors.get_train_filename(path),
                proto=example_pb2.Example) as fin:
            train_examples = list(fin)
        with genomics_reader.TFRecordReader(ngs_errors.get_eval_filename(path),
                                            proto=example_pb2.Example) as fin:
            eval_examples = list(fin)
        with genomics_reader.TFRecordReader(ngs_errors.get_test_filename(path),
                                            proto=example_pb2.Example) as fin:
            test_examples = list(fin)

        n_examples = len(train_examples) + len(eval_examples) + len(
            test_examples)
        return train_examples + eval_examples + test_examples
예제 #2
0
def Reader(path, proto=None, compression_type=None):
  """A TFRecordReader that defaults to tf.Example protos."""
  if not proto:
    proto = example_pb2.Example

  return genomics_reader.TFRecordReader(
      path, proto, compression_type=compression_type)
예제 #3
0
 def testCompressedExplicit(self):
     reader = genomics_reader.TFRecordReader(
         test_utils.genomics_core_testdata('test_features.gff.tfrecord.gz'),
         gff_pb2.GffRecord(),
         compression_type='GZIP')
     records = list(reader.iterate())
     self.assertEqual('GenBank', records[0].source)
     self.assertEqual('ctg123', records[1].range.reference_name)
예제 #4
0
 def testUncompressed(self):
     reader = genomics_reader.TFRecordReader(
         test_utils.genomics_core_testdata('test_features.gff.tfrecord'),
         gff_pb2.GffRecord(),
     )
     records = list(reader.iterate())
     self.assertEqual('GenBank', records[0].source)
     self.assertEqual('ctg123', records[1].range.reference_name)
     self.assertNotEqual(reader.c_reader, 0)
예제 #5
0
def main(argv):
    if len(argv) != 3:
        print('Usage: {} <filename> <proto_name>\n'.format(argv[0]))
        sys.exit(-1)

    filename = argv[1]
    proto_name = argv[2]

    if proto_name not in PROTO_DB:
        print('Unknown protocol buffer name {}\n'.format(proto_name))
        print('Known names are: {}\n'.format(' '.join(PROTO_DB.keys())))
        sys.exit(-1)

    proto = PROTO_DB[proto_name]

    with genomics_reader.TFRecordReader(filename, proto=proto) as reader:
        for record in reader:
            print(text_format.MessageToString(record))
예제 #6
0
 def testTwoIteratorsAtTheSameTime(self):
     dreader = genomics_reader.TFRecordReader('0,1,2,3,4,5', DummyProto())
     iter2 = iter(dreader)
     for i in range(6):
         self.assertEqual(str(i), dreader.next())
         self.assertEqual(str(i), iter2.next())
예제 #7
0
 def testMock(self):
     reader = genomics_reader.TFRecordReader('a,b,c,d,e', DummyProto())
     self.assertEqual(['a', 'b', 'c', 'd', 'e'], list(reader))
예제 #8
0
def _read_examples(path):
    with genomics_reader.TFRecordReader(path,
                                        proto=example_pb2.Example) as fin:
        return list(fin)