예제 #1
0
    def check_expected_structure(self, sampler):
        """Checks the stream of episode descriptions is as expected."""
        chunk_sizes = sampler.compute_chunk_sizes()
        batch_size = sum(chunk_sizes)
        placeholder_id = len(self.dataset_spec.get_classes(self.split))

        # We need to go through TF and back because
        # `reader.decompress_episode_representation` operates on TF tensors.
        generator = functools.partial(reader.episode_representation_generator,
                                      dataset_spec=self.dataset_spec,
                                      split=self.split,
                                      pool=None,
                                      sampler=sampler)
        tf_generator = tf.data.Dataset.from_generator(
            generator, tf.int64,
            tf.TensorShape([None,
                            2])).map(reader.decompress_episode_representation)
        iterator = tf_generator.make_one_shot_iterator()
        next_item = iterator.get_next()

        for _ in range(3):
            with self.cached_session() as sess:
                batch = sess.run(next_item)

            self.assertEqual(len(batch), batch_size)

            flush_chunk, support_chunk, query_chunk = split_into_chunks(
                batch, chunk_sizes)

            # flush_chunk is slightly oversized: if we actually had support_chunk_size
            # + query_chunk_size examples remaining, we could have used them.
            # Therefore, the last element of flush_chunk should be padding.
            self.assertEqual(flush_chunk[-1], placeholder_id)
            # TODO(lamblinp): check more about the content of flush_chunk

            # The padding should be at the end of each chunk.
            for chunk in (flush_chunk, support_chunk, query_chunk):
                num_actual_examples = sum(class_id != placeholder_id
                                          for class_id in chunk)
                self.assertNotIn(placeholder_id, chunk[:num_actual_examples])
                self.assertTrue(
                    all(placeholder_id == class_id
                        for class_id in chunk[num_actual_examples:]))
예제 #2
0
 def create_ta(s):
     return tf.TensorArray(dtype=s.dtype,
                           size=self._train_interval + 1,
                           element_shape=tf.TensorShape(
                               [batch_size]).concatenate(s.shape))
예제 #3
0
 def create_ta(s):
     return tf.TensorArray(dtype=s.dtype,
                           size=maximum_iterations,
                           element_shape=tf.TensorShape(
                               [batch_size]).concatenate(s.shape))