Esempio n. 1
0
    def __init__(self,
                 embedding_key,
                 data_keys,
                 score_transform=None,
                 top_k=None,
                 reducer=tf.math.reduce_mean):
        """Initializes the CacheClassificationLoss object.

    Args:
      embedding_key: The key containing the embedding in the cache.
      data_keys: The keys containing the document data in the cache.
      score_transform: Scores are transformed by this function before use.
        Specifically we have scores(i, j) = score_transform(dot(query_embed_i,
        doc_embed_j))
      top_k: If set, the top k scoring negative elements will be mined and the
        rest of the elements masked before calculating the loss.
      reducer: Function that reduces the losses to a single scaler. If None,
        then the elementwise losses are returned.
    """
        self.embedding_key = embedding_key
        self.data_keys = data_keys
        self.score_transform = score_transform
        self.top_k = top_k
        self.reducer = reducer
        self._retrieval_fn = retrieval_fns.GumbelMaxRetrievalFn()
 def test_gumbel_max_retrieval_fn_with_temperature_and_mock_randomness(self):
   with mock.patch.object(retrieval_fns,
                          '_sample_gumbel') as mock_sample_gumbel:
     scores = tf.convert_to_tensor([[1.0, 0.0, 0.0]])
     mock_sample_gumbel.return_value = tf.convert_to_tensor([[0.0, 0.9, 0.0]])
     gumbel_max_retrieval_fn = retrieval_fns.GumbelMaxRetrievalFn(inv_temp=0.5)
     expected = tf.convert_to_tensor([[1]])
     actual = gumbel_max_retrieval_fn(scores)
   self.assertAllEqual(expected, actual)
Esempio n. 3
0
 def __init__(self,
              embedding_key,
              data_keys,
              score_transform=None,
              top_k=None,
              reducer=tf.math.reduce_mean):
     self.embedding_key = embedding_key
     self.data_keys = data_keys
     self.score_transform = score_transform
     self.top_k = top_k
     self.reducer = reducer
     self._retrieval_fn = retrieval_fns.GumbelMaxRetrievalFn()
Esempio n. 4
0
 def test_gumbel_max_retrieval_fn_with_mock_randomness(self):
     with mock.patch.object(retrieval_fns,
                            '_sample_gumbel') as mock_sample_gumbel:
         scores = tf.convert_to_tensor([[1.0, 0.0, 0.0], [100.0, 0.0, 0.0],
                                        [0.0, 0.0, -1.0]])
         mock_sample_gumbel.return_value = tf.convert_to_tensor(
             [[0.0, 0.1, -0.1], [1.0, -2.0, 3.3], [0.1, -0.1, 2.0]])
         gumbel_max_retrieval_fn = retrieval_fns.GumbelMaxRetrievalFn()
         expected = tf.convert_to_tensor([[0], [0], [2]])
         actual = gumbel_max_retrieval_fn(scores)
     self.assertAllEqual(expected, actual)
     expected_shape = tf.convert_to_tensor([3, 3])
     mock_sample_gumbel.assert_called_once()
     mock_args, mock_kwargs = mock_sample_gumbel.call_args
     self.assertAllEqual(mock_args, (expected_shape, ))
     self.assertEqual(mock_kwargs, {})
 def test_gumbel_max_retrieval_fn_has_correct_output_shape(self):
   scores = tf.convert_to_tensor([[1.0, 0.0, 0.0], [100.0, 0.0, 0.0],
                                  [0.0, 0.0, -1.0]])
   gumbel_max_retrieval_fn = retrieval_fns.GumbelMaxRetrievalFn()
   output = gumbel_max_retrieval_fn(scores)
   self.assertAllEqual([3, 1], tf.shape(output))