def testDecodeExampleWithBoundingBoxDense(self):
        num_bboxes = 10
        np_ymin = np.random.rand(num_bboxes, 1)
        np_xmin = np.random.rand(num_bboxes, 1)
        np_ymax = np.random.rand(num_bboxes, 1)
        np_xmax = np.random.rand(num_bboxes, 1)
        np_bboxes = np.hstack([np_ymin, np_xmin, np_ymax, np_xmax])

        example = example_pb2.Example(features=feature_pb2.Features(
            feature={
                'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
                'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
                'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
                'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
            }))
        serialized_example = example.SerializeToString()

        with self.test_session():
            serialized_example = array_ops.reshape(serialized_example,
                                                   shape=[])

            keys_to_features = {
                'image/object/bbox/ymin':
                parsing_ops.FixedLenSequenceFeature([],
                                                    dtypes.float32,
                                                    allow_missing=True),
                'image/object/bbox/xmin':
                parsing_ops.FixedLenSequenceFeature([],
                                                    dtypes.float32,
                                                    allow_missing=True),
                'image/object/bbox/ymax':
                parsing_ops.FixedLenSequenceFeature([],
                                                    dtypes.float32,
                                                    allow_missing=True),
                'image/object/bbox/xmax':
                parsing_ops.FixedLenSequenceFeature([],
                                                    dtypes.float32,
                                                    allow_missing=True),
            }

            items_to_handlers = {
                'object/bbox':
                tfexample_decoder.BoundingBox(['ymin', 'xmin', 'ymax', 'xmax'],
                                              'image/object/bbox/'),
            }

            decoder = tfexample_decoder.TFExampleDecoder(
                keys_to_features, items_to_handlers)
            [tf_bboxes] = decoder.decode(serialized_example, ['object/bbox'])
            bboxes = tf_bboxes.eval()

        self.assertAllClose(np_bboxes, bboxes)
Ejemplo n.º 2
0
    def test_decode_example_with_bounding_box(self):
        num_bboxes = 10
        np_ymin = np.random.rand(num_bboxes, 1)
        np_xmin = np.random.rand(num_bboxes, 1)
        np_ymax = np.random.rand(num_bboxes, 1)
        np_xmax = np.random.rand(num_bboxes, 1)
        np_bboxes = np.hstack([np_ymin, np_xmin, np_ymax, np_xmax])

        example = example_pb2.Example(features=feature_pb2.Features(
            feature={
                'image/object/bbox/ymin': self._encode_float_feature(np_ymin),
                'image/object/bbox/xmin': self._encode_float_feature(np_xmin),
                'image/object/bbox/ymax': self._encode_float_feature(np_ymax),
                'image/object/bbox/xmax': self._encode_float_feature(np_xmax),
            }))
        serialized_example = example.SerializeToString()

        with self.test_session():
            serialized_example = array_ops.reshape(serialized_example,
                                                   shape=[])

            keys_to_features = {
                'image/object/bbox/ymin':
                parsing_ops.VarLenFeature(dtypes.float32),
                'image/object/bbox/xmin':
                parsing_ops.VarLenFeature(dtypes.float32),
                'image/object/bbox/ymax':
                parsing_ops.VarLenFeature(dtypes.float32),
                'image/object/bbox/xmax':
                parsing_ops.VarLenFeature(dtypes.float32),
            }

            items_to_handlers = {
                'object/bbox':
                tfexample_decoder.BoundingBox(['ymin', 'xmin', 'ymax', 'xmax'],
                                              'image/object/bbox/'),
            }

            decoder = TFExampleDecoder(keys_to_features, items_to_handlers)
            [tf_bboxes] = decoder.decode(serialized_example, ['object/bbox'])
            bboxes = tf_bboxes.eval()

        self.assertAllClose(np_bboxes, bboxes)