def testFunctionCacheKeyRespectsSupertype(self): ctx = function_cache.FunctionContext(0) key_a = function_cache.FunctionCacheKey(MockSupertypes2With3(1), ctx) key_b = function_cache.FunctionCacheKey(MockSupertypes2With3(2), ctx) self.assertEqual( key_b.most_specific_common_supertype([key_a]), function_cache.FunctionCacheKey(MockSupertypes2With3(3), ctx)) self.assertIsNone(key_a.most_specific_common_supertype([key_b]))
def testFunctionCacheKeyRespectsSubtype(self): ctx = function_cache.FunctionContext(0) key_a = function_cache.FunctionCacheKey(MockSubtypeOf2(1), ctx) key_b = function_cache.FunctionCacheKey(MockSubtypeOf2(2), ctx) key_c = function_cache.FunctionCacheKey(MockSubtypeOf2(1), ctx) self.assertTrue(key_a.is_subtype_of(key_b)) self.assertFalse(key_b.is_subtype_of(key_a)) self.assertFalse(key_a.is_subtype_of(key_c))
def testFunctionCacheKeyRespectsEquality(self): ctx = function_cache.FunctionContext(0) generic = MockGenericType key_a = function_cache.FunctionCacheKey(generic(1), ctx) key_b = function_cache.FunctionCacheKey(generic(2), ctx) key_c = function_cache.FunctionCacheKey(generic(1), ctx) self.assertNotEqual(key_a, key_b) self.assertEqual(key_a, key_c) self.assertEqual(hash(key_a), hash(key_c))
def testFirstMostSpecificFunctionCacheKeyIsLookedUp(self): ctx = function_cache.FunctionContext(0) cache = function_cache.FunctionCache() cache.add(function_cache.FunctionCacheKey(MockShape(1, 2, None), ctx), trace_type.WeakrefDeletionObserver(), "a") cache.add(function_cache.FunctionCacheKey(MockShape(1, None, 3), ctx), trace_type.WeakrefDeletionObserver(), "b") self.assertEqual( cache.lookup( function_cache.FunctionCacheKey(MockShape(1, 2, 3), ctx), True), "a")
def testMostSpecificFunctionCacheKeyIsOrderAgnostic(self): ctx = function_cache.FunctionContext(0) keys = [(function_cache.FunctionCacheKey(MockShape(1, 1, 1), MockEmptyCaptureSnapshot(), ctx), "a"), (function_cache.FunctionCacheKey(MockShape(1, None, 1), MockEmptyCaptureSnapshot(), ctx), "b"), (function_cache.FunctionCacheKey(MockShape(None, None, 1), MockEmptyCaptureSnapshot(), ctx), "c"), (function_cache.FunctionCacheKey(MockShape(None, None, None), MockEmptyCaptureSnapshot(), ctx), "d")] for permutation in itertools.permutations(keys): cache = function_cache.FunctionCache() cache.add(permutation[0][0], trace_type.WeakrefDeletionObserver(), permutation[0][1]) cache.add(permutation[1][0], trace_type.WeakrefDeletionObserver(), permutation[1][1]) cache.add(permutation[2][0], trace_type.WeakrefDeletionObserver(), permutation[2][1]) cache.add(permutation[3][0], trace_type.WeakrefDeletionObserver(), permutation[3][1]) self.assertEqual( cache.lookup( function_cache.FunctionCacheKey(MockShape(1, 1, 1), MockEmptyCaptureSnapshot(), ctx), True), "a") self.assertEqual( cache.lookup( function_cache.FunctionCacheKey(MockShape(1, 2, 1), MockEmptyCaptureSnapshot(), ctx), True), "b") self.assertEqual( cache.lookup( function_cache.FunctionCacheKey(MockShape(2, 2, 1), MockEmptyCaptureSnapshot(), ctx), True), "c") self.assertEqual( cache.lookup( function_cache.FunctionCacheKey(MockShape(2, 2, 2), MockEmptyCaptureSnapshot(), ctx), True), "d")
def make_function_context() -> function_cache.FunctionContext: """Generates a FunctionContext based on current contextual info.""" ctx = context.context() # Don't need to open an init_scope if the tf.function call is in eager mode # already. executing_eagerly = ctx.executing_eagerly() parent_graph = None xla_context_id = 0 if not executing_eagerly: # We want to force function retracing for each different # XLAControlFlowContext, so add `xla_context_id` to the context. xla_context = _enclosing_xla_context() if xla_context is not None and xla_context.RequiresUniqueFunctionRetracing( ): xla_context_id = id(xla_context) with ops.init_scope(): # The graph, or whether we're executing eagerly, should be a part of the # cache key so we don't improperly capture tensors such as variables. executing_eagerly = ctx.executing_eagerly() parent_graph = None if executing_eagerly else ops.get_default_graph() # pylint: disable=protected-access default_graph = ops.get_default_graph() # TODO(b/117617952): The current distribution strategy will affect graph # building (e.g. accessing different variables from different devices) and # so requires retracing for each device. strategy_stack = default_graph._distribution_strategy_stack uses_distribution_strategy = ( strategy_stack and strategy_stack[-1].strategy.extended._retrace_functions_for_each_device) if executing_eagerly: colocation_stack = () if uses_distribution_strategy: device_functions = (pydev.merge_device(ctx.device_name),) else: device_functions = () else: colocation_stack = tuple(default_graph._colocation_stack.peek_objs()) if (uses_distribution_strategy or func_graph_module.device_stack_has_callable( default_graph._device_function_stack)): # Putting the device in the cache key ensures that call-site device # annotations are respected. device_functions = tuple(default_graph._device_functions_outer_to_inner) else: device_functions = () in_cross_replica_context = False try: in_cross_replica_context = (strategy_stack[-1].replica_context is None) # pylint: disable=protected-access except (AttributeError, IndexError): pass if save_context.in_save_context(): variable_policy = ( save_context.get_save_options().experimental_variable_policy) else: variable_policy = None return function_cache.FunctionContext( EagerContext(parent_graph, device_functions, colocation_stack, in_cross_replica_context, variable_policy, xla_context_id))