Пример #1
0
      def thread_fn():
        self.assertFalse(save_context.in_save_context())
        with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
          save_context.get_save_options()

        options = save_options.SaveOptions(save_debug_info=False)
        with save_context.save_context(options):
          self.assertTrue(save_context.in_save_context())
          # save_debug_info has a different value in this thread.
          self.assertFalse(save_context.get_save_options().save_debug_info)
          entered_context_in_thread.set()
          continue_thread.wait()

        self.assertFalse(save_context.in_save_context())
        with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
          save_context.get_save_options()
Пример #2
0
  def test_multi_thread(self):
    self.assertFalse(save_context.in_save_context())
    with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
      save_context.get_save_options()

    options = save_options.SaveOptions(save_debug_info=True)
    with save_context.save_context(options):
      self.assertTrue(save_context.in_save_context())
      self.assertTrue(save_context.get_save_options().save_debug_info)

      entered_context_in_thread = threading.Event()
      continue_thread = threading.Event()

      def thread_fn():
        self.assertFalse(save_context.in_save_context())
        with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
          save_context.get_save_options()

        options = save_options.SaveOptions(save_debug_info=False)
        with save_context.save_context(options):
          self.assertTrue(save_context.in_save_context())
          # save_debug_info has a different value in this thread.
          self.assertFalse(save_context.get_save_options().save_debug_info)
          entered_context_in_thread.set()
          continue_thread.wait()

        self.assertFalse(save_context.in_save_context())
        with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
          save_context.get_save_options()

      t = threading.Thread(target=thread_fn)
      t.start()

      entered_context_in_thread.wait()
      # Another thread shouldn't affect this thread.
      self.assertTrue(save_context.in_save_context())
      self.assertTrue(save_context.get_save_options().save_debug_info)

      continue_thread.set()
      t.join()
      # Another thread exiting SaveContext shouldn't affect this thread.
      self.assertTrue(save_context.in_save_context())
      self.assertTrue(save_context.get_save_options().save_debug_info)

    self.assertFalse(save_context.in_save_context())
    with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
      save_context.get_save_options()
Пример #3
0
 def resource_handle(self):
   if context.executing_eagerly() or save_context.in_save_context():
     return self._coordinator_instance.resource_handle
   else:
     self._maybe_build_distributed_table()
     closure, spec = self.resource_handle_call_time_value()
     return ops.get_default_graph().capture_call_time_value(
         closure,
         spec,
         default_value=self._coordinator_instance.resource_handle)
Пример #4
0
 def restore(self, restored_tensors, restored_shapes):
     restored_tensor = restored_tensors[0]
     if restored_shapes is not None:
         restored_tensor = array_ops.reshape(restored_tensor,
                                             restored_shapes[0])
     # Copy the restored tensor to the variable's device.
     device = "" if save_context.in_save_context() else self._var_device
     with ops.device(device):
         restored_tensor = array_ops.identity(restored_tensor)
         return resource_variable_ops.shape_safe_assign_variable_handle(
             self.handle_op, self._var_shape, restored_tensor)
Пример #5
0
  def __init__(self, strategy, wrapped_creator):

    self._coordinator_instance = wrapped_creator()
    self._wrapped_creator = wrapped_creator
    self._coordinator = strategy._cluster_coordinator
    # self._distributed_table is a RemoteValue mapping worker_index to
    # RemoteValue that wraps a resource handle on the worker
    self._distributed_table = None
    self._distributed_table_creation_lock = threading.Lock()

    if not save_context.in_save_context():
      self._maybe_build_distributed_table()
  def handle(self):
    if save_context.in_save_context() or context.executing_eagerly():
      return self._vars[0].handle

    if tpu_util.enclosing_tpu_context() is None:
      raise NotImplementedError('TPUReplicatedVariable.handle is not available '
                                'outside tpu context or save context')
    else:
      with tpu_util.outside_or_skip_tpu_context():
        return xla_sharding.replicate(
            tpu_partition_ops.tpu_partitioned_input(
                [v.handle for v in self._vars], partition_dim=-1))
Пример #7
0
    def __init__(self, strategy, wrapped_creator):
        distribute_lib.distribution_strategy_input_api_counter.get_cell(
            self.__class__.__name__,
            "PSSDistributedLookupTable").increase_by(1)
        self._coordinator_instance = wrapped_creator()
        self._wrapped_creator = wrapped_creator
        self._coordinator = strategy._cluster_coordinator
        # self._distributed_table is a RemoteValue mapping worker_index to
        # RemoteValue that wraps a resource handle on the worker
        self._distributed_table = None
        self._distributed_table_creation_lock = threading.Lock()

        if not save_context.in_save_context():
            self._maybe_build_distributed_table()
Пример #8
0
def is_saving_non_distributed():
    """Returns whether we're saving a non-distributed version of the model.

  It returns True iff we are in saving context and are saving a non-distributed
  version of the model. That is, SaveOptions.experimental_variable_policy is
  NONE.

  Returns:
    A boolean.
  """
    if not save_context.in_save_context():
        return False
    options = save_context.get_save_options()
    return (options is not None and options.experimental_variable_policy !=
            save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)
Пример #9
0
def mark_as_unsaveable():
    """Marks the function as unsaveable if not inside save context."""
    if ops.inside_function() and not save_context.in_save_context():
        ops.get_default_graph().mark_as_unsaveable("""
ConcreteFunction that uses distributed variables in certain way cannot be saved.
If you're saving with

tf.saved_model.save(..., signatures=f.get_concrete_function())

do

@tf.function(input_signature=...)
def f_with_input_signature():
  ...

tf.saved_model.save(..., signatures=f_with_input_signature)`

instead.""")
Пример #10
0
def _make_execution_context() -> ExecutionContext:
    """Generates an ExecutionContext based on current contextual info."""
    ctx = context.context()

    # Don't need to open an init_scope if the _cache_key call is in eager mode
    # already.
    executing_eagerly = ctx.executing_eagerly()
    parent_graph = None
    xla_context_id = 0
    if not executing_eagerly:
        # We want to force function retracing for each different
        # XLAControlFlowContext, so add `xla_context_id` to the cache key.
        xla_context = _enclosing_xla_context()
        if xla_context is not None and xla_context.RequiresUniqueFunctionRetracing(
        ):
            xla_context_id = id(xla_context)

        with ops.init_scope():
            # The graph, or whether we're executing eagerly, should be a part of the
            # cache key so we don't improperly capture tensors such as variables.
            executing_eagerly = ctx.executing_eagerly()
            parent_graph = None if executing_eagerly else ops.get_default_graph(
            )

    # pylint: disable=protected-access
    default_graph = ops.get_default_graph()
    # TODO(b/117617952): The current distribution strategy will affect graph
    # building (e.g. accessing different variables from different devices) and
    # so requires retracing for each device.
    strategy_stack = default_graph._distribution_strategy_stack
    uses_distribution_strategy = (strategy_stack
                                  and strategy_stack[-1].strategy.extended.
                                  _retrace_functions_for_each_device)
    if executing_eagerly:
        colocation_stack = ()
        if uses_distribution_strategy:
            device_functions = (pydev.merge_device(ctx.device_name), )
        else:
            device_functions = ()
    else:
        colocation_stack = tuple(default_graph._colocation_stack.peek_objs())
        if (uses_distribution_strategy
                or func_graph_module.device_stack_has_callable(
                    default_graph._device_function_stack)):
            # Putting the device in the cache key ensures that call-site device
            # annotations are respected.
            device_functions = tuple(
                default_graph._device_functions_outer_to_inner)
        else:
            device_functions = ()

    in_cross_replica_context = False
    try:
        in_cross_replica_context = (strategy_stack[-1].replica_context is None)  # pylint: disable=protected-access
    except (AttributeError, IndexError):
        pass

    if save_context.in_save_context():
        variable_policy = (
            save_context.get_save_options().experimental_variable_policy)
    else:
        variable_policy = None

    return ExecutionContext(parent_graph, device_functions, colocation_stack,
                            in_cross_replica_context, variable_policy,
                            xla_context_id)
Пример #11
0
 def variables(self):
     """The list of `Variable`s that make up the shards of this object."""
     if save_context.in_save_context():
         return [self._saving_variable]
     return self._variables
Пример #12
0
 def variables(self):
     """The list of `Variables`."""
     if save_context.in_save_context():
         return [self._vars[0]]
     return self._vars