Пример #1
0
def recall_at_k(labels, predictions, k):
    '''
    Compute recall at position k.

    :param labels: shape=(num_examples,), dtype=tf.int64
    :param predictions: logits of shape=(num_examples, num_classes)
    :param k: recall position
    :return: recall at position k


    Example:

    labels = tf.constant([0, 1, 1], dtype=tf.int64)
    predictions = tf.constant([[0.1, 0.2, 0.3], [3, 5, 2], [0.3, 0.4, 0.7]])
    recall_at_k(labels, predictions, 2)
    # recall_at_k(labels, predictions, 2) = 0.6667

    '''
    labels = expand_dims(labels, axis=1)
    _, predictions_idx = nn.top_k(predictions, k)
    predictions_idx = math_ops.to_int64(predictions_idx)
    tp = sets.set_size(sets.set_intersection(predictions_idx, labels))
    tp = math_ops.to_double(tp)
    tp = math_ops.reduce_sum(tp)
    fn = sets.set_size(
        sets.set_difference(predictions_idx, labels, aminusb=False))
    fn = math_ops.to_double(fn)
    fn = math_ops.reduce_sum(fn)
    recall = math_ops.div(tp, math_ops.add(tp, fn), name='recall_at_k')

    return recall
Пример #2
0
 def _set_size(self, sparse_data):
     # Validate that we get the same results with or without `validate_indices`.
     ops = [
         sets.set_size(sparse_data, validate_indices=True),
         sets.set_size(sparse_data, validate_indices=False)
     ]
     for op in ops:
         self.assertEqual(None, op.get_shape().dims)
         self.assertEqual(dtypes.int32, op.dtype)
     with self.test_session() as sess:
         results = sess.run(ops)
     self.assertAllEqual(results[0], results[1])
     return results[0]
Пример #3
0
 def _set_size(self, sparse_data):
   # Validate that we get the same results with or without `validate_indices`.
   ops = [
       sets.set_size(sparse_data, validate_indices=True),
       sets.set_size(sparse_data, validate_indices=False)
   ]
   for op in ops:
     self.assertEqual(None, op.get_shape().dims)
     self.assertEqual(dtypes.int32, op.dtype)
   with self.cached_session() as sess:
     results = sess.run(ops)
   self.assertAllEqual(results[0], results[1])
   return results[0]
Пример #4
0
 def _set_difference_count(self, a, b, aminusb=True):
     op = sets.set_size(sets.set_difference(a, b, aminusb))
     with self.test_session() as sess:
         return sess.run(op)
Пример #5
0
 def _set_intersection_count(self, a, b):
     op = sets.set_size(sets.set_intersection(a, b))
     with self.test_session() as sess:
         return sess.run(op)
Пример #6
0
 def _set_union_count(self, a, b):
     op = sets.set_size(sets.set_union(a, b))
     with self.cached_session() as sess:
         return sess.run(op)
Пример #7
0
 def _set_difference_count(self, a, b, aminusb=True):
   op = sets.set_size(sets.set_difference(a, b, aminusb))
   with self.cached_session() as sess:
     return sess.run(op)
Пример #8
0
 def _set_intersection_count(self, a, b):
   op = sets.set_size(sets.set_intersection(a, b))
   with self.cached_session() as sess:
     return sess.run(op)
Пример #9
0
 def _set_union_count(self, a, b):
   op = sets.set_size(sets.set_union(a, b))
   with self.test_session() as sess:
     return sess.run(op)
Пример #10
0
 def _set_difference_count(self, a, b, aminusb=True):
   op = sets.set_size(sets.set_difference(a, b, aminusb))
   with self.cached_session() as sess:
     return self.evaluate(op)
Пример #11
0
 def _set_intersection_count(self, a, b):
   op = sets.set_size(sets.set_intersection(a, b))
   with self.cached_session() as sess:
     return self.evaluate(op)
Пример #12
0
 def _set_union_count(self, a, b):
   op = sets.set_size(sets.set_union(a, b))
   with self.cached_session() as sess:
     return self.evaluate(op)