def testWeakStructRespectsContainerTypes(self):
     a = tf.constant(1., name='alice')
     b = tf.constant(2., name='bob')
     # Different lists are equal.
     self.assertEqual(cache_util.WeakStructRef([a, b]),
                      cache_util.WeakStructRef([a, b]))
     # List and tuple with same contents are not equal.
     self.assertNotEqual(cache_util.WeakStructRef([a, b]),
                         cache_util.WeakStructRef((a, b)))
    def testStructRefCallbackFiresOnce(self):
        tensor_struct = {
            'a': tf.constant(1., name='alice'),
            'b': tf.constant(2., name='bob'),
            'c': tf.constant(3., name='carol')
        }

        callback_keys = []

        def callback(key):
            callback_keys.append(key)

        struct_ref = cache_util.WeakStructRef(tensor_struct, callback=callback)

        if tf.executing_eagerly():
            self.assertTrue(struct_ref.alive)
            self.assertEqual(callback_keys, [])
            del tensor_struct['a']  # Goodbye, Alice!
            self.assertFalse(struct_ref.alive)
            self.assertEqual(callback_keys, [struct_ref])
            del tensor_struct['b']  # Goodbye, Bob!
            self.assertEqual(callback_keys, [struct_ref])
            del tensor_struct  # Goodbye, everybody!
            self.assertEqual(callback_keys, [struct_ref])

        else:
            self.assertTrue(struct_ref.alive)
            del tensor_struct
            self.assertTrue(struct_ref.alive)
    def testInputsAreCached(self, steps=4):
        struct = self.test_arg()
        weak_start = cache_util.WeakStructRef(struct)

        # Call forward a few times.
        for _ in range(steps):
            struct = self.forward(struct)

        print(type(struct), struct)
        self.assertLen(self.forward_keys, steps)
        self.assertEqual(self.bijector.forward_call_count, steps)
        self.assertLen(self.cache.items(direction='forward'), steps)
        self.assertLen(self.inverse_keys, steps)
        self.assertEqual(self.bijector.inverse_call_count, 0)
        self.assertLen(self.cache.items(direction='inverse'), steps)

        # Now invert our calls
        for _ in range(steps):
            struct = self.inverse(struct)

        self.assertLen(self.forward_keys, 1)  # Has cached attrs.
        self.assertEqual(self.bijector.forward_call_count,
                         steps)  # No new calls.
        self.assertLen(self.inverse_keys, 0)  # Refs are all gone.
        self.assertEqual(self.bijector.inverse_call_count,
                         0)  # All cache hits.

        # Original is recoverable. Contents are referentially equal.
        self.assertTrue(weak_start.alive)
        tf.nest.map_structure(self.assertIs, struct, weak_start())

        struct = None
        self.assertFalse(weak_start.alive)
        self.assertLen(self.forward_keys, 0)
        self.assertLen(self.inverse_keys, 0)
  def testStructRefIsWeak(self):
    tensor_struct = {
        'a': tf.constant(1., name='alice'),
        'b': tf.constant(2., name='bob'),
        'c': tf.constant(3., name='carol')}

    weak_ref = cache_util.WeakStructRef(tensor_struct)
    another_weak_ref = cache_util.WeakStructRef(tensor_struct)

    self.assertTrue(weak_ref.alive)
    del tensor_struct

    # In eager mode, references get cleaned up.
    if tf.executing_eagerly():
      self.assertFalse(weak_ref.alive)
      self.assertFalse(another_weak_ref.alive)
    # In graph mode, references stay alive.
    else:
      self.assertTrue(weak_ref.alive)
      self.assertTrue(another_weak_ref.alive)
    def testWeakStructCopiesContainers(self):
        a = tf.constant(1., name='alice')
        b = tf.constant(2., name='bob')
        c = tf.constant(3., name='carol')
        tensor_struct = [a, {'x': b, 'y': c}]

        # Reference the original struct.
        weak_ref = cache_util.WeakStructRef(tensor_struct)

        # Copy the structure, then mutate the original inplace.
        struct_copy = tf.nest.map_structure(lambda x: x, tensor_struct)
        tensor_struct[:] = tensor_struct[::-1]

        self.assertEqual(weak_ref(), struct_copy)
        with self.assertRaises(ValueError):  # pylint: disable=g-error-prone-assert-raises
            tf.nest.assert_same_structure(weak_ref(), tensor_struct)
 def testStructRefChecksSubkey(self):
     tensor = tf.constant([1., 2., 3.], dtype=tf.float32, name='alice')
     ref1 = cache_util.WeakStructRef(tensor, subkey='a')
     ref2 = cache_util.WeakStructRef(tensor, subkey='b')
     self.assertNotEqual(ref1, ref2)
     self.assertNotEqual(hash(ref1), hash(ref2))