def testDecodeExampleWithLookup(self): example = example_pb2.Example(features=feature_pb2.Features( feature={ 'image/object/class/text': self._BytesFeature(np.array(['cat', 'dog', 'guinea pig'])), })) serialized_example = example.SerializeToString() # 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2 table = lookup_ops.index_table_from_tensor( constant_op.constant(['dog', 'guinea pig', 'cat'])) with self.test_session() as sess: sess.run(lookup_ops.tables_initializer()) serialized_example = array_ops.reshape(serialized_example, shape=[]) keys_to_features = { 'image/object/class/text': parsing_ops.VarLenFeature(dtypes.string), } items_to_handlers = { 'labels': tfexample_decoder.LookupTensor('image/object/class/text', table), } decoder = tfexample_decoder.TFExampleDecoder( keys_to_features, items_to_handlers) obtained_class_ids = decoder.decode(serialized_example)[0].eval() self.assertAllClose([2, 0, 1], obtained_class_ids)
def testDecodeExampleWithBackupHandlerLookup(self): example1 = example_pb2.Example(features=feature_pb2.Features( feature={ 'image/object/class/text': self._BytesFeature(np.array(['cat', 'dog', 'guinea pig'])), 'image/object/class/label': self._EncodedInt64Feature(np.array([42, 10, 900])) })) example2 = example_pb2.Example(features=feature_pb2.Features( feature={ 'image/object/class/text': self._BytesFeature(np.array(['cat', 'dog', 'guinea pig'])), })) example3 = example_pb2.Example(features=feature_pb2.Features( feature={ 'image/object/class/label': self._EncodedInt64Feature(np.array([42, 10, 901])) })) # 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2 table = lookup_ops.index_table_from_tensor( constant_op.constant(['dog', 'guinea pig', 'cat'])) keys_to_features = { 'image/object/class/text': parsing_ops.VarLenFeature(dtypes.string), 'image/object/class/label': parsing_ops.VarLenFeature(dtypes.int64), } backup_handler = tfexample_decoder.BackupHandler( handler=tfexample_decoder.Tensor('image/object/class/label'), backup=tfexample_decoder.LookupTensor('image/object/class/text', table)) items_to_handlers = { 'labels': backup_handler, } decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) obtained_class_ids_each_example = [] with self.test_session() as sess: sess.run(lookup_ops.tables_initializer()) for example in [example1, example2, example3]: serialized_example = array_ops.reshape( example.SerializeToString(), shape=[]) obtained_class_ids_each_example.append( decoder.decode(serialized_example)[0].eval()) self.assertAllClose([42, 10, 900], obtained_class_ids_each_example[0]) self.assertAllClose([2, 0, 1], obtained_class_ids_each_example[1]) self.assertAllClose([42, 10, 901], obtained_class_ids_each_example[2])