def test_masked_update_cache_with_existing_items_when_all_items_masked( self): specs = { '1': tf.io.FixedLenFeature(shape=[2], dtype=tf.int32), } cache_manager = negative_cache.CacheManager(specs=specs, cache_size=4) cache = negative_cache.NegativeCache(data={ '1': tf.convert_to_tensor([[5, 5], [10, 10], [15, 15], [20, 20]], dtype=tf.int32) }, age=tf.convert_to_tensor( [2, 2, 2, 2], dtype=tf.int32)) updated_item_indices = tf.convert_to_tensor([1, 3], dtype=tf.int32) updated_item_data = { '1': tf.convert_to_tensor([[1, 1], [2, 2]], dtype=tf.int32), } updated_item_mask = tf.convert_to_tensor([False, False]) cache = cache_manager.update_cache( cache, updated_item_data=updated_item_data, updated_item_indices=updated_item_indices, updated_item_mask=updated_item_mask) self.assertEqual({'1'}, set(cache.data.keys())) self.assertAllEqual( tf.convert_to_tensor([[5, 5], [10, 10], [15, 15], [20, 20]], dtype=tf.float32), cache.data['1']) self.assertAllEqual(tf.convert_to_tensor([3, 3, 3, 3], dtype=tf.int32), cache.age)
def test_update_cache_without_lru(self): specs = { '1': tf.io.FixedLenFeature(shape=[2], dtype=tf.int32), } cache_manager = negative_cache.CacheManager(specs=specs, cache_size=4, use_lru=False) cache = negative_cache.NegativeCache(data={ '1': tf.convert_to_tensor([[5, 5], [10, 10], [15, 15], [20, 20]], dtype=tf.int32) }, age=tf.convert_to_tensor( [1, 0, 1, 1], dtype=tf.int32)) updated_item_indices = tf.convert_to_tensor([1, 3], dtype=tf.int32) updated_item_data = { '1': tf.convert_to_tensor([[1, 1], [2, 2]], dtype=tf.int32), } cache = cache_manager.update_cache( cache, updated_item_indices=updated_item_indices, updated_item_data=updated_item_data) cache_data_expected = [[5, 5], [1, 1], [15, 15], [2, 2]] cache_age_expected = [2, 1, 2, 2] self.assertAllEqual(cache_data_expected, cache.data['1']) self.assertAllEqual(cache_age_expected, cache.age)
def test_cache_classification_loss_training_loss(self): cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0]]) cached_data = tf.convert_to_tensor([[1.0], [-1.0]]) cache = { 'cache': negative_cache.NegativeCache( data={ 'embeddings': cached_embeddings, 'data': cached_data }, age=[0, 0]) } doc_network = lambda data: data['data'] query_embedding = tf.convert_to_tensor([[-1.0], [1.0], [3.0]]) pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0], [1.0]]) # pylint: disable=g-long-lambda retrieval_fn = lambda scores: tf.convert_to_tensor([[0], [1], [0]], dtype=tf.int64) # pylint: enable=g-long-lambda loss_fn = losses.CacheClassificationLoss( 'embeddings', ['data'], reducer=None) loss_fn._retrieval_fn = retrieval_fn training_loss = loss_fn(doc_network, query_embedding, pos_doc_embedding, cache).training_loss prob_pos = tf.convert_to_tensor([0.0420101, 0.705385, 0.499381]) score_differences = tf.convert_to_tensor([1.0, -3.0, 0.0]) training_loss_expected = (1.0 - prob_pos) * score_differences self.assertAllClose(training_loss_expected, training_loss) loss_fn.reducer = tf.math.reduce_mean training_loss = loss_fn(doc_network, query_embedding, pos_doc_embedding, cache).training_loss training_loss_expected = tf.reduce_mean( (1.0 - prob_pos) * score_differences) self.assertAllClose(training_loss_expected, training_loss)
def test_update_cache_with_new_items_and_existing_items(self): specs = { '1': tf.io.FixedLenFeature(shape=[1], dtype=tf.int32), '2': tf.io.FixedLenFeature(shape=[1], dtype=tf.int32), } cache_manager = negative_cache.CacheManager(specs=specs, cache_size=2) data = { '1': tf.convert_to_tensor([[0], [0], [3]], dtype=tf.int32), '2': tf.convert_to_tensor([[1], [2], [4]], dtype=tf.int32) } age = tf.convert_to_tensor([2, 1, 0]) cache = negative_cache.NegativeCache(data=data, age=age) updated_item_indices = tf.convert_to_tensor([0], dtype=tf.int32) updated_item_data = { '1': tf.convert_to_tensor([[10]], dtype=tf.int32), } new_items = { '1': tf.convert_to_tensor([[11]], dtype=tf.int32), '2': tf.convert_to_tensor([[12]], dtype=tf.int32) } cache = cache_manager.update_cache( cache, new_items=new_items, updated_item_data=updated_item_data, updated_item_indices=updated_item_indices) self.assertEqual({'1', '2'}, set(cache.data.keys())) self.assertAllEqual( tf.convert_to_tensor([[10], [11], [3]], dtype=tf.int32), cache.data['1']) self.assertAllEqual( tf.convert_to_tensor([[1], [12], [4]], dtype=tf.int32), cache.data['2']) self.assertAllEqual(tf.convert_to_tensor([0, 0, 1], dtype=tf.int32), cache.age)
def test_cache_classification_loss_interpretable_loss(self): cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0]]) cached_data = tf.convert_to_tensor([[1.0], [-1.0]]) cache = { 'cache': negative_cache.NegativeCache( data={ 'embeddings': cached_embeddings, 'data': cached_data }, age=[0, 0]) } doc_network = lambda data: data['data'] query_embedding = tf.convert_to_tensor([[-1.0], [1.0], [3.0]]) pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0], [1.0]]) loss_fn = losses.CacheClassificationLoss( 'embeddings', ['data'], reducer=None) interpretable_loss = loss_fn(doc_network, query_embedding, pos_doc_embedding, cache).interpretable_loss interpretable_loss_expected = [3.169846, 0.349012, 0.694385] self.assertAllClose(interpretable_loss_expected, interpretable_loss) loss_fn.reducer = tf.math.reduce_mean interpretable_loss = loss_fn(doc_network, query_embedding, pos_doc_embedding, cache).interpretable_loss interpretable_loss_expected = (3.169846 + 0.349012 + 0.694385) / 3.0 self.assertAllClose(interpretable_loss_expected, interpretable_loss)
def test_cache_classification_loss_refreshed_embeddings_with_top_k(self): cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0], [3.0], [2.0]]) cached_data = tf.convert_to_tensor([[1.0], [-1.0], [3.0], [2.0]]) cache = { 'cache': negative_cache.NegativeCache( data={ 'embeddings': cached_embeddings, 'data': cached_data }, age=[0, 0]) } doc_network = lambda data: data['data'] query_embedding = tf.convert_to_tensor([[-1.0], [1.0]]) pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0]]) loss_fn = losses.CacheClassificationLoss('embeddings', ['data'], top_k=2) # pylint: disable=g-long-lambda retrieval_fn = lambda scores: tf.convert_to_tensor([[0], [0]], dtype=tf.int64) # pylint: enable=g-long-lambda loss_fn._retrieval_fn = retrieval_fn loss_fn_return = loss_fn(doc_network, query_embedding, pos_doc_embedding, cache) updated_item_data = loss_fn_return.updated_item_data updated_item_indices = loss_fn_return.updated_item_indices updated_item_mask = loss_fn_return.updated_item_mask self.assertAllClose([[-1.0], [1.0]], updated_item_data['cache']['embeddings']) self.assertAllEqual([1, 0], updated_item_indices['cache']) self.assertAllEqual([True, True], updated_item_mask['cache'])
def test_cache_classification_loss_staleness(self): cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0]]) cached_data = tf.convert_to_tensor([[2.0], [-3.0]]) cache = { 'cache': negative_cache.NegativeCache( data={ 'embeddings': cached_embeddings, 'data': cached_data }, age=[0, 0]) } doc_network = lambda data: data['data'] query_embedding = tf.convert_to_tensor([[0.0], [0.0], [0.0]]) pos_doc_embedding = tf.convert_to_tensor([[0.0], [0.0], [0.0]]) # pylint: disable=g-long-lambda retrieval_fn = lambda scores: tf.convert_to_tensor([[0], [1], [0]], dtype=tf.int64) # pylint: enable=g-long-lambda loss_fn = losses.CacheClassificationLoss('embeddings', ['data']) loss_fn._retrieval_fn = retrieval_fn staleness = loss_fn(doc_network, query_embedding, pos_doc_embedding, cache).staleness staleness_expected = 0.31481481481 self.assertAllClose(staleness_expected, staleness) loss_fn.reducer = None staleness = loss_fn(doc_network, query_embedding, pos_doc_embedding, cache).staleness staleness_expected = [1.0 / 4.0, 4.0 / 9.0, 1.0 / 4.0] self.assertAllClose(staleness_expected, staleness)
def test_cache_classification_loss_training_loss_gradient(self): cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0]]) cached_data = tf.convert_to_tensor([[1.0], [-1.0]]) cache = { 'cache': negative_cache.NegativeCache( data={ 'embeddings': cached_embeddings, 'data': cached_data }, age=[0, 0]) } query_model = tf.Variable(1.0) doc_model = tf.Variable(1.0) doc_network = lambda data: data['data'] * doc_model query_embedding = tf.convert_to_tensor([[-1.0], [1.0], [3.0]]) pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0], [1.0]]) # pylint: disable=g-long-lambda retrieval_fn = lambda scores: tf.convert_to_tensor([[0], [1], [0]], dtype=tf.int64) # pylint: enable=g-long-lambda loss_fn = losses.CacheClassificationLoss('embeddings', ['data']) loss_fn._retrieval_fn = retrieval_fn with tf.GradientTape() as tape: query_embedding = query_model * query_embedding pos_doc_embedding = doc_model * pos_doc_embedding training_loss = loss_fn(doc_network, query_embedding, pos_doc_embedding, cache).training_loss gradient = tape.gradient(training_loss, [query_model, doc_model]) gradient_expected = [0.024715006, 0.024715006] self.assertAllClose(gradient_expected, gradient)
def test_cache_classification_loss_refreshed_embeddings(self): cached_embeddings_1 = tf.convert_to_tensor([[1.0], [2.0], [3.0]]) cached_embeddings_2 = tf.convert_to_tensor([[-1.0], [-2.0]]) cached_data_1 = tf.convert_to_tensor([[10.0], [20.0], [30.0]]) cached_data_2 = tf.convert_to_tensor([[-10.0], [-20.0]]) cache = { 'cache1': negative_cache.NegativeCache( data={ 'embeddings': cached_embeddings_1, 'data': cached_data_1 }, age=[0, 0, 0]), 'cache2': negative_cache.NegativeCache( data={ 'embeddings': cached_embeddings_2, 'data': cached_data_2 }, age=[0, 0]), } doc_network = lambda data: data['data'] query_embedding = tf.convert_to_tensor([[0.0], [0.0], [0.0]]) pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0], [1.0]]) # pylint: disable=g-long-lambda retrieval_fn = lambda scores: tf.convert_to_tensor([[0], [1], [3]], dtype=tf.int64) # pylint: enable=g-long-lambda loss_fn = losses.CacheClassificationLoss('embeddings', ['data']) loss_fn._retrieval_fn = retrieval_fn cache_loss_return = loss_fn(doc_network, query_embedding, pos_doc_embedding, cache) self.assertAllEqual(cache_loss_return.updated_item_mask['cache1'], [True, True, False]) self.assertAllEqual(cache_loss_return.updated_item_mask['cache2'], [False, False, True]) self.assertAllEqual( cache_loss_return.updated_item_data['cache1']['embeddings'][0:2], [[10.0], [20.0]]) self.assertAllEqual( cache_loss_return.updated_item_data['cache2']['embeddings'][2], [-10.0]) self.assertAllEqual(cache_loss_return.updated_item_indices['cache1'][0:2], [0, 1]) self.assertEqual(cache_loss_return.updated_item_indices['cache2'][2], 0)
def _initialize_cache(self): cache = {} for key in self.cache_managers: initial_cache = self.cache_managers[key].init_cache() initial_cache_age = _make_cache_variable(initial_cache.age) initial_cache_data = { k: _make_cache_variable(initial_cache.data[k]) for k in initial_cache.data } cache[key] = negative_cache.NegativeCache(data=initial_cache_data, age=initial_cache_age) return cache
def test_is_in_cache_filter_fn_with_missing_keys(self): data = { '1': tf.convert_to_tensor([[0, 0], [1, 1], [2, 2], [3, 3], [20, 20], [30, 30], [40, 40]]), '2': tf.convert_to_tensor([[4, 4], [5, 5], [6, 6], [7, 7], [50, 50], [60, 60], [70, 70]]) } age = tf.zeros([4]) cache = negative_cache.NegativeCache(data, age) new_items = { '1': tf.convert_to_tensor([[0, 0], [1, 1], [3, 3], [8, 8], [9, 9]]), '2': tf.convert_to_tensor([[4, 4], [7, 7], [7, 7], [10, 10], [11, 11]]) } is_in_cache_filter_fn = filter_fns.IsInCacheFilterFn(keys=('1',)) mask = is_in_cache_filter_fn(cache, new_items) mask_expected = tf.convert_to_tensor([False, False, False, True, True]) self.assertAllEqual(mask_expected, mask)
def test_new_items_with_mask(self): specs = { '1': tf.io.FixedLenFeature(shape=[2], dtype=tf.int32), } cache_manager = negative_cache.CacheManager(specs=specs, cache_size=4) cache = negative_cache.NegativeCache( data={ '1': tf.convert_to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype=tf.int32) }, age=tf.convert_to_tensor([0, 2, 1, 3], dtype=tf.int32)) new_items = { '1': tf.convert_to_tensor([[5, 5], [6, 6], [7, 7]], dtype=tf.int32) } new_items_mask = tf.convert_to_tensor([True, False, True]) cache = cache_manager.update_cache( cache, new_items=new_items, new_items_mask=new_items_mask) self.assertAllEqual( tf.convert_to_tensor([[1, 1], [7, 7], [3, 3], [5, 5]], dtype=tf.int32), cache.data['1'])
def test_cache_classification_loss_interpretable_loss_with_top_k(self): cached_embeddings = tf.convert_to_tensor([[1.0], [-1.0], [3.0], [2.0]]) cached_data = tf.convert_to_tensor([[1.0], [-1.0], [3.0], [2.0]]) cache = { 'cache': negative_cache.NegativeCache( data={ 'embeddings': cached_embeddings, 'data': cached_data }, age=[0, 0]) } doc_network = lambda data: data['data'] query_embedding = tf.convert_to_tensor([[-1.0], [1.0]]) pos_doc_embedding = tf.convert_to_tensor([[2.0], [2.0]]) loss_fn = losses.CacheClassificationLoss( 'embeddings', ['data'], reducer=None, top_k=2) interpretable_loss = loss_fn(doc_network, query_embedding, pos_doc_embedding, cache).interpretable_loss interpretable_loss_expected = [3.0949229, 1.407605964] self.assertAllClose(interpretable_loss_expected, interpretable_loss)
def testInterpretableLoss(self): cache_data_1 = { 'data': tf.convert_to_tensor([[1.0], [2.0]]), 'embedding': tf.convert_to_tensor([[1.0], [2.0]]) } cache_data_2 = { 'data': tf.convert_to_tensor([[-1.0], [-2.0]]), 'embedding': tf.convert_to_tensor([[-1.0], [-2.0]]) } cache_data_multi_replica = {} for key in cache_data_1: cache_data_multi_replica[key] = make_distributed_tensor( self.strategy, [cache_data_1[key], cache_data_2[key]]) cache_age = tf.zeros([0], dtype=tf.int32) cache_age_multi_replica = make_distributed_tensor( self.strategy, [cache_age, cache_age]) cache = negative_cache.NegativeCache(data=cache_data_multi_replica, age=cache_age_multi_replica) query_embeddings_1 = tf.convert_to_tensor([[-1.0], [1.0]]) query_embeddings_2 = tf.convert_to_tensor([[1.0], [1.0]]) query_embeddings_multi_replica = make_distributed_tensor( self.strategy, [query_embeddings_1, query_embeddings_2]) pos_doc_embeddings_1 = tf.convert_to_tensor([[2.0], [2.0]]) pos_doc_embeddings_2 = tf.convert_to_tensor([[2.0], [2.0]]) pos_doc_embeddings_multi_replica = make_distributed_tensor( self.strategy, [pos_doc_embeddings_1, pos_doc_embeddings_2]) embedding_key = 'embedding' data_keys = ('data', ) loss_obj = losses.DistributedCacheClassificationLoss( embedding_key=embedding_key, data_keys=data_keys) doc_network = lambda data: data['data'] @tf.function def loss_fn_reduced(query_embedding, pos_doc_embedding, cache): return loss_obj(doc_network, query_embedding, pos_doc_embedding, cache) output_reduced = self.strategy.run( loss_fn_reduced, args=(query_embeddings_multi_replica, pos_doc_embeddings_multi_replica, { 'cache': cache })) interpretable_loss_reduced = output_reduced.interpretable_loss.values interpretable_loss_reduced_expected = [(4.37452 + 0.890350) / 2.0, 0.890350] self.tpuAssertAllClose(interpretable_loss_reduced_expected, interpretable_loss_reduced) loss_obj.reducer = None @tf.function def loss_fn_no_reduce(query_embedding, pos_doc_embedding, cache): return loss_obj(doc_network, query_embedding, pos_doc_embedding, cache) output_no_reduce = self.strategy.run( loss_fn_no_reduce, args=(query_embeddings_multi_replica, pos_doc_embeddings_multi_replica, { 'cache': cache })) interpretable_loss_no_reduce = output_no_reduce.interpretable_loss.values interpretable_loss_no_reduce_expected = [(4.37452, 0.890350), (0.890350, 0.890350)] self.tpuAssertAllClose(interpretable_loss_no_reduce_expected, interpretable_loss_no_reduce)
def testTrainingLoss(self): cache_data_1 = { 'data': tf.convert_to_tensor([[1.0], [2.0]]), 'embedding': tf.convert_to_tensor([[1.0], [2.0]]) } cache_data_2 = { 'data': tf.convert_to_tensor([[-1.0], [-2.0]]), 'embedding': tf.convert_to_tensor([[-1.0], [-2.0]]) } cache_data_multi_replica = {} for key in cache_data_1: cache_data_multi_replica[key] = make_distributed_tensor( self.strategy, [cache_data_1[key], cache_data_2[key]]) cache_age = tf.zeros([0], dtype=tf.int32) cache_age_multi_replica = make_distributed_tensor( self.strategy, [cache_age, cache_age]) cache = negative_cache.NegativeCache(data=cache_data_multi_replica, age=cache_age_multi_replica) query_embeddings_1 = tf.convert_to_tensor([[-1.0], [1.0]]) query_embeddings_2 = tf.convert_to_tensor([[1.0], [1.0]]) query_embeddings_multi_replica = make_distributed_tensor( self.strategy, [query_embeddings_1, query_embeddings_2]) pos_doc_embeddings_1 = tf.convert_to_tensor([[2.0], [2.0]]) pos_doc_embeddings_2 = tf.convert_to_tensor([[2.0], [2.0]]) pos_doc_embeddings_multi_replica = make_distributed_tensor( self.strategy, [pos_doc_embeddings_1, pos_doc_embeddings_2]) embedding_key = 'embedding' data_keys = ('data', ) loss_obj = losses.DistributedCacheClassificationLoss( embedding_key=embedding_key, data_keys=data_keys) def mock_retrieval_fn(scores): if scores.shape[0] == 4: return tf.convert_to_tensor([[0], [1], [0], [1]], dtype=tf.int64) else: return tf.convert_to_tensor([[0], [1]], dtype=tf.int64) loss_obj._retrieval_fn = mock_retrieval_fn doc_network = lambda data: data['data'] @tf.function def loss_fn_reduced(query_embedding, pos_doc_embedding, cache): return loss_obj(doc_network, query_embedding, pos_doc_embedding, cache) output_reduced = self.strategy.run( loss_fn_reduced, args=(query_embeddings_multi_replica, pos_doc_embeddings_multi_replica, { 'cache': cache })) training_loss_reduced = output_reduced.training_loss.values training_loss_reduced_expected = [(0.9874058 + -2.35795) / 2.0, (-0.589488 + -2.35795) / 2.0] self.tpuAssertAllClose(training_loss_reduced_expected, training_loss_reduced) loss_obj.reducer = None @tf.function def loss_fn_no_reduce(query_embedding, pos_doc_embedding, cache): return loss_obj(doc_network, query_embedding, pos_doc_embedding, cache) output_no_reduce = self.strategy.run( loss_fn_no_reduce, args=(query_embeddings_multi_replica, pos_doc_embeddings_multi_replica, { 'cache': cache })) training_loss_no_reduce = output_no_reduce.training_loss.values training_loss_no_reduce_expected = [(0.9874058, -2.35795), (-0.589488, -2.35795)] self.tpuAssertAllClose(training_loss_no_reduce_expected, training_loss_no_reduce)