Esempio n. 1
0
    def testFunctionCacheKeyRespectsSupertype(self):
        ctx = function_cache.FunctionContext(0)
        key_a = function_cache.FunctionCacheKey(MockSupertypes2With3(1), ctx)
        key_b = function_cache.FunctionCacheKey(MockSupertypes2With3(2), ctx)

        self.assertEqual(
            key_b.most_specific_common_supertype([key_a]),
            function_cache.FunctionCacheKey(MockSupertypes2With3(3), ctx))
        self.assertIsNone(key_a.most_specific_common_supertype([key_b]))
Esempio n. 2
0
    def testFunctionCacheKeyRespectsSubtype(self):
        ctx = function_cache.FunctionContext(0)
        key_a = function_cache.FunctionCacheKey(MockSubtypeOf2(1), ctx)
        key_b = function_cache.FunctionCacheKey(MockSubtypeOf2(2), ctx)
        key_c = function_cache.FunctionCacheKey(MockSubtypeOf2(1), ctx)

        self.assertTrue(key_a.is_subtype_of(key_b))
        self.assertFalse(key_b.is_subtype_of(key_a))
        self.assertFalse(key_a.is_subtype_of(key_c))
Esempio n. 3
0
    def testFunctionCacheKeyRespectsEquality(self):
        ctx = function_cache.FunctionContext(0)
        generic = MockGenericType
        key_a = function_cache.FunctionCacheKey(generic(1), ctx)
        key_b = function_cache.FunctionCacheKey(generic(2), ctx)
        key_c = function_cache.FunctionCacheKey(generic(1), ctx)

        self.assertNotEqual(key_a, key_b)
        self.assertEqual(key_a, key_c)
        self.assertEqual(hash(key_a), hash(key_c))
Esempio n. 4
0
    def testFirstMostSpecificFunctionCacheKeyIsLookedUp(self):
        ctx = function_cache.FunctionContext(0)
        cache = function_cache.FunctionCache()
        cache.add(function_cache.FunctionCacheKey(MockShape(1, 2, None), ctx),
                  trace_type.WeakrefDeletionObserver(), "a")
        cache.add(function_cache.FunctionCacheKey(MockShape(1, None, 3), ctx),
                  trace_type.WeakrefDeletionObserver(), "b")

        self.assertEqual(
            cache.lookup(
                function_cache.FunctionCacheKey(MockShape(1, 2, 3), ctx),
                True), "a")
Esempio n. 5
0
 def setup():
   cache.clear()
   for key in keys:
     cache.add(*key, "testing")
   cache.add(
       function_cache.FunctionCacheKey(MockSubtypeOf2(3),
                                       MockEmptyCaptureSnapshot(), None),
       trace_type.WeakrefDeletionObserver(), "testing")
Esempio n. 6
0
    def testDeleteRemovesConcreteFunctions(self):
        cache = function_cache.FunctionCache()
        key_1, deletion_observer_1 = function_context.make_cache_key(1)
        cache.add(key_1, deletion_observer_1, "test_1")
        self.assertEqual(cache.lookup(key_1, False), "test_1")
        cache.delete(key_1)
        self.assertIsNone(cache.lookup(key_1, False))

        key_2 = function_cache.FunctionCacheKey(MockSubtypeOf2(2), None)
        cache.add(key_2, trace_type.WeakrefDeletionObserver(), "test_2")
        self.assertEqual(cache.lookup(key_2, False), "test_2")

        key_3 = function_cache.FunctionCacheKey(MockSubtypeOf2(3), None)
        self.assertEqual(cache.lookup(key_3, True), "test_2")

        cache.delete(key_2)
        self.assertIsNone(cache.lookup(key_2, False))
        self.assertIsNone(cache.lookup(key_3, True))
Esempio n. 7
0
def make_cache_key(
    args
) -> Tuple[function_cache.FunctionCacheKey, trace_type.WeakrefDeletionObserver]:
  """Computes the cache key given the function arguments."""
  signature_context = trace_type.InternalTracingContext()
  function_signature = trace_type.from_object(
      args, signature_context)
  return function_cache.FunctionCacheKey(
      function_signature,
      make_function_context()), signature_context.deletion_observer
Esempio n. 8
0
def make_cache_key(
    args,
    include_tensor_ranks_only: bool = False
) -> Tuple[function_cache.FunctionCacheKey, trace_type.WeakrefDeletionObserver]:
  """Computes the cache key given the function arguments."""
  signature_context = trace_type.SignatureContext(
      include_tensor_ranks_only)
  function_signature = trace_type.make_function_signature(
      args, signature_context)
  return function_cache.FunctionCacheKey(
      function_signature,
      make_function_context()), signature_context.deletion_observer
Esempio n. 9
0
  def benchmarkCacheHit50thKeyKnownSubtype(self):
    # If there are 50 keys and we get a key that has a subtype in cache and
    # the cache has observed the key before (to memorize the subtype).

    cache = function_cache.FunctionCache()
    args_per_call = 5
    num_total_checks = 50

    keys = []
    for i in range(num_total_checks - 1):
      args = []
      for j in range(args_per_call):
        args.append(array_ops.zeros([i, j]))
      keys.append(function_context.make_cache_key(args))

    for key in keys:
      cache.add(*key, "testing")
    cache.add(
        function_cache.FunctionCacheKey(MockSubtypeOf2(2),
                                        MockEmptyCaptureSnapshot(), None),
        trace_type.WeakrefDeletionObserver(), "testing")
    cache.lookup(function_cache.FunctionCacheKey(MockSubtypeOf2(3),
                                                 MockEmptyCaptureSnapshot(),
                                                 None), True)

    iterations = 10000
    lookup_key = function_cache.FunctionCacheKey(MockSubtypeOf2(2),
                                                 MockEmptyCaptureSnapshot(),
                                                 None)
    subtyping_time = timeit.timeit(
        lambda: cache.lookup(lookup_key, True), number=iterations)

    self.report_benchmark(
        name="cache_hit_50th_key_known_subtype",
        iters=iterations,
        wall_time=subtyping_time,
        metrics=[{
            "name": "cache_hit_50th_key_known_subtype_avg_ms",
            "value": subtyping_time / iterations * 1000
        }])
Esempio n. 10
0
  def testMostSpecificFunctionCacheKeyIsOrderAgnostic(self):
    ctx = function_cache.FunctionContext(0)
    keys = [(function_cache.FunctionCacheKey(MockShape(1, 1, 1),
                                             MockEmptyCaptureSnapshot(),
                                             ctx), "a"),
            (function_cache.FunctionCacheKey(MockShape(1, None, 1),
                                             MockEmptyCaptureSnapshot(),
                                             ctx), "b"),
            (function_cache.FunctionCacheKey(MockShape(None, None, 1),
                                             MockEmptyCaptureSnapshot(),
                                             ctx), "c"),
            (function_cache.FunctionCacheKey(MockShape(None, None, None),
                                             MockEmptyCaptureSnapshot(),
                                             ctx), "d")]

    for permutation in itertools.permutations(keys):
      cache = function_cache.FunctionCache()
      cache.add(permutation[0][0], trace_type.WeakrefDeletionObserver(),
                permutation[0][1])
      cache.add(permutation[1][0], trace_type.WeakrefDeletionObserver(),
                permutation[1][1])
      cache.add(permutation[2][0], trace_type.WeakrefDeletionObserver(),
                permutation[2][1])
      cache.add(permutation[3][0], trace_type.WeakrefDeletionObserver(),
                permutation[3][1])

      self.assertEqual(
          cache.lookup(
              function_cache.FunctionCacheKey(MockShape(1, 1, 1),
                                              MockEmptyCaptureSnapshot(),
                                              ctx), True),
          "a")
      self.assertEqual(
          cache.lookup(
              function_cache.FunctionCacheKey(MockShape(1, 2, 1),
                                              MockEmptyCaptureSnapshot(),
                                              ctx), True),
          "b")
      self.assertEqual(
          cache.lookup(
              function_cache.FunctionCacheKey(MockShape(2, 2, 1),
                                              MockEmptyCaptureSnapshot(),
                                              ctx), True),
          "c")
      self.assertEqual(
          cache.lookup(
              function_cache.FunctionCacheKey(MockShape(2, 2, 2),
                                              MockEmptyCaptureSnapshot(),
                                              ctx), True),
          "d")
Esempio n. 11
0
def make_cache_key(
    args: Any,
    captures: Any = None,
) -> Tuple[function_cache.FunctionCacheKey,
           trace_type.WeakrefDeletionObserver]:
    """Computes the cache key given the function arguments."""
    if captures is None:
        captures = dict()
    signature_context = trace_type.InternalTracingContext()
    args_signature = trace_type.from_object(args, signature_context)
    captures_dict_tracetype = trace_type.from_object(captures,
                                                     signature_context)
    captures_signature = function_cache.CaptureSnapshot(
        captures_dict_tracetype.mapping)

    return function_cache.FunctionCacheKey(
        args_signature, captures_signature,
        make_function_context()), signature_context.deletion_observer
Esempio n. 12
0
    def benchmarkCacheHit50thKeyUnknownSubtype(self):
        # If there are 50 keys and we get a key that has a subtype in cache but
        # the cache has never observed the key before (no memory for the subtype).

        cache = function_cache.FunctionCache()
        args_per_call = 5
        num_total_checks = 50

        keys = []
        for i in range(num_total_checks - 1):
            args = []
            for j in range(args_per_call):
                args.append(array_ops.zeros([i, j]))
            keys.append(function_context.make_cache_key(args))

        def setup():
            cache.clear()
            for key in keys:
                cache.add(*key, "testing")
            cache.add(function_cache.FunctionCacheKey(MockSubtypeOf2(3), None),
                      trace_type.WeakrefDeletionObserver(), "testing")

        iterations = 10000
        lookup_key = function_cache.FunctionCacheKey(MockSubtypeOf2(2), None)
        subtyping_time = sum(
            timeit.repeat(stmt=lambda: cache.lookup(lookup_key, True),
                          setup=setup,
                          repeat=iterations,
                          number=1))

        self.report_benchmark(name="cache_hit_50th_key_unknown_subtype",
                              iters=iterations,
                              wall_time=subtyping_time,
                              metrics=[{
                                  "name":
                                  "cache_hit_50th_key_unknown_subtype_avg_ms",
                                  "value":
                                  subtyping_time / iterations * 1000
                              }])