def input_pipeline(file_pattern, mode, capacity=64): keys_to_features = { "source": tf.VarLenFeature(tf.int64), "target": tf.VarLenFeature(tf.int64), "source_length": tf.FixedLenFeature([1], tf.int64), "target_length": tf.FixedLenFeature([1], tf.int64) } items_to_handlers = { "source": tfexample_decoder.Tensor("source"), "target": tfexample_decoder.Tensor("target"), "source_length": tfexample_decoder.Tensor("source_length"), "target_length": tfexample_decoder.Tensor("target_length") } # Now the non-trivial case construction. with tf.name_scope("examples_queue"): training = (mode == "train") # Read serialized examples using slim parallel_reader. num_epochs = None if training else 1 data_files = parallel_reader.get_data_files(file_pattern) num_readers = min(4 if training else 1, len(data_files)) _, examples = parallel_reader.parallel_read([file_pattern], tf.TFRecordReader, num_epochs=num_epochs, shuffle=training, capacity=2 * capacity, min_after_dequeue=capacity, num_readers=num_readers) decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) decoded = decoder.decode(examples, items=list(items_to_handlers)) examples = {} for (field, tensor) in zip(keys_to_features, decoded): examples[field] = tensor # We do not want int64s as they do are not supported on GPUs. return {k: tf.to_int32(v) for (k, v) in six.iteritems(examples)}
def get_observation(): with tf.Session() as sess: _, string_tensor = parallel_reader.parallel_read( [DATA_PATH], reader_class=tf.TFRecordReader, num_epochs=1, num_readers=1, shuffle=True, dtypes=[tf.string, tf.string], capacity=500, min_after_dequeue=200) decoder = tf_example_decoder.TfExampleDecoder( load_instance_masks=True, instance_mask_type=input_reader_pb2.PNG_MASKS) decoded_data = decoder.decode(string_tensor) image = decoded_data['image'] box_list = decoded_data['groundtruth_boxes'] class_list = decoded_data['groundtruth_classes'] mask_list = decoded_data['groundtruth_instance_masks'] file_name = decoded_data['filename'] init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer()) sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) while True: img, boxes, classes, masks, file_names = sess.run( [image, box_list, class_list, mask_list, file_name]) print("current file {} contains {} masks".format( file_names.decode(), len(boxes))) sys.stdout.flush() # if masks.shape[0] > 5: # return masks image_with_labels(img, boxes, masks, classes, label_dict=LABEL_DICT) plt.tight_layout() # if AUTO: # plt.pause(0.5) # elif plt.waitforbuttonpress(): # break plt.waitforbuttonpress() plt.close() coord.request_stop() coord.join(threads) return 1