コード例 #1
0
    def testFunctionCacheKeyRespectsSupertype(self):
        ctx = function_cache.FunctionContext(0)
        key_a = function_cache.FunctionCacheKey(MockSupertypes2With3(1), ctx)
        key_b = function_cache.FunctionCacheKey(MockSupertypes2With3(2), ctx)

        self.assertEqual(
            key_b.most_specific_common_supertype([key_a]),
            function_cache.FunctionCacheKey(MockSupertypes2With3(3), ctx))
        self.assertIsNone(key_a.most_specific_common_supertype([key_b]))
コード例 #2
0
    def testFunctionCacheKeyRespectsSubtype(self):
        ctx = function_cache.FunctionContext(0)
        key_a = function_cache.FunctionCacheKey(MockSubtypeOf2(1), ctx)
        key_b = function_cache.FunctionCacheKey(MockSubtypeOf2(2), ctx)
        key_c = function_cache.FunctionCacheKey(MockSubtypeOf2(1), ctx)

        self.assertTrue(key_a.is_subtype_of(key_b))
        self.assertFalse(key_b.is_subtype_of(key_a))
        self.assertFalse(key_a.is_subtype_of(key_c))
コード例 #3
0
    def testFunctionCacheKeyRespectsEquality(self):
        ctx = function_cache.FunctionContext(0)
        generic = MockGenericType
        key_a = function_cache.FunctionCacheKey(generic(1), ctx)
        key_b = function_cache.FunctionCacheKey(generic(2), ctx)
        key_c = function_cache.FunctionCacheKey(generic(1), ctx)

        self.assertNotEqual(key_a, key_b)
        self.assertEqual(key_a, key_c)
        self.assertEqual(hash(key_a), hash(key_c))
コード例 #4
0
    def testFirstMostSpecificFunctionCacheKeyIsLookedUp(self):
        ctx = function_cache.FunctionContext(0)
        cache = function_cache.FunctionCache()
        cache.add(function_cache.FunctionCacheKey(MockShape(1, 2, None), ctx),
                  trace_type.WeakrefDeletionObserver(), "a")
        cache.add(function_cache.FunctionCacheKey(MockShape(1, None, 3), ctx),
                  trace_type.WeakrefDeletionObserver(), "b")

        self.assertEqual(
            cache.lookup(
                function_cache.FunctionCacheKey(MockShape(1, 2, 3), ctx),
                True), "a")
コード例 #5
0
  def testMostSpecificFunctionCacheKeyIsOrderAgnostic(self):
    ctx = function_cache.FunctionContext(0)
    keys = [(function_cache.FunctionCacheKey(MockShape(1, 1, 1),
                                             MockEmptyCaptureSnapshot(),
                                             ctx), "a"),
            (function_cache.FunctionCacheKey(MockShape(1, None, 1),
                                             MockEmptyCaptureSnapshot(),
                                             ctx), "b"),
            (function_cache.FunctionCacheKey(MockShape(None, None, 1),
                                             MockEmptyCaptureSnapshot(),
                                             ctx), "c"),
            (function_cache.FunctionCacheKey(MockShape(None, None, None),
                                             MockEmptyCaptureSnapshot(),
                                             ctx), "d")]

    for permutation in itertools.permutations(keys):
      cache = function_cache.FunctionCache()
      cache.add(permutation[0][0], trace_type.WeakrefDeletionObserver(),
                permutation[0][1])
      cache.add(permutation[1][0], trace_type.WeakrefDeletionObserver(),
                permutation[1][1])
      cache.add(permutation[2][0], trace_type.WeakrefDeletionObserver(),
                permutation[2][1])
      cache.add(permutation[3][0], trace_type.WeakrefDeletionObserver(),
                permutation[3][1])

      self.assertEqual(
          cache.lookup(
              function_cache.FunctionCacheKey(MockShape(1, 1, 1),
                                              MockEmptyCaptureSnapshot(),
                                              ctx), True),
          "a")
      self.assertEqual(
          cache.lookup(
              function_cache.FunctionCacheKey(MockShape(1, 2, 1),
                                              MockEmptyCaptureSnapshot(),
                                              ctx), True),
          "b")
      self.assertEqual(
          cache.lookup(
              function_cache.FunctionCacheKey(MockShape(2, 2, 1),
                                              MockEmptyCaptureSnapshot(),
                                              ctx), True),
          "c")
      self.assertEqual(
          cache.lookup(
              function_cache.FunctionCacheKey(MockShape(2, 2, 2),
                                              MockEmptyCaptureSnapshot(),
                                              ctx), True),
          "d")
コード例 #6
0
def make_function_context() -> function_cache.FunctionContext:
  """Generates a FunctionContext based on current contextual info."""
  ctx = context.context()

  # Don't need to open an init_scope if the tf.function call is in eager mode
  # already.
  executing_eagerly = ctx.executing_eagerly()
  parent_graph = None
  xla_context_id = 0
  if not executing_eagerly:
    # We want to force function retracing for each different
    # XLAControlFlowContext, so add `xla_context_id` to the context.
    xla_context = _enclosing_xla_context()
    if xla_context is not None and xla_context.RequiresUniqueFunctionRetracing(
    ):
      xla_context_id = id(xla_context)

    with ops.init_scope():
      # The graph, or whether we're executing eagerly, should be a part of the
      # cache key so we don't improperly capture tensors such as variables.
      executing_eagerly = ctx.executing_eagerly()
      parent_graph = None if executing_eagerly else ops.get_default_graph()

  # pylint: disable=protected-access
  default_graph = ops.get_default_graph()
  # TODO(b/117617952): The current distribution strategy will affect graph
  # building (e.g. accessing different variables from different devices) and
  # so requires retracing for each device.
  strategy_stack = default_graph._distribution_strategy_stack
  uses_distribution_strategy = (
      strategy_stack and
      strategy_stack[-1].strategy.extended._retrace_functions_for_each_device)
  if executing_eagerly:
    colocation_stack = ()
    if uses_distribution_strategy:
      device_functions = (pydev.merge_device(ctx.device_name),)
    else:
      device_functions = ()
  else:
    colocation_stack = tuple(default_graph._colocation_stack.peek_objs())
    if (uses_distribution_strategy or
        func_graph_module.device_stack_has_callable(
            default_graph._device_function_stack)):
      # Putting the device in the cache key ensures that call-site device
      # annotations are respected.
      device_functions = tuple(default_graph._device_functions_outer_to_inner)
    else:
      device_functions = ()

  in_cross_replica_context = False
  try:
    in_cross_replica_context = (strategy_stack[-1].replica_context is None)  # pylint: disable=protected-access
  except (AttributeError, IndexError):
    pass

  if save_context.in_save_context():
    variable_policy = (
        save_context.get_save_options().experimental_variable_policy)
  else:
    variable_policy = None

  return function_cache.FunctionContext(
      EagerContext(parent_graph, device_functions, colocation_stack,
                   in_cross_replica_context, variable_policy, xla_context_id))