def tensorflow_random_state(seed: int) -> Generator[None, None, None]: # Save values origin_gpu_det = os.environ.get("TF_DETERMINISTIC_OPS", None) orig_random_state = random.getstate() orig_np_random_state = np.random.get_state() if context.executing_eagerly(): tf_random_seed = context.global_seed() else: tf_random_seed = ops.get_default_graph().seed determism_enabled = config.is_op_determinism_enabled() config.enable_op_determinism() # Set values os.environ["TF_DETERMINISTIC_OPS"] = "1" random.seed(seed) np.random.seed(seed) tf.random.set_seed(seed) yield # Reset values if origin_gpu_det is not None: os.environ["TF_DETERMINISTIC_OPS"] = origin_gpu_det else: os.environ.pop("TF_DETERMINISTIC_OPS") random.setstate(orig_random_state) np.random.set_state(orig_np_random_state) tf.random.set_seed(tf_random_seed) if not determism_enabled: config.disable_op_determinism()
def __init__(self, name, collections=None): """Construct a new FuncGraph. The graph will inherit its graph key, collections, seed, and distribution strategy stack from the current context or graph. Args: name: the name of the function. collections: a dictionary of collections this FuncGraph should start with. If not specified (None), the FuncGraph will read (but not write to) the outer graph's collections that are not whitelisted, and both read and write to the outer graph's collections that are whitelisted. The current whitelisted collections are the global variables, the local variables, and the trainable variables. Defaults to None. """ super(FuncGraph, self).__init__() self.name = name self.inputs = [] self.outputs = [] self.structured_input_signature = None self.structured_outputs = None self._weak_variables = [] self.outer_graph = ops.get_default_graph() self.captures = py_collections.OrderedDict() self._building_function = True # Map from resource tensor name to last op (in program order) which uses # this tensor. Used to enforce that execution order matches program order # for resource tensors. self._last_op_using_resource_tensor = {} graph = self.outer_graph if context.executing_eagerly(): self.seed = context.global_seed() device_type = context.context().device_spec.device_type self._xla_compile = (device_type == "TPU" or device_type == "XLA_GPU" or device_type == "XLA_CPU") else: self.seed = graph.seed self._xla_compile = getattr(graph, "_xla_compile", False) # TODO(allenl): Figure out if we can remove colocation stack # specialization (currently used in cond_v2), here and in the cache key. self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access if collections is None: for collection_name in graph.get_all_collection_keys(): if collection_name not in WHITELIST_COLLECTIONS: self._collections[collection_name] = graph.get_collection( collection_name) for collection_name in WHITELIST_COLLECTIONS: self._collections[collection_name] = graph.get_collection_ref( collection_name) else: self._collections = collections
def get_seed(op_seed): """Returns the local seeds an operation should use given an op-specific seed. Given operation-specific seed, `op_seed`, this helper function returns two seeds derived from graph-level and op-level seeds. Many random operations internally use the two seeds to allow user to change the seed globally for a graph, or for only specific operations. For details on how the graph-level seed interacts with op seeds, see `tf.compat.v1.random.set_random_seed`. Args: op_seed: integer. Returns: A tuple of two integers that should be used for the local seed of this operation. """ eager = context.executing_eagerly() if eager: global_seed = context.global_seed() else: global_seed = ops.get_default_graph().seed if global_seed is not None: if op_seed is None: # pylint: disable=protected-access if hasattr(ops.get_default_graph(), '_seed_used'): ops.get_default_graph()._seed_used = True if eager: op_seed = context.internal_operation_seed() else: op_seed = _graph_to_seed_dict.setdefault( ops.get_default_graph(), 0) _graph_to_seed_dict[ops.get_default_graph()] += 1 seeds = _truncate_seed(global_seed), _truncate_seed(op_seed) else: if op_seed is not None: seeds = DEFAULT_GRAPH_SEED, _truncate_seed(op_seed) else: seeds = None, None if seeds == (None, None) and config.deterministic_ops_enabled(): raise RuntimeError( # pylint: disable=g-doc-exception 'Random ops require a seed to be set when determinism is enabled. ' 'Please set a seed before running the op, e.g. by calling ' 'tf.random.set_seed(1).') # Avoid (0, 0) as the C++ ops interpret it as nondeterminism, which would # be unexpected since Python docs say nondeterminism is (None, None). if seeds == (0, 0): return (0, _MAXINT32) return seeds
def get_seed(op_seed): """Returns the local seeds an operation should use given an op-specific seed. Given operation-specific seed, `op_seed`, this helper function returns two seeds derived from graph-level and op-level seeds. Many random operations internally use the two seeds to allow user to change the seed globally for a graph, or for only specific operations. For details on how the graph-level seed interacts with op seeds, see `tf.random.set_random_seed`. Args: op_seed: integer. Returns: A tuple of two integers that should be used for the local seed of this operation. """ eager = context.executing_eagerly() if eager: global_seed = context.global_seed() else: global_seed = ops.get_default_graph().seed if global_seed is not None: if op_seed is None: # pylint: disable=protected-access if hasattr(ops.get_default_graph(), '_seed_used'): ops.get_default_graph()._seed_used = True if eager: op_seed = context.internal_operation_seed() else: op_seed = ops.get_default_graph()._last_id seeds = _truncate_seed(global_seed), _truncate_seed(op_seed) else: if op_seed is not None: seeds = DEFAULT_GRAPH_SEED, _truncate_seed(op_seed) else: seeds = None, None # Avoid (0, 0) as the C++ ops interpret it as nondeterminism, which would # be unexpected since Python docs say nondeterminism is (None, None). if seeds == (0, 0): return (0, _MAXINT32) return seeds
def get_seed(op_seed): """Returns the local seeds an operation should use given an op-specific seed. Given operation-specific seed, `op_seed`, this helper function returns two seeds derived from graph-level and op-level seeds. Many random operations internally use the two seeds to allow user to change the seed globally for a graph, or for only specific operations. For details on how the graph-level seed interacts with op seeds, see `tf.compat.v1.random.set_random_seed`. Args: op_seed: integer. Returns: A tuple of two integers that should be used for the local seed of this operation. """ eager = context.executing_eagerly() if eager: global_seed = context.global_seed() else: global_seed = ops.get_default_graph().seed if global_seed is not None: if op_seed is None: # pylint: disable=protected-access if hasattr(ops.get_default_graph(), '_seed_used'): ops.get_default_graph()._seed_used = True if eager: op_seed = context.internal_operation_seed() else: op_seed = ops.get_default_graph()._last_id seeds = _truncate_seed(global_seed), _truncate_seed(op_seed) else: if op_seed is not None: seeds = DEFAULT_GRAPH_SEED, _truncate_seed(op_seed) else: seeds = None, None # Avoid (0, 0) as the C++ ops interpret it as nondeterminism, which would # be unexpected since Python docs say nondeterminism is (None, None). if seeds == (0, 0): return (0, _MAXINT32) return seeds
def SecureGetRandomSeed(op_seed=None): """Returns the local seeds an operation should use given an op-specific seed. Args: op_seed: integer. Returns: A tuple of two integers that should be used for the local seed of this operation. """ eager = context.executing_eagerly() if eager: global_seed = context.global_seed() else: global_seed = ops.get_default_graph().seed seeds = (0, 0) if global_seed is not None: if op_seed is None: # pylint: disable=protected-access if hasattr(ops.get_default_graph(), '_seed_used'): ops.get_default_graph()._seed_used = True if eager: op_seed = context.internal_operation_seed() else: op_seed = ops.get_default_graph()._last_id seeds = random_seed._truncate_seed( global_seed), random_seed._truncate_seed(op_seed) else: if op_seed is not None: seeds = random_seed.DEFAULT_GRAPH_SEED, random_seed._truncate_seed( op_seed) else: if py_protocol_handler.is_activated(): seeds = py_protocol_handler.rand_seed( 0), py_protocol_handler.rand_seed(0) # Avoid (0, 0) as the C++ ops interpret it as nondeterminism, which would # random_seedbe unexpected since Python docs say nondeterminism is (None, None). if seeds == (0, 0): return (0, random_seed._MAXINT32) return seeds
def __init__(self, name, collections=None, capture_by_value=None): """Construct a new FuncGraph. The graph will inherit its graph key, collections, seed, and distribution strategy stack from the current context or graph. Args: name: the name of the function. collections: a dictionary of collections this FuncGraph should start with. If not specified (None), the FuncGraph will read (but not write to) the outer graph's collections that are not whitelisted, and both read and write to the outer graph's collections that are whitelisted. The current whitelisted collections are the global variables, the local variables, and the trainable variables. Defaults to None. capture_by_value: An optional boolean. If True, the func graph will capture Variables by value instead of reference. By default inherit from outer graphs, and failing that will default to False. """ super(FuncGraph, self).__init__() self.name = name self.inputs = [] self.outputs = [] self.control_outputs = [] self.control_captures = set() self.structured_input_signature = None self.structured_outputs = None self._weak_variables = [] self._watched_variables = weakref.WeakSet() self.outer_graph = ops.get_default_graph() self.captures = py_collections.OrderedDict() # Inherit capture-by-value from outer graph. if capture_by_value is not None: self.capture_by_value = capture_by_value elif self.outer_graph is not None and isinstance( self.outer_graph, FuncGraph): self.capture_by_value = self.outer_graph.capture_by_value else: self.capture_by_value = False self._building_function = True # Map from resource tensor name to last op (in program order) which uses # this tensor. Used to enforce that execution order matches program order # for resource tensors. self._last_op_using_resource_tensor = {} graph = self.outer_graph if context.executing_eagerly(): self.seed = context.global_seed() # [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of # any None op_seed for random_op in the function, in which case we end up # using function seed, which could be unintended behavior for the op. self._seed_used = False else: self.seed = graph.seed self._seed_used = False # TODO(allenl): Figure out if we can remove colocation stack # specialization (currently used in cond_v2), here and in the cache key. self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access if collections is None: for collection_name in graph.get_all_collection_keys(): if collection_name not in WHITELIST_COLLECTIONS: self._collections[collection_name] = graph.get_collection( collection_name) for collection_name in WHITELIST_COLLECTIONS: self._collections[collection_name] = graph.get_collection_ref( collection_name) else: self._collections = collections
def __init__(self, name, read_only_collections=True): """Construct a new FuncGraph. The graph will inherit its graph key, collections, seed, and distribution strategy stack from the current context or graph. Args: name: the name of the function. read_only_collections: whether to not write function graph collections back to default graph. Defaults to True. """ super(FuncGraph, self).__init__() self.name = name self.inputs = [] self.outputs = [] self.structured_outputs = None self._read_only_collections = read_only_collections self._weak_variables = [] self.outer_graph = ops.get_default_graph() self.captures = collections.OrderedDict() self._building_function = True # Map from resource tensor name to last op (in program order) which uses # this tensor. Used to enforce that execution order matches program order # for resource tensors. self._last_op_using_resource_tensor = {} graph = self.outer_graph # pylint: disable=protected-access # TODO(b/112906995, nareshmodi): distribution strategy depends on inheriting # this stack from the default graph even in eager mode. Maybe it should be # part of the eager context? This would also allow us to remove a # get_default_graph() call from the function cache lookup. self._distribution_strategy_stack = graph._distribution_strategy_stack # We ignore device placements from any outer scopes while tracing the # function when possible, to avoid hard-coding them in the function # graph. "Default" placements come from the PartitionedCallOp's placement, # so that the same trace of the Python function may be placed on several # different devices and saved functions may be placed on new devices when # restored. if context.executing_eagerly(): self.seed = context.global_seed() self._xla_compile = (context.context().device_spec.device_type == "TPU") if self._distribution_strategy_stack or self._xla_compile: self._add_device_to_stack(context.context().device_name) else: self.seed = graph.seed self._xla_compile = getattr(graph, "_xla_compile", False) # TODO(allenl): Figure out if we can remove colocation stack # specialization (currently used in cond_v2), here and in the cache key. self._colocation_stack = graph._colocation_stack.copy() if (self._distribution_strategy_stack or self._xla_compile or device_stack_has_callable(graph._device_function_stack)): # Hard-code devices from device functions in the function body self._device_function_stack = graph._device_function_stack.copy() if not self._read_only_collections: self._collections = graph._collections else: for collection_name in graph.get_all_collection_keys(): if collection_name not in WHITELIST_COLLECTIONS: self._collections[collection_name] = graph.get_collection( collection_name) for collection_name in WHITELIST_COLLECTIONS: self._collections[collection_name] = graph.get_collection_ref( collection_name) self._variable_creator_stack = graph._variable_creator_stack # Inherit the graph key, since this is used for matching variables in # optimizers. self._graph_key = graph._graph_key
def __init__(self, name, read_only_collections=True): """Construct a new FuncGraph. The graph will inherit its graph key, collections, seed, and distribution strategy stack from the current context or graph. Args: name: the name of the function. read_only_collections: whether to not write function graph collections back to default graph. Defaults to True. """ super(FuncGraph, self).__init__() self.name = name self.inputs = [] self.outputs = [] self.structured_outputs = None self._read_only_collections = read_only_collections self._weak_variables = [] self.outer_graph = ops.get_default_graph() self.captures = collections.OrderedDict() self._building_function = True # Map from resource tensor name to last op (in program order) which uses # this tensor. Used to enforce that execution order matches program order # for resource tensors. self._last_op_using_resource_tensor = {} graph = self.outer_graph # pylint: disable=protected-access # TODO(b/112906995, nareshmodi): distribution strategy depends on inheriting # this stack from the default graph even in eager mode. Maybe it should be # part of the eager context? This would also allow us to remove a # get_default_graph() call from the function cache lookup. self._distribution_strategy_stack = list( graph._distribution_strategy_stack) # We ignore device placements from any outer scopes while tracing the # function when possible, to avoid hard-coding them in the function # graph. "Default" placements come from the PartitionedCallOp's placement, # so that the same trace of the Python function may be placed on several # different devices and saved functions may be placed on new devices when # restored. if context.executing_eagerly(): self.seed = context.global_seed() self._xla_compile = ( context.context().device_spec.device_type == "TPU") if self._distribution_strategy_stack or self._xla_compile: self._add_device_to_stack(context.context().device_name) else: self.seed = graph.seed self._xla_compile = getattr(graph, "_xla_compile", False) # TODO(allenl): Figure out if we can remove colocation stack # specialization (currently used in cond_v2), here and in the cache key. self._colocation_stack = graph._colocation_stack.copy() if (self._distribution_strategy_stack or self._xla_compile or device_stack_has_callable(graph._device_function_stack)): # Hard-code devices from device functions in the function body self._device_function_stack = graph._device_function_stack.copy( ) if not self._read_only_collections: self._collections = graph._collections else: for collection_name in graph.get_all_collection_keys(): if collection_name not in WHITELIST_COLLECTIONS: self._collections[collection_name] = graph.get_collection( collection_name) for collection_name in WHITELIST_COLLECTIONS: self._collections[collection_name] = graph.get_collection_ref( collection_name) self._variable_creator_stack = graph._variable_creator_stack # Inherit the graph key, since this is used for matching variables in # optimizers. self._graph_key = graph._graph_key