def testClear(self):

        bijector_2 = _BareBonesBijector(self.forward_impl, self.inverse_impl)
        cache_2 = cache_util.BijectorCache(bijector=bijector_2,
                                           bijector_class=_BareBonesBijector,
                                           storage=self.cache.storage)
        class_cache = cache_util.BijectorCache(
            bijector=None,
            bijector_class=_BareBonesBijector,
            storage=self.cache.storage)

        x = self.test_arg()
        _ = self.forward(x)
        _ = cache_2.forward(x)

        y = self.test_arg()
        _ = self.forward(y)

        # Class cache has three entries
        self.assertLen(class_cache.weak_keys(direction='forward'), 3)

        # Clear entry associated with bijector_2
        cache_2.clear()
        self.assertLen(class_cache.weak_keys(direction='forward'), 2)

        # Clear class cache
        class_cache.clear()
        self.assertLen(class_cache.weak_keys(direction='forward'), 0)
Example #2
0
 def _setup_cache(self):
   """Defines the cache for this bijector."""
   # Wrap forward/inverse with getters so instance methods can be patched.
   return cache_util.BijectorCache(
       forward_impl=(lambda x, **kwargs: self._forward(x, **kwargs)),  # pylint: disable=unnecessary-lambda
       inverse_impl=(lambda y, **kwargs: self._inverse(y, **kwargs)),  # pylint: disable=unnecessary-lambda
       cache_type=cache_util.CachedDirectedFunction)
  def setUp(self):
    if not tf.executing_eagerly():
      self.skipTest('Not interesting in graph mode.')

    self.bijector = _BareBonesBijector(self.forward_impl, self.inverse_impl)
    self.cache = cache_util.BijectorCache(
        bijector=self.bijector, bijector_class=type(self.bijector))
    super(CacheTestBase, self).setUp()
  def setUp(self):
    if not tf.executing_eagerly():
      self.skipTest('Not interesting in graph mode.')

    self.forward_call_count = 0
    self.inverse_call_count = 0
    # Build a cache for every test case.
    self.cache = cache_util.BijectorCache(
        self._call_forward, self._call_inverse)
    # Stick cached methods on the instance for convenience.
    self.forward = self.cache.forward
    self.inverse = self.cache.inverse
    super(CacheTestBase, self).setUp()
  def testInstanceCache(self):
    instance_cache_bijector = tfb.Exp()
    instance_cache_bijector._cache = cache_util.BijectorCache(
        bijector=instance_cache_bijector)
    global_cache_bijector = tfb.Exp()

    x = tf.constant(0., dtype=tf.float32)
    y = global_cache_bijector.forward(x)

    # Instance-level cache doesn't store values from calls to an identical but
    # globally-cached bijector.
    self.assertLen(
        global_cache_bijector._cache.weak_keys(direction='forward'), 1)
    self.assertLen(
        instance_cache_bijector._cache.weak_keys(direction='forward'), 0)

    # Bijector with instance-level cache performs a globally-cached
    # transformation => cache miss. (Implying global cache did not pick it up.)
    z = instance_cache_bijector.forward(x)
    self.assertIsNot(y, z)
Example #6
0
 def _setup_cache(self):
     """Overrides the bijector cache to update attrs on forward/inverse."""
     return cache_util.BijectorCache(
         forward_impl=self._augmented_forward,
         inverse_impl=self._augmented_inverse,
         cache_type=cache_util.CachedDirectedFunctionWithGreedyAttrs)