def __init__(self, embedding_key, data_keys, score_transform=None, top_k=None, reducer=tf.math.reduce_mean): """Initializes the CacheClassificationLoss object. Args: embedding_key: The key containing the embedding in the cache. data_keys: The keys containing the document data in the cache. score_transform: Scores are transformed by this function before use. Specifically we have scores(i, j) = score_transform(dot(query_embed_i, doc_embed_j)) top_k: If set, the top k scoring negative elements will be mined and the rest of the elements masked before calculating the loss. reducer: Function that reduces the losses to a single scaler. If None, then the elementwise losses are returned. """ self.embedding_key = embedding_key self.data_keys = data_keys self.score_transform = score_transform self.top_k = top_k self.reducer = reducer self._retrieval_fn = retrieval_fns.GumbelMaxRetrievalFn()
def test_gumbel_max_retrieval_fn_with_temperature_and_mock_randomness(self): with mock.patch.object(retrieval_fns, '_sample_gumbel') as mock_sample_gumbel: scores = tf.convert_to_tensor([[1.0, 0.0, 0.0]]) mock_sample_gumbel.return_value = tf.convert_to_tensor([[0.0, 0.9, 0.0]]) gumbel_max_retrieval_fn = retrieval_fns.GumbelMaxRetrievalFn(inv_temp=0.5) expected = tf.convert_to_tensor([[1]]) actual = gumbel_max_retrieval_fn(scores) self.assertAllEqual(expected, actual)
def __init__(self, embedding_key, data_keys, score_transform=None, top_k=None, reducer=tf.math.reduce_mean): self.embedding_key = embedding_key self.data_keys = data_keys self.score_transform = score_transform self.top_k = top_k self.reducer = reducer self._retrieval_fn = retrieval_fns.GumbelMaxRetrievalFn()
def test_gumbel_max_retrieval_fn_with_mock_randomness(self): with mock.patch.object(retrieval_fns, '_sample_gumbel') as mock_sample_gumbel: scores = tf.convert_to_tensor([[1.0, 0.0, 0.0], [100.0, 0.0, 0.0], [0.0, 0.0, -1.0]]) mock_sample_gumbel.return_value = tf.convert_to_tensor( [[0.0, 0.1, -0.1], [1.0, -2.0, 3.3], [0.1, -0.1, 2.0]]) gumbel_max_retrieval_fn = retrieval_fns.GumbelMaxRetrievalFn() expected = tf.convert_to_tensor([[0], [0], [2]]) actual = gumbel_max_retrieval_fn(scores) self.assertAllEqual(expected, actual) expected_shape = tf.convert_to_tensor([3, 3]) mock_sample_gumbel.assert_called_once() mock_args, mock_kwargs = mock_sample_gumbel.call_args self.assertAllEqual(mock_args, (expected_shape, )) self.assertEqual(mock_kwargs, {})
def test_gumbel_max_retrieval_fn_has_correct_output_shape(self): scores = tf.convert_to_tensor([[1.0, 0.0, 0.0], [100.0, 0.0, 0.0], [0.0, 0.0, -1.0]]) gumbel_max_retrieval_fn = retrieval_fns.GumbelMaxRetrievalFn() output = gumbel_max_retrieval_fn(scores) self.assertAllEqual([3, 1], tf.shape(output))