def test_cache_classification_loss_refreshed_embeddings_with_top_k(self):
    cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0], [3.0], [2.0]])
    cached_data = tf.convert_to_tensor([[1.0], [-1.0], [3.0], [2.0]])
    cache = {
        'cache':
            negative_cache.NegativeCache(
                data={
                    'embeddings': cached_embeddings,
                    'data': cached_data
                },
                age=[0, 0])
    }
    doc_network = lambda data: data['data']
    query_embedding = tf.convert_to_tensor([[-1.0], [1.0]])
    pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0]])

    loss_fn = losses.CacheClassificationLoss('embeddings', ['data'], top_k=2)
    # pylint: disable=g-long-lambda
    retrieval_fn = lambda scores: tf.convert_to_tensor([[0], [0]],
                                                       dtype=tf.int64)
    # pylint: enable=g-long-lambda
    loss_fn._retrieval_fn = retrieval_fn
    loss_fn_return = loss_fn(doc_network, query_embedding, pos_doc_embedding,
                             cache)
    updated_item_data = loss_fn_return.updated_item_data
    updated_item_indices = loss_fn_return.updated_item_indices
    updated_item_mask = loss_fn_return.updated_item_mask
    self.assertAllClose([[-1.0], [1.0]],
                        updated_item_data['cache']['embeddings'])
    self.assertAllEqual([1, 0], updated_item_indices['cache'])
    self.assertAllEqual([True, True], updated_item_mask['cache'])
  def test_cache_classification_loss_training_loss(self):
    cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0]])
    cached_data = tf.convert_to_tensor([[1.0], [-1.0]])
    cache = {
        'cache':
            negative_cache.NegativeCache(
                data={
                    'embeddings': cached_embeddings,
                    'data': cached_data
                },
                age=[0, 0])
    }
    doc_network = lambda data: data['data']
    query_embedding = tf.convert_to_tensor([[-1.0], [1.0], [3.0]])
    pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0], [1.0]])
    # pylint: disable=g-long-lambda
    retrieval_fn = lambda scores: tf.convert_to_tensor([[0], [1], [0]],
                                                       dtype=tf.int64)
    # pylint: enable=g-long-lambda
    loss_fn = losses.CacheClassificationLoss(
        'embeddings', ['data'], reducer=None)
    loss_fn._retrieval_fn = retrieval_fn
    training_loss = loss_fn(doc_network, query_embedding, pos_doc_embedding,
                            cache).training_loss
    prob_pos = tf.convert_to_tensor([0.0420101, 0.705385, 0.499381])
    score_differences = tf.convert_to_tensor([1.0, -3.0, 0.0])
    training_loss_expected = (1.0 - prob_pos) * score_differences
    self.assertAllClose(training_loss_expected, training_loss)

    loss_fn.reducer = tf.math.reduce_mean
    training_loss = loss_fn(doc_network, query_embedding, pos_doc_embedding,
                            cache).training_loss
    training_loss_expected = tf.reduce_mean(
        (1.0 - prob_pos) * score_differences)
    self.assertAllClose(training_loss_expected, training_loss)
  def test_cache_classification_loss_staleness(self):
    cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0]])
    cached_data = tf.convert_to_tensor([[2.0], [-3.0]])
    cache = {
        'cache':
            negative_cache.NegativeCache(
                data={
                    'embeddings': cached_embeddings,
                    'data': cached_data
                },
                age=[0, 0])
    }
    doc_network = lambda data: data['data']
    query_embedding = tf.convert_to_tensor([[0.0], [0.0], [0.0]])
    pos_doc_embedding = tf.convert_to_tensor([[0.0], [0.0], [0.0]])
    # pylint: disable=g-long-lambda
    retrieval_fn = lambda scores: tf.convert_to_tensor([[0], [1], [0]],
                                                       dtype=tf.int64)
    # pylint: enable=g-long-lambda
    loss_fn = losses.CacheClassificationLoss('embeddings', ['data'])
    loss_fn._retrieval_fn = retrieval_fn
    staleness = loss_fn(doc_network, query_embedding, pos_doc_embedding,
                        cache).staleness
    staleness_expected = 0.31481481481
    self.assertAllClose(staleness_expected, staleness)

    loss_fn.reducer = None
    staleness = loss_fn(doc_network, query_embedding, pos_doc_embedding,
                        cache).staleness
    staleness_expected = [1.0 / 4.0, 4.0 / 9.0, 1.0 / 4.0]
    self.assertAllClose(staleness_expected, staleness)
  def test_cache_classification_loss_interpretable_loss(self):
    cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0]])
    cached_data = tf.convert_to_tensor([[1.0], [-1.0]])
    cache = {
        'cache':
            negative_cache.NegativeCache(
                data={
                    'embeddings': cached_embeddings,
                    'data': cached_data
                },
                age=[0, 0])
    }
    doc_network = lambda data: data['data']
    query_embedding = tf.convert_to_tensor([[-1.0], [1.0], [3.0]])
    pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0], [1.0]])

    loss_fn = losses.CacheClassificationLoss(
        'embeddings', ['data'], reducer=None)
    interpretable_loss = loss_fn(doc_network, query_embedding,
                                 pos_doc_embedding, cache).interpretable_loss
    interpretable_loss_expected = [3.169846, 0.349012, 0.694385]
    self.assertAllClose(interpretable_loss_expected, interpretable_loss)

    loss_fn.reducer = tf.math.reduce_mean
    interpretable_loss = loss_fn(doc_network, query_embedding,
                                 pos_doc_embedding, cache).interpretable_loss
    interpretable_loss_expected = (3.169846 + 0.349012 + 0.694385) / 3.0
    self.assertAllClose(interpretable_loss_expected, interpretable_loss)
  def test_cache_classification_loss_training_loss_gradient(self):
    cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0]])
    cached_data = tf.convert_to_tensor([[1.0], [-1.0]])
    cache = {
        'cache':
            negative_cache.NegativeCache(
                data={
                    'embeddings': cached_embeddings,
                    'data': cached_data
                },
                age=[0, 0])
    }
    query_model = tf.Variable(1.0)
    doc_model = tf.Variable(1.0)
    doc_network = lambda data: data['data'] * doc_model

    query_embedding = tf.convert_to_tensor([[-1.0], [1.0], [3.0]])
    pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0], [1.0]])
    # pylint: disable=g-long-lambda
    retrieval_fn = lambda scores: tf.convert_to_tensor([[0], [1], [0]],
                                                       dtype=tf.int64)
    # pylint: enable=g-long-lambda
    loss_fn = losses.CacheClassificationLoss('embeddings', ['data'])
    loss_fn._retrieval_fn = retrieval_fn

    with tf.GradientTape() as tape:
      query_embedding = query_model * query_embedding
      pos_doc_embedding = doc_model * pos_doc_embedding
      training_loss = loss_fn(doc_network, query_embedding, pos_doc_embedding,
                              cache).training_loss
    gradient = tape.gradient(training_loss, [query_model, doc_model])
    gradient_expected = [0.024715006, 0.024715006]
    self.assertAllClose(gradient_expected, gradient)
 def test_cache_classification_loss_refreshed_embeddings(self):
   cached_embeddings_1 = tf.convert_to_tensor([[1.0], [2.0], [3.0]])
   cached_embeddings_2 = tf.convert_to_tensor([[-1.0], [-2.0]])
   cached_data_1 = tf.convert_to_tensor([[10.0], [20.0], [30.0]])
   cached_data_2 = tf.convert_to_tensor([[-10.0], [-20.0]])
   cache = {
       'cache1':
           negative_cache.NegativeCache(
               data={
                   'embeddings': cached_embeddings_1,
                   'data': cached_data_1
               },
               age=[0, 0, 0]),
       'cache2':
           negative_cache.NegativeCache(
               data={
                   'embeddings': cached_embeddings_2,
                   'data': cached_data_2
               },
               age=[0, 0]),
   }
   doc_network = lambda data: data['data']
   query_embedding = tf.convert_to_tensor([[0.0], [0.0], [0.0]])
   pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0], [1.0]])
   # pylint: disable=g-long-lambda
   retrieval_fn = lambda scores: tf.convert_to_tensor([[0], [1], [3]],
                                                      dtype=tf.int64)
   # pylint: enable=g-long-lambda
   loss_fn = losses.CacheClassificationLoss('embeddings', ['data'])
   loss_fn._retrieval_fn = retrieval_fn
   cache_loss_return = loss_fn(doc_network, query_embedding, pos_doc_embedding,
                               cache)
   self.assertAllEqual(cache_loss_return.updated_item_mask['cache1'],
                       [True, True, False])
   self.assertAllEqual(cache_loss_return.updated_item_mask['cache2'],
                       [False, False, True])
   self.assertAllEqual(
       cache_loss_return.updated_item_data['cache1']['embeddings'][0:2],
       [[10.0], [20.0]])
   self.assertAllEqual(
       cache_loss_return.updated_item_data['cache2']['embeddings'][2], [-10.0])
   self.assertAllEqual(cache_loss_return.updated_item_indices['cache1'][0:2],
                       [0, 1])
   self.assertEqual(cache_loss_return.updated_item_indices['cache2'][2], 0)
  def test_cache_classification_loss_interpretable_loss_with_top_k(self):
    cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0], [3.0], [2.0]])
    cached_data = tf.convert_to_tensor([[1.0], [-1.0], [3.0], [2.0]])
    cache = {
        'cache':
            negative_cache.NegativeCache(
                data={
                    'embeddings': cached_embeddings,
                    'data': cached_data
                },
                age=[0, 0])
    }
    doc_network = lambda data: data['data']
    query_embedding = tf.convert_to_tensor([[-1.0], [1.0]])
    pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0]])

    loss_fn = losses.CacheClassificationLoss(
        'embeddings', ['data'], reducer=None, top_k=2)
    interpretable_loss = loss_fn(doc_network, query_embedding,
                                 pos_doc_embedding, cache).interpretable_loss
    interpretable_loss_expected = [3.0949229, 1.407605964]
    self.assertAllClose(interpretable_loss_expected, interpretable_loss)