示例#1
0
      def thread_fn():
        self.assertFalse(save_context.in_save_context())
        with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
          save_context.get_save_options()

        options = save_options.SaveOptions(save_debug_info=False)
        with save_context.save_context(options):
          self.assertTrue(save_context.in_save_context())
          # save_debug_info has a different value in this thread.
          self.assertFalse(save_context.get_save_options().save_debug_info)
          entered_context_in_thread.set()
          continue_thread.wait()

        self.assertFalse(save_context.in_save_context())
        with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
          save_context.get_save_options()
示例#2
0
def is_saving_non_distributed():
    """Returns whether we're saving a non-distributed version of the model.

  It returns True iff we are in saving context and are saving a non-distributed
  version of the model. That is, SaveOptions.experimental_variable_policy is
  NONE.

  Returns:
    A boolean.
  """
    if not save_context.in_save_context():
        return False
    options = save_context.get_save_options()
    return (options is not None and options.experimental_variable_policy !=
            save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)
示例#3
0
  def test_multi_thread(self):
    self.assertFalse(save_context.in_save_context())
    with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
      save_context.get_save_options()

    options = save_options.SaveOptions(save_debug_info=True)
    with save_context.save_context(options):
      self.assertTrue(save_context.in_save_context())
      self.assertTrue(save_context.get_save_options().save_debug_info)

      entered_context_in_thread = threading.Event()
      continue_thread = threading.Event()

      def thread_fn():
        self.assertFalse(save_context.in_save_context())
        with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
          save_context.get_save_options()

        options = save_options.SaveOptions(save_debug_info=False)
        with save_context.save_context(options):
          self.assertTrue(save_context.in_save_context())
          # save_debug_info has a different value in this thread.
          self.assertFalse(save_context.get_save_options().save_debug_info)
          entered_context_in_thread.set()
          continue_thread.wait()

        self.assertFalse(save_context.in_save_context())
        with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
          save_context.get_save_options()

      t = threading.Thread(target=thread_fn)
      t.start()

      entered_context_in_thread.wait()
      # Another thread shouldn't affect this thread.
      self.assertTrue(save_context.in_save_context())
      self.assertTrue(save_context.get_save_options().save_debug_info)

      continue_thread.set()
      t.join()
      # Another thread exiting SaveContext shouldn't affect this thread.
      self.assertTrue(save_context.in_save_context())
      self.assertTrue(save_context.get_save_options().save_debug_info)

    self.assertFalse(save_context.in_save_context())
    with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
      save_context.get_save_options()
示例#4
0
def _make_execution_context() -> ExecutionContext:
    """Generates an ExecutionContext based on current contextual info."""
    ctx = context.context()

    # Don't need to open an init_scope if the _cache_key call is in eager mode
    # already.
    executing_eagerly = ctx.executing_eagerly()
    parent_graph = None
    xla_context_id = 0
    if not executing_eagerly:
        # We want to force function retracing for each different
        # XLAControlFlowContext, so add `xla_context_id` to the cache key.
        xla_context = _enclosing_xla_context()
        if xla_context is not None and xla_context.RequiresUniqueFunctionRetracing(
        ):
            xla_context_id = id(xla_context)

        with ops.init_scope():
            # The graph, or whether we're executing eagerly, should be a part of the
            # cache key so we don't improperly capture tensors such as variables.
            executing_eagerly = ctx.executing_eagerly()
            parent_graph = None if executing_eagerly else ops.get_default_graph(
            )

    # pylint: disable=protected-access
    default_graph = ops.get_default_graph()
    # TODO(b/117617952): The current distribution strategy will affect graph
    # building (e.g. accessing different variables from different devices) and
    # so requires retracing for each device.
    strategy_stack = default_graph._distribution_strategy_stack
    uses_distribution_strategy = (strategy_stack
                                  and strategy_stack[-1].strategy.extended.
                                  _retrace_functions_for_each_device)
    if executing_eagerly:
        colocation_stack = ()
        if uses_distribution_strategy:
            device_functions = (pydev.merge_device(ctx.device_name), )
        else:
            device_functions = ()
    else:
        colocation_stack = tuple(default_graph._colocation_stack.peek_objs())
        if (uses_distribution_strategy
                or func_graph_module.device_stack_has_callable(
                    default_graph._device_function_stack)):
            # Putting the device in the cache key ensures that call-site device
            # annotations are respected.
            device_functions = tuple(
                default_graph._device_functions_outer_to_inner)
        else:
            device_functions = ()

    in_cross_replica_context = False
    try:
        in_cross_replica_context = (strategy_stack[-1].replica_context is None)  # pylint: disable=protected-access
    except (AttributeError, IndexError):
        pass

    if save_context.in_save_context():
        variable_policy = (
            save_context.get_save_options().experimental_variable_policy)
    else:
        variable_policy = None

    return ExecutionContext(parent_graph, device_functions, colocation_stack,
                            in_cross_replica_context, variable_policy,
                            xla_context_id)