Example #1
0
 def test_sequence_parsing(self, batch_size):
     file_pattern = os.path.join(FLAGS.test_tmpdir, 'test.tfrecord')
     sequence_length = 3
     if not os.path.exists(file_pattern):
         self._write_test_sequence_examples(sequence_length, file_pattern)
     dataset = tfdata.parallel_read(file_patterns=file_pattern)
     # Features
     state_spec_1 = tensorspec_utils.ExtendedTensorSpec(
         shape=(TEST_IMAGE_SHAPE),
         dtype=tf.uint8,
         is_sequence=True,
         name='image_sequence_feature',
         data_format='JPEG')
     state_spec_2 = tensorspec_utils.ExtendedTensorSpec(
         shape=(2),
         dtype=tf.float32,
         is_sequence=True,
         name='sequence_feature')
     feature_tspec = PoseEnvFeature(state=state_spec_1, action=state_spec_2)
     feature_tspec = tensorspec_utils.add_sequence_length_specs(
         feature_tspec)
     # Labels
     reward_spec = tensorspec_utils.ExtendedTensorSpec(
         shape=(),
         dtype=tf.int64,
         is_sequence=False,
         name='context_feature')
     label_tspec = PoseEnvLabel(reward=reward_spec)
     label_tspec = tensorspec_utils.add_sequence_length_specs(label_tspec)
     dataset = dataset.batch(batch_size, drop_remainder=True)
     dataset = tfdata.serialized_to_parsed(dataset, feature_tspec,
                                           label_tspec)
     features, labels = dataset.make_one_shot_iterator().get_next()
     # Check tensor shapes.
     self.assertAllEqual([batch_size, None] + TEST_IMAGE_SHAPE,
                         features.state.shape.as_list())
     self.assertAllEqual([batch_size, None, 2],
                         features.action.shape.as_list())
     self.assertAllEqual([batch_size],
                         features.state_length.shape.as_list())
     self.assertAllEqual([batch_size],
                         features.action_length.shape.as_list())
     self.assertAllEqual([batch_size], labels.reward.shape.as_list())
     with self.session() as session:
         features_, labels_ = session.run([features, labels])
         # Check that images are equal.
         for i in range(3):
             img = TEST_IMAGE * i
             self.assertAllEqual(img, features_.state[0, i])
         # Check that numpy shapes are equal.
         self.assertAllEqual([batch_size, sequence_length] +
                             TEST_IMAGE_SHAPE, features_.state.shape)
         self.assertAllEqual([sequence_length] * batch_size,
                             features_.state_length)
         self.assertAllEqual([batch_size, sequence_length, 2],
                             features_.action.shape)
         self.assertAllEqual([batch_size], labels_.reward.shape)
Example #2
0
 def test_add_sequence_length_specs(self):
     input_spec = utils.TensorSpecStruct(image1=D1, actions=S7)
     modified_spec = utils.add_sequence_length_specs(input_spec)
     expected_length_spec = utils.ExtendedTensorSpec(
         shape=(), dtype=tf.int64, name='sequence_actions_length')
     self.assertEqual(modified_spec.actions_length, expected_length_spec)