def testDecodeExampleWithSparseTensorToDense(self):
        np_indices = np.array([1, 2, 5])
        np_values = np.array([0.1, 0.2, 0.6]).astype('f')
        np_shape = np.array([6])
        np_dense = np.array([0.0, 0.1, 0.2, 0.0, 0.0, 0.6]).astype('f')
        example = example_pb2.Example(features=feature_pb2.Features(
            feature={
                'indices': self._EncodedInt64Feature(np_indices),
                'values': self._EncodedFloatFeature(np_values),
            }))

        serialized_example = example.SerializeToString()

        with self.test_session():
            serialized_example = array_ops.reshape(serialized_example,
                                                   shape=[])
            keys_to_features = {
                'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
                'values': parsing_ops.VarLenFeature(dtype=dtypes.float32),
            }
            items_to_handlers = {
                'labels':
                tfexample_decoder.SparseTensor(shape=np_shape, densify=True),
            }
            decoder = tfexample_decoder.TFExampleDecoder(
                keys_to_features, items_to_handlers)
            [tf_labels] = decoder.decode(serialized_example, ['labels'])
            labels = tf_labels.eval()
            self.assertAllClose(labels, np_dense)
示例#2
0
    def test_decode_example_with_sparse_tensor_with_given_shape(self):
        np_indices = np.array([[1], [2], [5]])
        np_values = np.array([0.1, 0.2, 0.6]).astype('f')
        np_shape = np.array([6])
        example = example_pb2.Example(features=feature_pb2.Features(
            feature={
                'indices': self._encode_int64_feature(np_indices),
                'values': self._encode_float_feature(np_values),
            }))

        serialized_example = example.SerializeToString()

        with self.test_session():
            serialized_example = array_ops.reshape(serialized_example,
                                                   shape=[])
            keys_to_features = {
                'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
                'values': parsing_ops.VarLenFeature(dtype=dtypes.float32),
            }
            items_to_handlers = {
                'labels': tfexample_decoder.SparseTensor(shape=np_shape),
            }
            decoder = TFExampleDecoder(keys_to_features, items_to_handlers)
            [tf_labels] = decoder.decode(serialized_example, ['labels'])
            labels = tf_labels.eval()
            self.assertAllEqual(labels.indices, np_indices)
            self.assertAllEqual(labels.values, np_values)
            self.assertAllEqual(labels.dense_shape, np_shape)