コード例 #1
0
 def test_update_cache_without_lru(self):
     specs = {
         '1': tf.io.FixedLenFeature(shape=[2], dtype=tf.int32),
     }
     cache_manager = negative_cache.CacheManager(specs=specs,
                                                 cache_size=4,
                                                 use_lru=False)
     cache = negative_cache.NegativeCache(data={
         '1':
         tf.convert_to_tensor([[5, 5], [10, 10], [15, 15], [20, 20]],
                              dtype=tf.int32)
     },
                                          age=tf.convert_to_tensor(
                                              [1, 0, 1, 1], dtype=tf.int32))
     updated_item_indices = tf.convert_to_tensor([1, 3], dtype=tf.int32)
     updated_item_data = {
         '1': tf.convert_to_tensor([[1, 1], [2, 2]], dtype=tf.int32),
     }
     cache = cache_manager.update_cache(
         cache,
         updated_item_indices=updated_item_indices,
         updated_item_data=updated_item_data)
     cache_data_expected = [[5, 5], [1, 1], [15, 15], [2, 2]]
     cache_age_expected = [2, 1, 2, 2]
     self.assertAllEqual(cache_data_expected, cache.data['1'])
     self.assertAllEqual(cache_age_expected, cache.age)
コード例 #2
0
 def test_masked_update_cache_with_existing_items_not_in_index_one(self):
     specs = {
         '1': tf.io.FixedLenFeature(shape=[2], dtype=tf.int32),
         '2': tf.io.FixedLenFeature(shape=[3], dtype=tf.float32),
     }
     cache_manager = negative_cache.CacheManager(specs=specs, cache_size=4)
     cache = cache_manager.init_cache()
     updated_item_indices = tf.convert_to_tensor([0, 3], dtype=tf.int32)
     updated_item_data = {
         '1':
         tf.convert_to_tensor([[1, 1], [2, 2]], dtype=tf.int32),
         '2':
         tf.convert_to_tensor([[3.0, 3.0, 3.0], [4.0, 4.0, 4.0]],
                              dtype=tf.float32),
     }
     updated_item_mask = tf.convert_to_tensor([True, False])
     cache = cache_manager.update_cache(
         cache,
         updated_item_data=updated_item_data,
         updated_item_indices=updated_item_indices,
         updated_item_mask=updated_item_mask)
     self.assertEqual({'1', '2'}, set(cache.data.keys()))
     self.assertAllEqual(
         tf.convert_to_tensor([[1, 1], [0, 0], [0, 0], [0, 0]],
                              dtype=tf.float32), cache.data['1'])
     self.assertAllEqual(
         tf.convert_to_tensor([[3.0, 3.0, 3.0], [0.0, 0.0, 0.0],
                               [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
                              dtype=tf.float32), cache.data['2'])
     self.assertAllEqual(tf.convert_to_tensor([0, 1, 1, 1], dtype=tf.int32),
                         cache.age)
コード例 #3
0
 def test_update_cache_with_new_items_and_existing_items(self):
     specs = {
         '1': tf.io.FixedLenFeature(shape=[1], dtype=tf.int32),
         '2': tf.io.FixedLenFeature(shape=[1], dtype=tf.int32),
     }
     cache_manager = negative_cache.CacheManager(specs=specs, cache_size=2)
     data = {
         '1': tf.convert_to_tensor([[0], [0], [3]], dtype=tf.int32),
         '2': tf.convert_to_tensor([[1], [2], [4]], dtype=tf.int32)
     }
     age = tf.convert_to_tensor([2, 1, 0])
     cache = negative_cache.NegativeCache(data=data, age=age)
     updated_item_indices = tf.convert_to_tensor([0], dtype=tf.int32)
     updated_item_data = {
         '1': tf.convert_to_tensor([[10]], dtype=tf.int32),
     }
     new_items = {
         '1': tf.convert_to_tensor([[11]], dtype=tf.int32),
         '2': tf.convert_to_tensor([[12]], dtype=tf.int32)
     }
     cache = cache_manager.update_cache(
         cache,
         new_items=new_items,
         updated_item_data=updated_item_data,
         updated_item_indices=updated_item_indices)
     self.assertEqual({'1', '2'}, set(cache.data.keys()))
     self.assertAllEqual(
         tf.convert_to_tensor([[10], [11], [3]], dtype=tf.int32),
         cache.data['1'])
     self.assertAllEqual(
         tf.convert_to_tensor([[1], [12], [4]], dtype=tf.int32),
         cache.data['2'])
     self.assertAllEqual(tf.convert_to_tensor([0, 0, 1], dtype=tf.int32),
                         cache.age)
コード例 #4
0
 def test_masked_update_cache_with_existing_items_when_all_items_masked(
         self):
     specs = {
         '1': tf.io.FixedLenFeature(shape=[2], dtype=tf.int32),
     }
     cache_manager = negative_cache.CacheManager(specs=specs, cache_size=4)
     cache = negative_cache.NegativeCache(data={
         '1':
         tf.convert_to_tensor([[5, 5], [10, 10], [15, 15], [20, 20]],
                              dtype=tf.int32)
     },
                                          age=tf.convert_to_tensor(
                                              [2, 2, 2, 2], dtype=tf.int32))
     updated_item_indices = tf.convert_to_tensor([1, 3], dtype=tf.int32)
     updated_item_data = {
         '1': tf.convert_to_tensor([[1, 1], [2, 2]], dtype=tf.int32),
     }
     updated_item_mask = tf.convert_to_tensor([False, False])
     cache = cache_manager.update_cache(
         cache,
         updated_item_data=updated_item_data,
         updated_item_indices=updated_item_indices,
         updated_item_mask=updated_item_mask)
     self.assertEqual({'1'}, set(cache.data.keys()))
     self.assertAllEqual(
         tf.convert_to_tensor([[5, 5], [10, 10], [15, 15], [20, 20]],
                              dtype=tf.float32), cache.data['1'])
     self.assertAllEqual(tf.convert_to_tensor([3, 3, 3, 3], dtype=tf.int32),
                         cache.age)
コード例 #5
0
 def test_check_cache_after_update(self):
     specs = {
         'data': tf.io.FixedLenFeature(shape=[2], dtype=tf.int32),
         'embedding': tf.io.FixedLenFeature(shape=[3], dtype=tf.float32)
     }
     cache_manager = negative_cache.CacheManager(specs, cache_size=4)
     cache_loss = StubCacheLoss(
         updated_item_data={
             'cache': {
                 'embedding': tf.convert_to_tensor([[1.0, 1.0, 1.0]])
             }
         },
         updated_item_indices={'cache': tf.convert_to_tensor([0])},
         updated_item_mask={'cache': tf.convert_to_tensor([True])})
     handler = handlers.CacheLossHandler(cache_manager,
                                         cache_loss,
                                         embedding_key='embedding',
                                         data_keys=('data', ))
     loss_actual = handler.update_cache_and_compute_loss(
         item_network=None,
         query_embeddings=None,
         pos_item_embeddings=tf.convert_to_tensor([[2.0, 2.0, 2.0]]),
         features={'data': tf.convert_to_tensor([[2, 2]])})
     self.assertAllEqual({'data', 'embedding'},
                         set(handler.cache.data.keys()))
     self.assertAllEqual(
         tf.convert_to_tensor([[0, 0], [2, 2], [0, 0], [0, 0]]),
         handler.cache.data['data'])
     self.assertAllEqual(
         tf.convert_to_tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0],
                               [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]),
         handler.cache.data['embedding'])
     self.assertAllEqual(tf.convert_to_tensor([0, 0, 1, 1]),
                         handler.cache.age)
     self.assertEqual(0.0, loss_actual)
コード例 #6
0
 def test_update_cache_with_existing_items(self):
     specs = {
         '1': tf.io.FixedLenFeature(shape=[2], dtype=tf.int32),
         '2': tf.io.FixedLenFeature(shape=[3], dtype=tf.float32),
     }
     cache_manager = negative_cache.CacheManager(specs=specs, cache_size=4)
     cache = cache_manager.init_cache()
     updated_item_indices = tf.convert_to_tensor([1, 3], dtype=tf.int32)
     updated_item_data = {
         '1': tf.ones(shape=[2, 2], dtype=tf.int32),
         '2': tf.ones(shape=[2, 3], dtype=tf.float32)
     }
     cache = cache_manager.update_cache(
         cache,
         updated_item_data=updated_item_data,
         updated_item_indices=updated_item_indices)
     self.assertEqual({'1', '2'}, set(cache.data.keys()))
     self.assertAllEqual(
         tf.convert_to_tensor([[0, 0], [1, 1], [0, 0], [1, 1]],
                              dtype=tf.int32), cache.data['1'])
     self.assertAllEqual(
         tf.convert_to_tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0],
                               [0.0, 0.0, 0.0], [1.0, 1.0, 1.0]],
                              dtype=tf.float32), cache.data['2'])
     self.assertAllEqual(tf.convert_to_tensor([1, 0, 1, 0], dtype=tf.int32),
                         cache.age)
コード例 #7
0
 def test_init_cache(self):
   specs = {
       '1': tf.io.FixedLenFeature(shape=[2], dtype=tf.int32),
       '2': tf.io.FixedLenFeature(shape=[3], dtype=tf.float32),
       '3': tf.io.FixedLenFeature(shape=[3, 2], dtype=tf.float32)
   }
   cache_manager = negative_cache.CacheManager(specs=specs, cache_size=6)
   cache = cache_manager.init_cache()
   self.assertEqual({'1', '2', '3'}, set(cache.data.keys()))
   self.assertAllEqual(tf.zeros([6, 2], dtype=tf.int32), cache.data['1'])
   self.assertAllEqual(tf.zeros([6, 3], dtype=tf.float32), cache.data['2'])
   self.assertAllEqual(tf.zeros([6, 3, 2], dtype=tf.float32), cache.data['3'])
   self.assertAllEqual(tf.zeros([6], dtype=tf.int32), cache.age)
コード例 #8
0
 def test_raises_value_error_if_different_update_sizes(self):
     specs = {
         '1': tf.io.FixedLenFeature(shape=[2], dtype=tf.int32),
         '2': tf.io.FixedLenFeature(shape=[3], dtype=tf.float32),
     }
     cache_manager = negative_cache.CacheManager(specs=specs, cache_size=4)
     init_cache_fn = tf.function(cache_manager.init_cache)
     cache = init_cache_fn()
     updates = {
         '1': tf.ones(shape=[2, 2], dtype=tf.int32),
         '2': tf.ones(shape=[1, 3], dtype=tf.float32)
     }
     update_cache_fn = tf.function(cache_manager.update_cache)
     with self.assertRaises(ValueError):
         cache = update_cache_fn(cache, new_items=updates)
コード例 #9
0
 def test_update_cache(self):
     specs = {
         '1': tf.io.FixedLenFeature(shape=[2], dtype=tf.int32),
         '2': tf.io.FixedLenFeature(shape=[3], dtype=tf.float32),
     }
     cache_manager = negative_cache.CacheManager(specs=specs, cache_size=4)
     cache = cache_manager.init_cache()
     updates = {
         '1': tf.ones(shape=[2, 2], dtype=tf.int32),
         '2': tf.ones(shape=[2, 3], dtype=tf.float32)
     }
     cache = cache_manager.update_cache(cache, new_items=updates)
     self.assertEqual({'1', '2'}, set(cache.data.keys()))
     self.assertAllEqual(
         tf.convert_to_tensor([[1, 1], [1, 1], [0, 0], [0, 0]],
                              dtype=tf.int32), cache.data['1'])
     self.assertAllEqual(
         tf.convert_to_tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0],
                               [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
                              dtype=tf.float32), cache.data['2'])
     self.assertAllEqual(tf.convert_to_tensor([0, 0, 1, 1], dtype=tf.int32),
                         cache.age)
     updates = {
         '1': 2 * tf.ones(shape=[2, 2], dtype=tf.int32),
         '2': 2.0 * tf.ones(shape=[2, 3], dtype=tf.float32)
     }
     cache = cache_manager.update_cache(cache, new_items=updates)
     self.assertEqual({'1', '2'}, set(cache.data.keys()))
     self.assertAllEqual(
         tf.convert_to_tensor([[1, 1], [1, 1], [2, 2], [2, 2]],
                              dtype=tf.int32), cache.data['1'])
     self.assertAllEqual(
         tf.convert_to_tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0],
                               [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]],
                              dtype=tf.float32), cache.data['2'])
     updates = {
         '1': 3 * tf.ones(shape=[2, 2], dtype=tf.int32),
         '2': 3.0 * tf.ones(shape=[2, 3], dtype=tf.float32)
     }
     cache = cache_manager.update_cache(cache, new_items=updates)
     self.assertEqual({'1', '2'}, set(cache.data.keys()))
     self.assertAllEqual(
         tf.convert_to_tensor([[3, 3], [3, 3], [2, 2], [2, 2]],
                              dtype=tf.int32), cache.data['1'])
     self.assertAllEqual(
         tf.convert_to_tensor([[3.0, 3.0, 3.0], [3.0, 3.0, 3.0],
                               [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]],
                              dtype=tf.float32), cache.data['2'])
コード例 #10
0
 def test_raises_value_error_if_update_item_keys_not_in_specs(self):
     specs = {
         '1': tf.io.FixedLenFeature(shape=[1], dtype=tf.int32),
         '2': tf.io.FixedLenFeature(shape=[1], dtype=tf.int32),
     }
     cache_manager = negative_cache.CacheManager(specs=specs, cache_size=4)
     cache = cache_manager.init_cache()
     updated_item_data = {
         '1': tf.ones(shape=[2, 1], dtype=tf.int32),
         '3': tf.ones(shape=[2, 1], dtype=tf.int32),
     }
     updated_item_indices = tf.convert_to_tensor([0])
     with self.assertRaises(ValueError):
         cache = cache_manager.update_cache(
             cache,
             updated_item_data=updated_item_data,
             updated_item_indices=updated_item_indices)
コード例 #11
0
 def test_initialize_cache(self):
     specs = {
         'data': tf.io.FixedLenFeature(shape=[2], dtype=tf.int32),
         'embedding': tf.io.FixedLenFeature(shape=[3], dtype=tf.float32)
     }
     cache_manager = negative_cache.CacheManager(specs, cache_size=4)
     handler = handlers.CacheLossHandler(cache_manager,
                                         StubCacheLoss(None, None, None),
                                         embedding_key='embedding',
                                         data_keys=('data', ))
     self.assertAllEqual({'data', 'embedding'},
                         set(handler.cache.data.keys()))
     self.assertAllEqual(tf.zeros(shape=[4, 2], dtype=tf.int32),
                         handler.cache.data['data'])
     self.assertAllEqual(tf.zeros(shape=[4, 3], dtype=tf.float32),
                         handler.cache.data['embedding'])
     self.assertAllEqual(tf.zeros(shape=[4], dtype=tf.int32),
                         handler.cache.age)
コード例 #12
0
 def test_raises_value_error_if_new_item_keys_not_equal_specs(self):
     specs = {
         '1': tf.io.FixedLenFeature(shape=[1], dtype=tf.int32),
         '2': tf.io.FixedLenFeature(shape=[1], dtype=tf.int32),
     }
     cache_manager = negative_cache.CacheManager(specs=specs, cache_size=4)
     cache = cache_manager.init_cache()
     updates = {
         '1': tf.ones(shape=[2, 1], dtype=tf.int32),
     }
     with self.assertRaises(ValueError):
         cache = cache_manager.update_cache(cache, new_items=updates)
     updates = {
         '1': tf.ones(shape=[2, 1], dtype=tf.int32),
         '2': tf.ones(shape=[2, 1], dtype=tf.int32),
         '3': tf.ones(shape=[2, 1], dtype=tf.int32),
     }
     with self.assertRaises(ValueError):
         cache = cache_manager.update_cache(cache, new_items=updates)
コード例 #13
0
 def test_update_caches_with_tf_function(self):
     specs = {
         '1': tf.io.FixedLenFeature(shape=[2], dtype=tf.int32),
         '2': tf.io.FixedLenFeature(shape=[3], dtype=tf.float32),
     }
     cache_manager = negative_cache.CacheManager(specs=specs, cache_size=4)
     init_cache_fn = tf.function(cache_manager.init_cache)
     cache = init_cache_fn()
     updates = {
         '1': tf.ones(shape=[2, 2], dtype=tf.int32),
         '2': tf.ones(shape=[2, 3], dtype=tf.float32)
     }
     update_cache_fn = tf.function(cache_manager.update_cache)
     cache = update_cache_fn(cache, new_items=updates)
     self.assertEqual({'1', '2'}, set(cache.data.keys()))
     self.assertAllEqual(
         tf.convert_to_tensor([[1, 1], [1, 1], [0, 0], [0, 0]],
                              dtype=tf.int32), cache.data['1'])
     self.assertAllEqual(
         tf.convert_to_tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0],
                               [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
                              dtype=tf.float32), cache.data['2'])
コード例 #14
0
  def test_new_items_with_mask(self):
    specs = {
        '1': tf.io.FixedLenFeature(shape=[2], dtype=tf.int32),
    }
    cache_manager = negative_cache.CacheManager(specs=specs, cache_size=4)
    cache = negative_cache.NegativeCache(
        data={
            '1':
                tf.convert_to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]],
                                     dtype=tf.int32)
        },
        age=tf.convert_to_tensor([0, 2, 1, 3], dtype=tf.int32))

    new_items = {
        '1': tf.convert_to_tensor([[5, 5], [6, 6], [7, 7]], dtype=tf.int32)
    }
    new_items_mask = tf.convert_to_tensor([True, False, True])
    cache = cache_manager.update_cache(
        cache, new_items=new_items, new_items_mask=new_items_mask)
    self.assertAllEqual(
        tf.convert_to_tensor([[1, 1], [7, 7], [3, 3], [5, 5]], dtype=tf.int32),
        cache.data['1'])