Ejemplo n.º 1
0
    def _verify_read_up_to_out(self, shared_queue):
        with self.cached_session():
            num_files = 3
            num_records_per_file = 7
            tfrecord_paths = test_utils.create_tfrecord_files(
                self.get_temp_dir(),
                num_files=num_files,
                num_records_per_file=num_records_per_file)

        p_reader = parallel_reader.ParallelReader(io_ops.TFRecordReader,
                                                  shared_queue,
                                                  num_readers=5)

        data_files = parallel_reader.get_data_files(tfrecord_paths)
        filename_queue = input_lib.string_input_producer(data_files,
                                                         num_epochs=1)
        key, value = p_reader.read_up_to(filename_queue, 4)

        count0 = 0
        count1 = 0
        count2 = 0
        all_keys_count = 0
        all_values_count = 0

        sv = supervisor.Supervisor(logdir=self.get_temp_dir())
        with sv.prepare_or_wait_for_session() as sess:
            sv.start_queue_runners(sess)
            while True:
                try:
                    current_keys, current_values = sess.run([key, value])
                    self.assertEqual(len(current_keys), len(current_values))
                    all_keys_count += len(current_keys)
                    all_values_count += len(current_values)
                    for current_key in current_keys:
                        if '0-of-3' in str(current_key):
                            count0 += 1
                        if '1-of-3' in str(current_key):
                            count1 += 1
                        if '2-of-3' in str(current_key):
                            count2 += 1
                except errors_impl.OutOfRangeError:
                    break

        self.assertEqual(count0, num_records_per_file)
        self.assertEqual(count1, num_records_per_file)
        self.assertEqual(count2, num_records_per_file)
        self.assertEqual(all_keys_count, num_files * num_records_per_file)
        self.assertEqual(all_values_count, all_keys_count)
        self.assertEqual(count0 + count1 + count2, all_keys_count)
Ejemplo n.º 2
0
    def _verify_all_data_sources_read(self, shared_queue):
        with self.cached_session():
            tfrecord_paths = test_utils.create_tfrecord_files(
                self.get_temp_dir(), num_files=3)

        num_readers = len(tfrecord_paths)
        p_reader = parallel_reader.ParallelReader(io_ops.TFRecordReader,
                                                  shared_queue,
                                                  num_readers=num_readers)

        data_files = parallel_reader.get_data_files(tfrecord_paths)
        filename_queue = input_lib.string_input_producer(data_files)
        key, value = p_reader.read(filename_queue)

        count0 = 0
        count1 = 0
        count2 = 0

        num_reads = 50

        sv = supervisor.Supervisor(logdir=self.get_temp_dir())
        with sv.prepare_or_wait_for_session() as sess:
            sv.start_queue_runners(sess)

            for _ in range(num_reads):
                current_key, _ = sess.run([key, value])
                if '0-of-3' in str(current_key):
                    count0 += 1
                if '1-of-3' in str(current_key):
                    count1 += 1
                if '2-of-3' in str(current_key):
                    count2 += 1

        self.assertGreater(count0, 0)
        self.assertGreater(count1, 0)
        self.assertGreater(count2, 0)
        self.assertEqual(count0 + count1 + count2, num_reads)