def read_tfr(tfr_pattern, target_len): tfr_files = tfrecord_batcher.order_tfrecords(tfr_pattern) if tfr_files: dataset = tf.data.Dataset.list_files(tf.constant(tfr_files), shuffle=False) else: dataset = tf.data.Dataset.list_files(tfr_pattern) dataset = dataset.flat_map(file_to_records) dataset = dataset.batch(1) dataset = dataset.map(parse_proto) iterator = dataset.make_one_shot_iterator() next_op = iterator.get_next() seqs_1hot = [] targets = [] with tf.Session() as sess: next_datum = sess.run(next_op) while next_datum: seq_1hot = next_datum["sequence"].reshape((-1, 4)) targets1 = next_datum["targets"].reshape(target_len, -1) seqs_1hot.append(seq_1hot) targets.append(targets1) try: next_datum = sess.run(next_op) except tf.errors.OutOfRangeError: next_datum = False seqs_1hot = np.array(seqs_1hot) targets = np.array(targets) return seqs_1hot, targets
def make_next_op(tfr_pattern): # read TF Records # dataset = tf.data.Dataset.list_files(tfr_pattern) tfr_files = order_tfrecords(tfr_pattern) if tfr_files: dataset = tf.data.Dataset.list_files(tf.constant(tfr_files), shuffle=False) else: print("Cannot order TFRecords %s" % tfr_pattern, file=sys.stderr) dataset = tf.data.Dataset.list_files(tfr_pattern) def file_to_records(filename): return tf.data.TFRecordDataset(filename, compression_type="ZLIB") dataset = dataset.flat_map(file_to_records) dataset = dataset.batch(1) dataset = dataset.map(parse_proto) iterator = dataset.make_one_shot_iterator() try: next_op = iterator.get_next() except tf.errors.OutOfRangeError: print("TFRecord pattern %s is empty" % self.tfr_pattern, file=sys.stderr) exit(1) return next_op