Пример #1
0
 def _set_intersection(self, a, b):
     # Validate that we get the same results with or without `validate_indices`,
     # and with a & b swapped.
     ops = (
         sets.set_intersection(a, b, validate_indices=True),
         sets.set_intersection(a, b, validate_indices=False),
         sets.set_intersection(b, a, validate_indices=True),
         sets.set_intersection(b, a, validate_indices=False),
     )
     for op in ops:
         self._assert_static_shapes(a, op)
     return self._run_equivalent_set_ops(ops)
Пример #2
0
 def _set_intersection(self, a, b):
   # Validate that we get the same results with or without `validate_indices`,
   # and with a & b swapped.
   ops = (
       sets.set_intersection(
           a, b, validate_indices=True),
       sets.set_intersection(
           a, b, validate_indices=False),
       sets.set_intersection(
           b, a, validate_indices=True),
       sets.set_intersection(
           b, a, validate_indices=False),)
   for op in ops:
     self._assert_static_shapes(a, op)
   return self._run_equivalent_set_ops(ops)
Пример #3
0
def recall_at_k(labels, predictions, k):
    '''
    Compute recall at position k.

    :param labels: shape=(num_examples,), dtype=tf.int64
    :param predictions: logits of shape=(num_examples, num_classes)
    :param k: recall position
    :return: recall at position k


    Example:

    labels = tf.constant([0, 1, 1], dtype=tf.int64)
    predictions = tf.constant([[0.1, 0.2, 0.3], [3, 5, 2], [0.3, 0.4, 0.7]])
    recall_at_k(labels, predictions, 2)
    # recall_at_k(labels, predictions, 2) = 0.6667

    '''
    labels = expand_dims(labels, axis=1)
    _, predictions_idx = nn.top_k(predictions, k)
    predictions_idx = math_ops.to_int64(predictions_idx)
    tp = sets.set_size(sets.set_intersection(predictions_idx, labels))
    tp = math_ops.to_double(tp)
    tp = math_ops.reduce_sum(tp)
    fn = sets.set_size(
        sets.set_difference(predictions_idx, labels, aminusb=False))
    fn = math_ops.to_double(fn)
    fn = math_ops.reduce_sum(fn)
    recall = math_ops.div(tp, math_ops.add(tp, fn), name='recall_at_k')

    return recall
Пример #4
0
 def _set_intersection(self, a, b):
     # Validate that we get the same results with or without `validate_indices`,
     # and with a & b swapped.
     ops = (
         sets.set_intersection(a, b, validate_indices=True),
         sets.set_intersection(a, b, validate_indices=False),
         sets.set_intersection(b, a, validate_indices=True),
         sets.set_intersection(b, a, validate_indices=False),
     )
     for op in ops:
         self._assert_shapes(a, op)
     with self.test_session() as sess:
         results = sess.run(ops)
     for i in range(1, 4):
         self.assertAllEqual(results[0].indices, results[i].indices)
         self.assertAllEqual(results[0].values, results[i].values)
         self.assertAllEqual(results[0].dense_shape, results[i].dense_shape)
     return results[0]
Пример #5
0
 def _set_intersection(self, a, b):
   # Validate that we get the same results with or without `validate_indices`,
   # and with a & b swapped.
   ops = (
       sets.set_intersection(
           a, b, validate_indices=True),
       sets.set_intersection(
           a, b, validate_indices=False),
       sets.set_intersection(
           b, a, validate_indices=True),
       sets.set_intersection(
           b, a, validate_indices=False),)
   for op in ops:
     self._assert_shapes(a, op)
   with self.test_session() as sess:
     results = sess.run(ops)
   for i in range(1, 4):
     self.assertAllEqual(results[0].indices, results[i].indices)
     self.assertAllEqual(results[0].values, results[i].values)
     self.assertAllEqual(results[0].dense_shape, results[i].dense_shape)
   return results[0]
Пример #6
0
 def _set_intersection_count(self, a, b):
     op = sets.set_size(sets.set_intersection(a, b))
     with self.test_session() as sess:
         return sess.run(op)
Пример #7
0
 def _set_intersection_count(self, a, b):
   op = sets.set_size(sets.set_intersection(a, b))
   with self.cached_session() as sess:
     return sess.run(op)
Пример #8
0
 def _set_intersection_count(self, a, b):
   op = sets.set_size(sets.set_intersection(a, b))
   with self.cached_session() as sess:
     return self.evaluate(op)