Exemplo n.º 1
0
def input_pipeline(file_pattern, mode, capacity=64):
    keys_to_features = {
        "source": tf.VarLenFeature(tf.int64),
        "target": tf.VarLenFeature(tf.int64),
        "source_length": tf.FixedLenFeature([1], tf.int64),
        "target_length": tf.FixedLenFeature([1], tf.int64)
    }

    items_to_handlers = {
        "source": tfexample_decoder.Tensor("source"),
        "target": tfexample_decoder.Tensor("target"),
        "source_length": tfexample_decoder.Tensor("source_length"),
        "target_length": tfexample_decoder.Tensor("target_length")
    }

    # Now the non-trivial case construction.
    with tf.name_scope("examples_queue"):
        training = (mode == "train")
        # Read serialized examples using slim parallel_reader.
        num_epochs = None if training else 1
        data_files = parallel_reader.get_data_files(file_pattern)
        num_readers = min(4 if training else 1, len(data_files))
        _, examples = parallel_reader.parallel_read([file_pattern],
                                                    tf.TFRecordReader,
                                                    num_epochs=num_epochs,
                                                    shuffle=training,
                                                    capacity=2 * capacity,
                                                    min_after_dequeue=capacity,
                                                    num_readers=num_readers)

        decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
                                                     items_to_handlers)

        decoded = decoder.decode(examples, items=list(items_to_handlers))
        examples = {}

        for (field, tensor) in zip(keys_to_features, decoded):
            examples[field] = tensor

        # We do not want int64s as they do are not supported on GPUs.
        return {k: tf.to_int32(v) for (k, v) in six.iteritems(examples)}
Exemplo n.º 2
0
def get_observation():
    with tf.Session() as sess:
        _, string_tensor = parallel_reader.parallel_read(
            [DATA_PATH],
            reader_class=tf.TFRecordReader,
            num_epochs=1,
            num_readers=1,
            shuffle=True,
            dtypes=[tf.string, tf.string],
            capacity=500,
            min_after_dequeue=200)

        decoder = tf_example_decoder.TfExampleDecoder(
            load_instance_masks=True,
            instance_mask_type=input_reader_pb2.PNG_MASKS)
        decoded_data = decoder.decode(string_tensor)

        image = decoded_data['image']
        box_list = decoded_data['groundtruth_boxes']
        class_list = decoded_data['groundtruth_classes']
        mask_list = decoded_data['groundtruth_instance_masks']
        file_name = decoded_data['filename']

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer(),
                           tf.tables_initializer())
        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        while True:
            img, boxes, classes, masks, file_names = sess.run(
                [image, box_list, class_list, mask_list, file_name])

            print("current file {} contains {} masks".format(
                file_names.decode(), len(boxes)))
            sys.stdout.flush()

            # if masks.shape[0] > 5:
            #     return masks

            image_with_labels(img,
                              boxes,
                              masks,
                              classes,
                              label_dict=LABEL_DICT)
            plt.tight_layout()

            # if AUTO:
            #     plt.pause(0.5)
            # elif plt.waitforbuttonpress():
            #     break

            plt.waitforbuttonpress()
            plt.close()

        coord.request_stop()
        coord.join(threads)

        return 1