def thread_fn(): self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options() options = save_options.SaveOptions(save_debug_info=False) with save_context.save_context(options): self.assertTrue(save_context.in_save_context()) # save_debug_info has a different value in this thread. self.assertFalse(save_context.get_save_options().save_debug_info) entered_context_in_thread.set() continue_thread.wait() self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options()
def test_multi_thread(self): self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options() options = save_options.SaveOptions(save_debug_info=True) with save_context.save_context(options): self.assertTrue(save_context.in_save_context()) self.assertTrue(save_context.get_save_options().save_debug_info) entered_context_in_thread = threading.Event() continue_thread = threading.Event() def thread_fn(): self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options() options = save_options.SaveOptions(save_debug_info=False) with save_context.save_context(options): self.assertTrue(save_context.in_save_context()) # save_debug_info has a different value in this thread. self.assertFalse(save_context.get_save_options().save_debug_info) entered_context_in_thread.set() continue_thread.wait() self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options() t = threading.Thread(target=thread_fn) t.start() entered_context_in_thread.wait() # Another thread shouldn't affect this thread. self.assertTrue(save_context.in_save_context()) self.assertTrue(save_context.get_save_options().save_debug_info) continue_thread.set() t.join() # Another thread exiting SaveContext shouldn't affect this thread. self.assertTrue(save_context.in_save_context()) self.assertTrue(save_context.get_save_options().save_debug_info) self.assertFalse(save_context.in_save_context()) with self.assertRaisesRegex(ValueError, 'not in a SaveContext'): save_context.get_save_options()
def resource_handle(self): if context.executing_eagerly() or save_context.in_save_context(): return self._coordinator_instance.resource_handle else: self._maybe_build_distributed_table() closure, spec = self.resource_handle_call_time_value() return ops.get_default_graph().capture_call_time_value( closure, spec, default_value=self._coordinator_instance.resource_handle)
def restore(self, restored_tensors, restored_shapes): restored_tensor = restored_tensors[0] if restored_shapes is not None: restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) # Copy the restored tensor to the variable's device. device = "" if save_context.in_save_context() else self._var_device with ops.device(device): restored_tensor = array_ops.identity(restored_tensor) return resource_variable_ops.shape_safe_assign_variable_handle( self.handle_op, self._var_shape, restored_tensor)
def __init__(self, strategy, wrapped_creator): self._coordinator_instance = wrapped_creator() self._wrapped_creator = wrapped_creator self._coordinator = strategy._cluster_coordinator # self._distributed_table is a RemoteValue mapping worker_index to # RemoteValue that wraps a resource handle on the worker self._distributed_table = None self._distributed_table_creation_lock = threading.Lock() if not save_context.in_save_context(): self._maybe_build_distributed_table()
def handle(self): if save_context.in_save_context() or context.executing_eagerly(): return self._vars[0].handle if tpu_util.enclosing_tpu_context() is None: raise NotImplementedError('TPUReplicatedVariable.handle is not available ' 'outside tpu context or save context') else: with tpu_util.outside_or_skip_tpu_context(): return xla_sharding.replicate( tpu_partition_ops.tpu_partitioned_input( [v.handle for v in self._vars], partition_dim=-1))
def __init__(self, strategy, wrapped_creator): distribute_lib.distribution_strategy_input_api_counter.get_cell( self.__class__.__name__, "PSSDistributedLookupTable").increase_by(1) self._coordinator_instance = wrapped_creator() self._wrapped_creator = wrapped_creator self._coordinator = strategy._cluster_coordinator # self._distributed_table is a RemoteValue mapping worker_index to # RemoteValue that wraps a resource handle on the worker self._distributed_table = None self._distributed_table_creation_lock = threading.Lock() if not save_context.in_save_context(): self._maybe_build_distributed_table()
def is_saving_non_distributed(): """Returns whether we're saving a non-distributed version of the model. It returns True iff we are in saving context and are saving a non-distributed version of the model. That is, SaveOptions.experimental_variable_policy is NONE. Returns: A boolean. """ if not save_context.in_save_context(): return False options = save_context.get_save_options() return (options is not None and options.experimental_variable_policy != save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)
def mark_as_unsaveable(): """Marks the function as unsaveable if not inside save context.""" if ops.inside_function() and not save_context.in_save_context(): ops.get_default_graph().mark_as_unsaveable(""" ConcreteFunction that uses distributed variables in certain way cannot be saved. If you're saving with tf.saved_model.save(..., signatures=f.get_concrete_function()) do @tf.function(input_signature=...) def f_with_input_signature(): ... tf.saved_model.save(..., signatures=f_with_input_signature)` instead.""")
def _make_execution_context() -> ExecutionContext: """Generates an ExecutionContext based on current contextual info.""" ctx = context.context() # Don't need to open an init_scope if the _cache_key call is in eager mode # already. executing_eagerly = ctx.executing_eagerly() parent_graph = None xla_context_id = 0 if not executing_eagerly: # We want to force function retracing for each different # XLAControlFlowContext, so add `xla_context_id` to the cache key. xla_context = _enclosing_xla_context() if xla_context is not None and xla_context.RequiresUniqueFunctionRetracing( ): xla_context_id = id(xla_context) with ops.init_scope(): # The graph, or whether we're executing eagerly, should be a part of the # cache key so we don't improperly capture tensors such as variables. executing_eagerly = ctx.executing_eagerly() parent_graph = None if executing_eagerly else ops.get_default_graph( ) # pylint: disable=protected-access default_graph = ops.get_default_graph() # TODO(b/117617952): The current distribution strategy will affect graph # building (e.g. accessing different variables from different devices) and # so requires retracing for each device. strategy_stack = default_graph._distribution_strategy_stack uses_distribution_strategy = (strategy_stack and strategy_stack[-1].strategy.extended. _retrace_functions_for_each_device) if executing_eagerly: colocation_stack = () if uses_distribution_strategy: device_functions = (pydev.merge_device(ctx.device_name), ) else: device_functions = () else: colocation_stack = tuple(default_graph._colocation_stack.peek_objs()) if (uses_distribution_strategy or func_graph_module.device_stack_has_callable( default_graph._device_function_stack)): # Putting the device in the cache key ensures that call-site device # annotations are respected. device_functions = tuple( default_graph._device_functions_outer_to_inner) else: device_functions = () in_cross_replica_context = False try: in_cross_replica_context = (strategy_stack[-1].replica_context is None) # pylint: disable=protected-access except (AttributeError, IndexError): pass if save_context.in_save_context(): variable_policy = ( save_context.get_save_options().experimental_variable_policy) else: variable_policy = None return ExecutionContext(parent_graph, device_functions, colocation_stack, in_cross_replica_context, variable_policy, xla_context_id)
def variables(self): """The list of `Variable`s that make up the shards of this object.""" if save_context.in_save_context(): return [self._saving_variable] return self._variables
def variables(self): """The list of `Variables`.""" if save_context.in_save_context(): return [self._vars[0]] return self._vars