Пример #1
0
def make_example_dict(example_protos, example_weights):

  def parse_examples(example_protos):
    features = {
        'target': tf.FixedLenFeature(shape=[1],
                                     dtype=tf.float32,
                                     default_value=0),
        'age_indices': tf.VarLenFeature(dtype=tf.int64),
        'age_values': tf.VarLenFeature(dtype=tf.float32),
        'gender_indices': tf.VarLenFeature(dtype=tf.int64),
        'gender_values': tf.VarLenFeature(dtype=tf.float32)
    }
    return tf.parse_example(
        [e.SerializeToString() for e in example_protos], features)

  parsed = parse_examples(example_protos)
  sparse_features = [
      SparseFeatureColumn(
          tf.reshape(
              tf.split(1, 2, parsed['age_indices'].indices)[0], [-1]),
          tf.reshape(parsed['age_indices'].values, [-1]),
          tf.reshape(parsed['age_values'].values, [-1])), SparseFeatureColumn(
              tf.reshape(
                  tf.split(1, 2, parsed['gender_indices'].indices)[0], [-1]),
              tf.reshape(parsed['gender_indices'].values, [-1]),
              tf.reshape(parsed['gender_values'].values, [-1]))
  ]
  return dict(sparse_features=sparse_features,
              dense_features=[],
              example_weights=example_weights,
              example_labels=tf.reshape(parsed['target'], [-1]),
              example_ids=['%d' % i for i in range(0, len(example_protos))])
Пример #2
0
 def testBasic(self):
   expected_example_indices = [1, 1, 1, 2]
   expected_feature_indices = [0, 1, 2, 0]
   sfc = SparseFeatureColumn(expected_example_indices,
                             expected_feature_indices, None)
   self.assertTrue(isinstance(sfc.example_indices, tf.Tensor))
   self.assertTrue(isinstance(sfc.feature_indices, tf.Tensor))
   self.assertEqual(sfc.feature_values, None)
   with self._single_threaded_test_session():
     self.assertAllEqual(expected_example_indices, sfc.example_indices.eval())
     self.assertAllEqual(expected_feature_indices, sfc.feature_indices.eval())
   expected_feature_values = [1.0, 2.0, 3.0, 4.0]
   sfc = SparseFeatureColumn([1, 1, 1, 2], [0, 1, 2, 0],
                             expected_feature_values)
   with self._single_threaded_test_session():
     self.assertAllEqual(expected_feature_values, sfc.feature_values.eval())