Esempio n. 1
0
    def _verify_read_up_to_out(self, shared_queue):
        with self.test_session():
            num_files = 3
            num_records_per_file = 7
            tfrecord_paths = test_utils.create_tfrecord_files(
                self.get_temp_dir(),
                num_files=num_files,
                num_records_per_file=num_records_per_file)

        p_reader = parallel_reader.ParallelReader(io_ops.TFRecordReader,
                                                  shared_queue,
                                                  num_readers=5)

        data_files = parallel_reader.get_data_files(tfrecord_paths)
        filename_queue = input_lib.string_input_producer(data_files,
                                                         num_epochs=1)
        key, value = p_reader.read_up_to(filename_queue, 4)

        count0 = 0
        count1 = 0
        count2 = 0
        all_keys_count = 0
        all_values_count = 0

        sv = supervisor.Supervisor(logdir=self.get_temp_dir())
        with sv.prepare_or_wait_for_session() as sess:
            sv.start_queue_runners(sess)
            while True:
                try:
                    current_keys, current_values = sess.run([key, value])
                    self.assertEquals(len(current_keys), len(current_values))
                    all_keys_count += len(current_keys)
                    all_values_count += len(current_values)
                    for current_key in current_keys:
                        if '0-of-3' in str(current_key):
                            count0 += 1
                        if '1-of-3' in str(current_key):
                            count1 += 1
                        if '2-of-3' in str(current_key):
                            count2 += 1
                except errors_impl.OutOfRangeError:
                    break

        self.assertEquals(count0, num_records_per_file)
        self.assertEquals(count1, num_records_per_file)
        self.assertEquals(count2, num_records_per_file)
        self.assertEquals(all_keys_count, num_files * num_records_per_file)
        self.assertEquals(all_values_count, all_keys_count)
        self.assertEquals(count0 + count1 + count2, all_keys_count)
Esempio n. 2
0
def read_and_decode(filename_queue, tfrec_len):
    batch_size = 5
    IMAGE_WIDTH = 10
    common_queue = tf.RandomShuffleQueue(capacity=10 + 3 * batch_size,
                                         min_after_dequeue=10 + batch_size,
                                         dtypes=[tf.string, tf.string])

    p_reader = parallel_reader.ParallelReader(io_ops.TFRecordReader,
                                              common_queue,
                                              num_readers=tfrec_len)

    _, serialized_example = p_reader.read(filename_queue)

    features = tf.parse_single_example(
        serialized_example,
        # Defaults are not specified since both keys are required.
        features={
            'height': tf.FixedLenFeature([], tf.int64),
            'image_raw': tf.FixedLenFeature([], tf.string)
        })

    bucket_boundaries = [7, 11, 18, 23, 33]
    image = tf.decode_raw(features['image_raw'], tf.uint8)

    height = tf.cast(features['height'], tf.int32)

    image_shape = tf.stack([height, IMAGE_WIDTH])
    print("*********image shape****", image_shape)

    image = tf.reshape(image, image_shape)

    (l, images) = tf.contrib.training.bucket_by_sequence_length(height, [image], batch_size,\
                bucket_boundaries,capacity=1 * batch_size, dynamic_pad=True)

    table_index = tf.train.range_input_producer(
        batch_size, shuffle=True).dequeue_many(batch_size)

    ret = tf.gather(images[0], table_index)

    return (l, ret)
    def _verify_all_data_sources_read(self, shared_queue):
        with self.cached_session():
            tfrecord_paths = test_utils.create_tfrecord_files(
                self.get_temp_dir(), num_files=3)

        num_readers = len(tfrecord_paths)
        p_reader = parallel_reader.ParallelReader(io_ops.TFRecordReader,
                                                  shared_queue,
                                                  num_readers=num_readers)

        data_files = parallel_reader.get_data_files(tfrecord_paths)
        filename_queue = input_lib.string_input_producer(data_files)
        key, value = p_reader.read(filename_queue)

        count0 = 0
        count1 = 0
        count2 = 0

        num_reads = 50

        sv = supervisor.Supervisor(logdir=self.get_temp_dir())
        with sv.prepare_or_wait_for_session() as sess:
            sv.start_queue_runners(sess)

            for _ in range(num_reads):
                current_key, _ = sess.run([key, value])
                if '0-of-3' in str(current_key):
                    count0 += 1
                if '1-of-3' in str(current_key):
                    count1 += 1
                if '2-of-3' in str(current_key):
                    count2 += 1

        self.assertGreater(count0, 0)
        self.assertGreater(count1, 0)
        self.assertGreater(count2, 0)
        self.assertEquals(count0 + count1 + count2, num_reads)