def testClear(self): bijector_2 = _BareBonesBijector(self.forward_impl, self.inverse_impl) cache_2 = cache_util.BijectorCache(bijector=bijector_2, bijector_class=_BareBonesBijector, storage=self.cache.storage) class_cache = cache_util.BijectorCache( bijector=None, bijector_class=_BareBonesBijector, storage=self.cache.storage) x = self.test_arg() _ = self.forward(x) _ = cache_2.forward(x) y = self.test_arg() _ = self.forward(y) # Class cache has three entries self.assertLen(class_cache.weak_keys(direction='forward'), 3) # Clear entry associated with bijector_2 cache_2.clear() self.assertLen(class_cache.weak_keys(direction='forward'), 2) # Clear class cache class_cache.clear() self.assertLen(class_cache.weak_keys(direction='forward'), 0)
def _setup_cache(self): """Defines the cache for this bijector.""" # Wrap forward/inverse with getters so instance methods can be patched. return cache_util.BijectorCache( forward_impl=(lambda x, **kwargs: self._forward(x, **kwargs)), # pylint: disable=unnecessary-lambda inverse_impl=(lambda y, **kwargs: self._inverse(y, **kwargs)), # pylint: disable=unnecessary-lambda cache_type=cache_util.CachedDirectedFunction)
def setUp(self): if not tf.executing_eagerly(): self.skipTest('Not interesting in graph mode.') self.bijector = _BareBonesBijector(self.forward_impl, self.inverse_impl) self.cache = cache_util.BijectorCache( bijector=self.bijector, bijector_class=type(self.bijector)) super(CacheTestBase, self).setUp()
def setUp(self): if not tf.executing_eagerly(): self.skipTest('Not interesting in graph mode.') self.forward_call_count = 0 self.inverse_call_count = 0 # Build a cache for every test case. self.cache = cache_util.BijectorCache( self._call_forward, self._call_inverse) # Stick cached methods on the instance for convenience. self.forward = self.cache.forward self.inverse = self.cache.inverse super(CacheTestBase, self).setUp()
def testInstanceCache(self): instance_cache_bijector = tfb.Exp() instance_cache_bijector._cache = cache_util.BijectorCache( bijector=instance_cache_bijector) global_cache_bijector = tfb.Exp() x = tf.constant(0., dtype=tf.float32) y = global_cache_bijector.forward(x) # Instance-level cache doesn't store values from calls to an identical but # globally-cached bijector. self.assertLen( global_cache_bijector._cache.weak_keys(direction='forward'), 1) self.assertLen( instance_cache_bijector._cache.weak_keys(direction='forward'), 0) # Bijector with instance-level cache performs a globally-cached # transformation => cache miss. (Implying global cache did not pick it up.) z = instance_cache_bijector.forward(x) self.assertIsNot(y, z)
def _setup_cache(self): """Overrides the bijector cache to update attrs on forward/inverse.""" return cache_util.BijectorCache( forward_impl=self._augmented_forward, inverse_impl=self._augmented_inverse, cache_type=cache_util.CachedDirectedFunctionWithGreedyAttrs)