コード例 #1
0
 def test_masked_update_cache_with_existing_items_when_all_items_masked(
         self):
     specs = {
         '1': tf.io.FixedLenFeature(shape=[2], dtype=tf.int32),
     }
     cache_manager = negative_cache.CacheManager(specs=specs, cache_size=4)
     cache = negative_cache.NegativeCache(data={
         '1':
         tf.convert_to_tensor([[5, 5], [10, 10], [15, 15], [20, 20]],
                              dtype=tf.int32)
     },
                                          age=tf.convert_to_tensor(
                                              [2, 2, 2, 2], dtype=tf.int32))
     updated_item_indices = tf.convert_to_tensor([1, 3], dtype=tf.int32)
     updated_item_data = {
         '1': tf.convert_to_tensor([[1, 1], [2, 2]], dtype=tf.int32),
     }
     updated_item_mask = tf.convert_to_tensor([False, False])
     cache = cache_manager.update_cache(
         cache,
         updated_item_data=updated_item_data,
         updated_item_indices=updated_item_indices,
         updated_item_mask=updated_item_mask)
     self.assertEqual({'1'}, set(cache.data.keys()))
     self.assertAllEqual(
         tf.convert_to_tensor([[5, 5], [10, 10], [15, 15], [20, 20]],
                              dtype=tf.float32), cache.data['1'])
     self.assertAllEqual(tf.convert_to_tensor([3, 3, 3, 3], dtype=tf.int32),
                         cache.age)
コード例 #2
0
 def test_update_cache_without_lru(self):
     specs = {
         '1': tf.io.FixedLenFeature(shape=[2], dtype=tf.int32),
     }
     cache_manager = negative_cache.CacheManager(specs=specs,
                                                 cache_size=4,
                                                 use_lru=False)
     cache = negative_cache.NegativeCache(data={
         '1':
         tf.convert_to_tensor([[5, 5], [10, 10], [15, 15], [20, 20]],
                              dtype=tf.int32)
     },
                                          age=tf.convert_to_tensor(
                                              [1, 0, 1, 1], dtype=tf.int32))
     updated_item_indices = tf.convert_to_tensor([1, 3], dtype=tf.int32)
     updated_item_data = {
         '1': tf.convert_to_tensor([[1, 1], [2, 2]], dtype=tf.int32),
     }
     cache = cache_manager.update_cache(
         cache,
         updated_item_indices=updated_item_indices,
         updated_item_data=updated_item_data)
     cache_data_expected = [[5, 5], [1, 1], [15, 15], [2, 2]]
     cache_age_expected = [2, 1, 2, 2]
     self.assertAllEqual(cache_data_expected, cache.data['1'])
     self.assertAllEqual(cache_age_expected, cache.age)
コード例 #3
0
  def test_cache_classification_loss_training_loss(self):
    cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0]])
    cached_data = tf.convert_to_tensor([[1.0], [-1.0]])
    cache = {
        'cache':
            negative_cache.NegativeCache(
                data={
                    'embeddings': cached_embeddings,
                    'data': cached_data
                },
                age=[0, 0])
    }
    doc_network = lambda data: data['data']
    query_embedding = tf.convert_to_tensor([[-1.0], [1.0], [3.0]])
    pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0], [1.0]])
    # pylint: disable=g-long-lambda
    retrieval_fn = lambda scores: tf.convert_to_tensor([[0], [1], [0]],
                                                       dtype=tf.int64)
    # pylint: enable=g-long-lambda
    loss_fn = losses.CacheClassificationLoss(
        'embeddings', ['data'], reducer=None)
    loss_fn._retrieval_fn = retrieval_fn
    training_loss = loss_fn(doc_network, query_embedding, pos_doc_embedding,
                            cache).training_loss
    prob_pos = tf.convert_to_tensor([0.0420101, 0.705385, 0.499381])
    score_differences = tf.convert_to_tensor([1.0, -3.0, 0.0])
    training_loss_expected = (1.0 - prob_pos) * score_differences
    self.assertAllClose(training_loss_expected, training_loss)

    loss_fn.reducer = tf.math.reduce_mean
    training_loss = loss_fn(doc_network, query_embedding, pos_doc_embedding,
                            cache).training_loss
    training_loss_expected = tf.reduce_mean(
        (1.0 - prob_pos) * score_differences)
    self.assertAllClose(training_loss_expected, training_loss)
コード例 #4
0
 def test_update_cache_with_new_items_and_existing_items(self):
     specs = {
         '1': tf.io.FixedLenFeature(shape=[1], dtype=tf.int32),
         '2': tf.io.FixedLenFeature(shape=[1], dtype=tf.int32),
     }
     cache_manager = negative_cache.CacheManager(specs=specs, cache_size=2)
     data = {
         '1': tf.convert_to_tensor([[0], [0], [3]], dtype=tf.int32),
         '2': tf.convert_to_tensor([[1], [2], [4]], dtype=tf.int32)
     }
     age = tf.convert_to_tensor([2, 1, 0])
     cache = negative_cache.NegativeCache(data=data, age=age)
     updated_item_indices = tf.convert_to_tensor([0], dtype=tf.int32)
     updated_item_data = {
         '1': tf.convert_to_tensor([[10]], dtype=tf.int32),
     }
     new_items = {
         '1': tf.convert_to_tensor([[11]], dtype=tf.int32),
         '2': tf.convert_to_tensor([[12]], dtype=tf.int32)
     }
     cache = cache_manager.update_cache(
         cache,
         new_items=new_items,
         updated_item_data=updated_item_data,
         updated_item_indices=updated_item_indices)
     self.assertEqual({'1', '2'}, set(cache.data.keys()))
     self.assertAllEqual(
         tf.convert_to_tensor([[10], [11], [3]], dtype=tf.int32),
         cache.data['1'])
     self.assertAllEqual(
         tf.convert_to_tensor([[1], [12], [4]], dtype=tf.int32),
         cache.data['2'])
     self.assertAllEqual(tf.convert_to_tensor([0, 0, 1], dtype=tf.int32),
                         cache.age)
コード例 #5
0
  def test_cache_classification_loss_interpretable_loss(self):
    cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0]])
    cached_data = tf.convert_to_tensor([[1.0], [-1.0]])
    cache = {
        'cache':
            negative_cache.NegativeCache(
                data={
                    'embeddings': cached_embeddings,
                    'data': cached_data
                },
                age=[0, 0])
    }
    doc_network = lambda data: data['data']
    query_embedding = tf.convert_to_tensor([[-1.0], [1.0], [3.0]])
    pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0], [1.0]])

    loss_fn = losses.CacheClassificationLoss(
        'embeddings', ['data'], reducer=None)
    interpretable_loss = loss_fn(doc_network, query_embedding,
                                 pos_doc_embedding, cache).interpretable_loss
    interpretable_loss_expected = [3.169846, 0.349012, 0.694385]
    self.assertAllClose(interpretable_loss_expected, interpretable_loss)

    loss_fn.reducer = tf.math.reduce_mean
    interpretable_loss = loss_fn(doc_network, query_embedding,
                                 pos_doc_embedding, cache).interpretable_loss
    interpretable_loss_expected = (3.169846 + 0.349012 + 0.694385) / 3.0
    self.assertAllClose(interpretable_loss_expected, interpretable_loss)
コード例 #6
0
  def test_cache_classification_loss_refreshed_embeddings_with_top_k(self):
    cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0], [3.0], [2.0]])
    cached_data = tf.convert_to_tensor([[1.0], [-1.0], [3.0], [2.0]])
    cache = {
        'cache':
            negative_cache.NegativeCache(
                data={
                    'embeddings': cached_embeddings,
                    'data': cached_data
                },
                age=[0, 0])
    }
    doc_network = lambda data: data['data']
    query_embedding = tf.convert_to_tensor([[-1.0], [1.0]])
    pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0]])

    loss_fn = losses.CacheClassificationLoss('embeddings', ['data'], top_k=2)
    # pylint: disable=g-long-lambda
    retrieval_fn = lambda scores: tf.convert_to_tensor([[0], [0]],
                                                       dtype=tf.int64)
    # pylint: enable=g-long-lambda
    loss_fn._retrieval_fn = retrieval_fn
    loss_fn_return = loss_fn(doc_network, query_embedding, pos_doc_embedding,
                             cache)
    updated_item_data = loss_fn_return.updated_item_data
    updated_item_indices = loss_fn_return.updated_item_indices
    updated_item_mask = loss_fn_return.updated_item_mask
    self.assertAllClose([[-1.0], [1.0]],
                        updated_item_data['cache']['embeddings'])
    self.assertAllEqual([1, 0], updated_item_indices['cache'])
    self.assertAllEqual([True, True], updated_item_mask['cache'])
コード例 #7
0
  def test_cache_classification_loss_staleness(self):
    cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0]])
    cached_data = tf.convert_to_tensor([[2.0], [-3.0]])
    cache = {
        'cache':
            negative_cache.NegativeCache(
                data={
                    'embeddings': cached_embeddings,
                    'data': cached_data
                },
                age=[0, 0])
    }
    doc_network = lambda data: data['data']
    query_embedding = tf.convert_to_tensor([[0.0], [0.0], [0.0]])
    pos_doc_embedding = tf.convert_to_tensor([[0.0], [0.0], [0.0]])
    # pylint: disable=g-long-lambda
    retrieval_fn = lambda scores: tf.convert_to_tensor([[0], [1], [0]],
                                                       dtype=tf.int64)
    # pylint: enable=g-long-lambda
    loss_fn = losses.CacheClassificationLoss('embeddings', ['data'])
    loss_fn._retrieval_fn = retrieval_fn
    staleness = loss_fn(doc_network, query_embedding, pos_doc_embedding,
                        cache).staleness
    staleness_expected = 0.31481481481
    self.assertAllClose(staleness_expected, staleness)

    loss_fn.reducer = None
    staleness = loss_fn(doc_network, query_embedding, pos_doc_embedding,
                        cache).staleness
    staleness_expected = [1.0 / 4.0, 4.0 / 9.0, 1.0 / 4.0]
    self.assertAllClose(staleness_expected, staleness)
コード例 #8
0
  def test_cache_classification_loss_training_loss_gradient(self):
    cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0]])
    cached_data = tf.convert_to_tensor([[1.0], [-1.0]])
    cache = {
        'cache':
            negative_cache.NegativeCache(
                data={
                    'embeddings': cached_embeddings,
                    'data': cached_data
                },
                age=[0, 0])
    }
    query_model = tf.Variable(1.0)
    doc_model = tf.Variable(1.0)
    doc_network = lambda data: data['data'] * doc_model

    query_embedding = tf.convert_to_tensor([[-1.0], [1.0], [3.0]])
    pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0], [1.0]])
    # pylint: disable=g-long-lambda
    retrieval_fn = lambda scores: tf.convert_to_tensor([[0], [1], [0]],
                                                       dtype=tf.int64)
    # pylint: enable=g-long-lambda
    loss_fn = losses.CacheClassificationLoss('embeddings', ['data'])
    loss_fn._retrieval_fn = retrieval_fn

    with tf.GradientTape() as tape:
      query_embedding = query_model * query_embedding
      pos_doc_embedding = doc_model * pos_doc_embedding
      training_loss = loss_fn(doc_network, query_embedding, pos_doc_embedding,
                              cache).training_loss
    gradient = tape.gradient(training_loss, [query_model, doc_model])
    gradient_expected = [0.024715006, 0.024715006]
    self.assertAllClose(gradient_expected, gradient)
コード例 #9
0
 def test_cache_classification_loss_refreshed_embeddings(self):
   cached_embeddings_1 = tf.convert_to_tensor([[1.0], [2.0], [3.0]])
   cached_embeddings_2 = tf.convert_to_tensor([[-1.0], [-2.0]])
   cached_data_1 = tf.convert_to_tensor([[10.0], [20.0], [30.0]])
   cached_data_2 = tf.convert_to_tensor([[-10.0], [-20.0]])
   cache = {
       'cache1':
           negative_cache.NegativeCache(
               data={
                   'embeddings': cached_embeddings_1,
                   'data': cached_data_1
               },
               age=[0, 0, 0]),
       'cache2':
           negative_cache.NegativeCache(
               data={
                   'embeddings': cached_embeddings_2,
                   'data': cached_data_2
               },
               age=[0, 0]),
   }
   doc_network = lambda data: data['data']
   query_embedding = tf.convert_to_tensor([[0.0], [0.0], [0.0]])
   pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0], [1.0]])
   # pylint: disable=g-long-lambda
   retrieval_fn = lambda scores: tf.convert_to_tensor([[0], [1], [3]],
                                                      dtype=tf.int64)
   # pylint: enable=g-long-lambda
   loss_fn = losses.CacheClassificationLoss('embeddings', ['data'])
   loss_fn._retrieval_fn = retrieval_fn
   cache_loss_return = loss_fn(doc_network, query_embedding, pos_doc_embedding,
                               cache)
   self.assertAllEqual(cache_loss_return.updated_item_mask['cache1'],
                       [True, True, False])
   self.assertAllEqual(cache_loss_return.updated_item_mask['cache2'],
                       [False, False, True])
   self.assertAllEqual(
       cache_loss_return.updated_item_data['cache1']['embeddings'][0:2],
       [[10.0], [20.0]])
   self.assertAllEqual(
       cache_loss_return.updated_item_data['cache2']['embeddings'][2], [-10.0])
   self.assertAllEqual(cache_loss_return.updated_item_indices['cache1'][0:2],
                       [0, 1])
   self.assertEqual(cache_loss_return.updated_item_indices['cache2'][2], 0)
コード例 #10
0
 def _initialize_cache(self):
     cache = {}
     for key in self.cache_managers:
         initial_cache = self.cache_managers[key].init_cache()
         initial_cache_age = _make_cache_variable(initial_cache.age)
         initial_cache_data = {
             k: _make_cache_variable(initial_cache.data[k])
             for k in initial_cache.data
         }
         cache[key] = negative_cache.NegativeCache(data=initial_cache_data,
                                                   age=initial_cache_age)
     return cache
コード例 #11
0
  def test_is_in_cache_filter_fn_with_missing_keys(self):

    data = {
        '1':
            tf.convert_to_tensor([[0, 0], [1, 1], [2, 2], [3, 3], [20, 20],
                                  [30, 30], [40, 40]]),
        '2':
            tf.convert_to_tensor([[4, 4], [5, 5], [6, 6], [7, 7], [50, 50],
                                  [60, 60], [70, 70]])
    }
    age = tf.zeros([4])
    cache = negative_cache.NegativeCache(data, age)
    new_items = {
        '1': tf.convert_to_tensor([[0, 0], [1, 1], [3, 3], [8, 8], [9, 9]]),
        '2': tf.convert_to_tensor([[4, 4], [7, 7], [7, 7], [10, 10], [11, 11]])
    }
    is_in_cache_filter_fn = filter_fns.IsInCacheFilterFn(keys=('1',))
    mask = is_in_cache_filter_fn(cache, new_items)
    mask_expected = tf.convert_to_tensor([False, False, False, True, True])
    self.assertAllEqual(mask_expected, mask)
コード例 #12
0
  def test_new_items_with_mask(self):
    specs = {
        '1': tf.io.FixedLenFeature(shape=[2], dtype=tf.int32),
    }
    cache_manager = negative_cache.CacheManager(specs=specs, cache_size=4)
    cache = negative_cache.NegativeCache(
        data={
            '1':
                tf.convert_to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]],
                                     dtype=tf.int32)
        },
        age=tf.convert_to_tensor([0, 2, 1, 3], dtype=tf.int32))

    new_items = {
        '1': tf.convert_to_tensor([[5, 5], [6, 6], [7, 7]], dtype=tf.int32)
    }
    new_items_mask = tf.convert_to_tensor([True, False, True])
    cache = cache_manager.update_cache(
        cache, new_items=new_items, new_items_mask=new_items_mask)
    self.assertAllEqual(
        tf.convert_to_tensor([[1, 1], [7, 7], [3, 3], [5, 5]], dtype=tf.int32),
        cache.data['1'])
コード例 #13
0
  def test_cache_classification_loss_interpretable_loss_with_top_k(self):
    cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0], [3.0], [2.0]])
    cached_data = tf.convert_to_tensor([[1.0], [-1.0], [3.0], [2.0]])
    cache = {
        'cache':
            negative_cache.NegativeCache(
                data={
                    'embeddings': cached_embeddings,
                    'data': cached_data
                },
                age=[0, 0])
    }
    doc_network = lambda data: data['data']
    query_embedding = tf.convert_to_tensor([[-1.0], [1.0]])
    pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0]])

    loss_fn = losses.CacheClassificationLoss(
        'embeddings', ['data'], reducer=None, top_k=2)
    interpretable_loss = loss_fn(doc_network, query_embedding,
                                 pos_doc_embedding, cache).interpretable_loss
    interpretable_loss_expected = [3.0949229, 1.407605964]
    self.assertAllClose(interpretable_loss_expected, interpretable_loss)
コード例 #14
0
    def testInterpretableLoss(self):
        cache_data_1 = {
            'data': tf.convert_to_tensor([[1.0], [2.0]]),
            'embedding': tf.convert_to_tensor([[1.0], [2.0]])
        }
        cache_data_2 = {
            'data': tf.convert_to_tensor([[-1.0], [-2.0]]),
            'embedding': tf.convert_to_tensor([[-1.0], [-2.0]])
        }
        cache_data_multi_replica = {}
        for key in cache_data_1:
            cache_data_multi_replica[key] = make_distributed_tensor(
                self.strategy, [cache_data_1[key], cache_data_2[key]])
        cache_age = tf.zeros([0], dtype=tf.int32)
        cache_age_multi_replica = make_distributed_tensor(
            self.strategy, [cache_age, cache_age])
        cache = negative_cache.NegativeCache(data=cache_data_multi_replica,
                                             age=cache_age_multi_replica)
        query_embeddings_1 = tf.convert_to_tensor([[-1.0], [1.0]])
        query_embeddings_2 = tf.convert_to_tensor([[1.0], [1.0]])
        query_embeddings_multi_replica = make_distributed_tensor(
            self.strategy, [query_embeddings_1, query_embeddings_2])
        pos_doc_embeddings_1 = tf.convert_to_tensor([[2.0], [2.0]])
        pos_doc_embeddings_2 = tf.convert_to_tensor([[2.0], [2.0]])
        pos_doc_embeddings_multi_replica = make_distributed_tensor(
            self.strategy, [pos_doc_embeddings_1, pos_doc_embeddings_2])

        embedding_key = 'embedding'
        data_keys = ('data', )
        loss_obj = losses.DistributedCacheClassificationLoss(
            embedding_key=embedding_key, data_keys=data_keys)
        doc_network = lambda data: data['data']

        @tf.function
        def loss_fn_reduced(query_embedding, pos_doc_embedding, cache):
            return loss_obj(doc_network, query_embedding, pos_doc_embedding,
                            cache)

        output_reduced = self.strategy.run(
            loss_fn_reduced,
            args=(query_embeddings_multi_replica,
                  pos_doc_embeddings_multi_replica, {
                      'cache': cache
                  }))

        interpretable_loss_reduced = output_reduced.interpretable_loss.values
        interpretable_loss_reduced_expected = [(4.37452 + 0.890350) / 2.0,
                                               0.890350]
        self.tpuAssertAllClose(interpretable_loss_reduced_expected,
                               interpretable_loss_reduced)

        loss_obj.reducer = None

        @tf.function
        def loss_fn_no_reduce(query_embedding, pos_doc_embedding, cache):
            return loss_obj(doc_network, query_embedding, pos_doc_embedding,
                            cache)

        output_no_reduce = self.strategy.run(
            loss_fn_no_reduce,
            args=(query_embeddings_multi_replica,
                  pos_doc_embeddings_multi_replica, {
                      'cache': cache
                  }))
        interpretable_loss_no_reduce = output_no_reduce.interpretable_loss.values
        interpretable_loss_no_reduce_expected = [(4.37452, 0.890350),
                                                 (0.890350, 0.890350)]
        self.tpuAssertAllClose(interpretable_loss_no_reduce_expected,
                               interpretable_loss_no_reduce)
コード例 #15
0
    def testTrainingLoss(self):
        cache_data_1 = {
            'data': tf.convert_to_tensor([[1.0], [2.0]]),
            'embedding': tf.convert_to_tensor([[1.0], [2.0]])
        }
        cache_data_2 = {
            'data': tf.convert_to_tensor([[-1.0], [-2.0]]),
            'embedding': tf.convert_to_tensor([[-1.0], [-2.0]])
        }
        cache_data_multi_replica = {}
        for key in cache_data_1:
            cache_data_multi_replica[key] = make_distributed_tensor(
                self.strategy, [cache_data_1[key], cache_data_2[key]])
        cache_age = tf.zeros([0], dtype=tf.int32)
        cache_age_multi_replica = make_distributed_tensor(
            self.strategy, [cache_age, cache_age])
        cache = negative_cache.NegativeCache(data=cache_data_multi_replica,
                                             age=cache_age_multi_replica)
        query_embeddings_1 = tf.convert_to_tensor([[-1.0], [1.0]])
        query_embeddings_2 = tf.convert_to_tensor([[1.0], [1.0]])
        query_embeddings_multi_replica = make_distributed_tensor(
            self.strategy, [query_embeddings_1, query_embeddings_2])
        pos_doc_embeddings_1 = tf.convert_to_tensor([[2.0], [2.0]])
        pos_doc_embeddings_2 = tf.convert_to_tensor([[2.0], [2.0]])
        pos_doc_embeddings_multi_replica = make_distributed_tensor(
            self.strategy, [pos_doc_embeddings_1, pos_doc_embeddings_2])

        embedding_key = 'embedding'
        data_keys = ('data', )
        loss_obj = losses.DistributedCacheClassificationLoss(
            embedding_key=embedding_key, data_keys=data_keys)

        def mock_retrieval_fn(scores):
            if scores.shape[0] == 4:
                return tf.convert_to_tensor([[0], [1], [0], [1]],
                                            dtype=tf.int64)
            else:
                return tf.convert_to_tensor([[0], [1]], dtype=tf.int64)

        loss_obj._retrieval_fn = mock_retrieval_fn
        doc_network = lambda data: data['data']

        @tf.function
        def loss_fn_reduced(query_embedding, pos_doc_embedding, cache):
            return loss_obj(doc_network, query_embedding, pos_doc_embedding,
                            cache)

        output_reduced = self.strategy.run(
            loss_fn_reduced,
            args=(query_embeddings_multi_replica,
                  pos_doc_embeddings_multi_replica, {
                      'cache': cache
                  }))

        training_loss_reduced = output_reduced.training_loss.values
        training_loss_reduced_expected = [(0.9874058 + -2.35795) / 2.0,
                                          (-0.589488 + -2.35795) / 2.0]
        self.tpuAssertAllClose(training_loss_reduced_expected,
                               training_loss_reduced)

        loss_obj.reducer = None

        @tf.function
        def loss_fn_no_reduce(query_embedding, pos_doc_embedding, cache):
            return loss_obj(doc_network, query_embedding, pos_doc_embedding,
                            cache)

        output_no_reduce = self.strategy.run(
            loss_fn_no_reduce,
            args=(query_embeddings_multi_replica,
                  pos_doc_embeddings_multi_replica, {
                      'cache': cache
                  }))
        training_loss_no_reduce = output_no_reduce.training_loss.values
        training_loss_no_reduce_expected = [(0.9874058, -2.35795),
                                            (-0.589488, -2.35795)]
        self.tpuAssertAllClose(training_loss_no_reduce_expected,
                               training_loss_no_reduce)