Exemple #1
0
    def testDecodeExampleWithBranchedLookup(self):

        example = example_pb2.Example(features=feature_pb2.Features(
            feature={
                'image/object/class/text':
                self._BytesFeatureFromList(
                    np.array(['cat', 'dog', 'guinea pig'])),
            }))
        serialized_example = example.SerializeToString()
        # 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2
        table = lookup_ops.index_table_from_tensor(
            constant_op.constant(['dog', 'guinea pig', 'cat']))

        with self.test_session() as sess:
            sess.run(lookup_ops.tables_initializer())

            serialized_example = array_ops.reshape(serialized_example,
                                                   shape=[])

            keys_to_features = {
                'image/object/class/text':
                parsing_ops.VarLenFeature(dtypes.string),
            }

            items_to_handlers = {
                'labels':
                tf_example_decoder.LookupTensor('image/object/class/text',
                                                table),
            }

            decoder = slim_example_decoder.TFExampleDecoder(
                keys_to_features, items_to_handlers)
            obtained_class_ids = decoder.decode(serialized_example)[0].eval()

        self.assertAllClose([2, 0, 1], obtained_class_ids)
  def testDecodeExampleWithBranchedBackupHandler(self):
    example1 = example_pb2.Example(
        features=feature_pb2.Features(
            feature={
                'image/object/class/text':
                    self._BytesFeatureFromList(
                        np.array(['cat', 'dog', 'guinea pig'])),
                'image/object/class/label':
                    self._Int64FeatureFromList(np.array([42, 10, 900]))
            }))
    example2 = example_pb2.Example(
        features=feature_pb2.Features(
            feature={
                'image/object/class/text':
                    self._BytesFeatureFromList(
                        np.array(['cat', 'dog', 'guinea pig'])),
            }))
    example3 = example_pb2.Example(
        features=feature_pb2.Features(
            feature={
                'image/object/class/label':
                    self._Int64FeatureFromList(np.array([42, 10, 901]))
            }))
    # 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2
    table = lookup_ops.index_table_from_tensor(
        constant_op.constant(['dog', 'guinea pig', 'cat']))
    keys_to_features = {
        'image/object/class/text': parsing_ops.VarLenFeature(dtypes.string),
        'image/object/class/label': parsing_ops.VarLenFeature(dtypes.int64),
    }
    backup_handler = tf_example_decoder.BackupHandler(
        handler=slim_example_decoder.Tensor('image/object/class/label'),
        backup=tf_example_decoder.LookupTensor('image/object/class/text',
                                               table))
    items_to_handlers = {
        'labels': backup_handler,
    }
    decoder = slim_example_decoder.TFExampleDecoder(keys_to_features,
                                                    items_to_handlers)
    obtained_class_ids_each_example = []
    with self.test_session() as sess:
      sess.run(lookup_ops.tables_initializer())
      for example in [example1, example2, example3]:
        serialized_example = array_ops.reshape(
            example.SerializeToString(), shape=[])
        obtained_class_ids_each_example.append(
            decoder.decode(serialized_example)[0].eval())

    self.assertAllClose([42, 10, 900], obtained_class_ids_each_example[0])
    self.assertAllClose([2, 0, 1], obtained_class_ids_each_example[1])
    self.assertAllClose([42, 10, 901], obtained_class_ids_each_example[2])