Esempio n. 1
0
def tensorflow_random_state(seed: int) -> Generator[None, None, None]:
    # Save values
    origin_gpu_det = os.environ.get("TF_DETERMINISTIC_OPS", None)
    orig_random_state = random.getstate()
    orig_np_random_state = np.random.get_state()
    if context.executing_eagerly():
        tf_random_seed = context.global_seed()
    else:
        tf_random_seed = ops.get_default_graph().seed

    determism_enabled = config.is_op_determinism_enabled()
    config.enable_op_determinism()

    # Set values
    os.environ["TF_DETERMINISTIC_OPS"] = "1"
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

    yield

    # Reset values
    if origin_gpu_det is not None:
        os.environ["TF_DETERMINISTIC_OPS"] = origin_gpu_det
    else:
        os.environ.pop("TF_DETERMINISTIC_OPS")
    random.setstate(orig_random_state)
    np.random.set_state(orig_np_random_state)
    tf.random.set_seed(tf_random_seed)
    if not determism_enabled:
        config.disable_op_determinism()
Esempio n. 2
0
    def __init__(self, name, collections=None):
        """Construct a new FuncGraph.

    The graph will inherit its graph key, collections, seed, and distribution
    strategy stack from the current context or graph.

    Args:
      name: the name of the function.
      collections: a dictionary of collections this FuncGraph should start
        with. If not specified (None), the FuncGraph will read (but not write
        to) the outer graph's collections that are not whitelisted, and both
        read and write to the outer graph's collections that are whitelisted.
        The current whitelisted collections are the global variables, the
        local variables, and the trainable variables.
        Defaults to None.
    """
        super(FuncGraph, self).__init__()

        self.name = name
        self.inputs = []
        self.outputs = []
        self.structured_input_signature = None
        self.structured_outputs = None
        self._weak_variables = []
        self.outer_graph = ops.get_default_graph()
        self.captures = py_collections.OrderedDict()

        self._building_function = True
        # Map from resource tensor name to last op (in program order) which uses
        # this tensor. Used to enforce that execution order matches program order
        # for resource tensors.
        self._last_op_using_resource_tensor = {}

        graph = self.outer_graph

        if context.executing_eagerly():
            self.seed = context.global_seed()
            device_type = context.context().device_spec.device_type
            self._xla_compile = (device_type == "TPU"
                                 or device_type == "XLA_GPU"
                                 or device_type == "XLA_CPU")
        else:
            self.seed = graph.seed
            self._xla_compile = getattr(graph, "_xla_compile", False)
            # TODO(allenl): Figure out if we can remove colocation stack
            # specialization (currently used in cond_v2), here and in the cache key.
            self._colocation_stack = graph._colocation_stack.copy()  # pylint: disable=protected-access

        if collections is None:
            for collection_name in graph.get_all_collection_keys():
                if collection_name not in WHITELIST_COLLECTIONS:
                    self._collections[collection_name] = graph.get_collection(
                        collection_name)
            for collection_name in WHITELIST_COLLECTIONS:
                self._collections[collection_name] = graph.get_collection_ref(
                    collection_name)
        else:
            self._collections = collections
Esempio n. 3
0
  def __init__(self, name, collections=None):
    """Construct a new FuncGraph.

    The graph will inherit its graph key, collections, seed, and distribution
    strategy stack from the current context or graph.

    Args:
      name: the name of the function.
      collections: a dictionary of collections this FuncGraph should start
        with. If not specified (None), the FuncGraph will read (but not write
        to) the outer graph's collections that are not whitelisted, and both
        read and write to the outer graph's collections that are whitelisted.
        The current whitelisted collections are the global variables, the
        local variables, and the trainable variables.
        Defaults to None.
    """
    super(FuncGraph, self).__init__()

    self.name = name
    self.inputs = []
    self.outputs = []
    self.structured_input_signature = None
    self.structured_outputs = None
    self._weak_variables = []
    self.outer_graph = ops.get_default_graph()
    self.captures = py_collections.OrderedDict()

    self._building_function = True
    # Map from resource tensor name to last op (in program order) which uses
    # this tensor. Used to enforce that execution order matches program order
    # for resource tensors.
    self._last_op_using_resource_tensor = {}

    graph = self.outer_graph

    if context.executing_eagerly():
      self.seed = context.global_seed()
      device_type = context.context().device_spec.device_type
      self._xla_compile = (device_type == "TPU" or device_type == "XLA_GPU"
                           or device_type == "XLA_CPU")
    else:
      self.seed = graph.seed
      self._xla_compile = getattr(graph, "_xla_compile", False)
      # TODO(allenl): Figure out if we can remove colocation stack
      # specialization (currently used in cond_v2), here and in the cache key.
      self._colocation_stack = graph._colocation_stack.copy()  # pylint: disable=protected-access

    if collections is None:
      for collection_name in graph.get_all_collection_keys():
        if collection_name not in WHITELIST_COLLECTIONS:
          self._collections[collection_name] = graph.get_collection(
              collection_name)
      for collection_name in WHITELIST_COLLECTIONS:
        self._collections[collection_name] = graph.get_collection_ref(
            collection_name)
    else:
      self._collections = collections
Esempio n. 4
0
def get_seed(op_seed):
    """Returns the local seeds an operation should use given an op-specific seed.

  Given operation-specific seed, `op_seed`, this helper function returns two
  seeds derived from graph-level and op-level seeds. Many random operations
  internally use the two seeds to allow user to change the seed globally for a
  graph, or for only specific operations.

  For details on how the graph-level seed interacts with op seeds, see
  `tf.compat.v1.random.set_random_seed`.

  Args:
    op_seed: integer.

  Returns:
    A tuple of two integers that should be used for the local seed of this
    operation.
  """
    eager = context.executing_eagerly()

    if eager:
        global_seed = context.global_seed()
    else:
        global_seed = ops.get_default_graph().seed

    if global_seed is not None:
        if op_seed is None:
            # pylint: disable=protected-access
            if hasattr(ops.get_default_graph(), '_seed_used'):
                ops.get_default_graph()._seed_used = True
            if eager:
                op_seed = context.internal_operation_seed()
            else:
                op_seed = _graph_to_seed_dict.setdefault(
                    ops.get_default_graph(), 0)
                _graph_to_seed_dict[ops.get_default_graph()] += 1

        seeds = _truncate_seed(global_seed), _truncate_seed(op_seed)
    else:
        if op_seed is not None:
            seeds = DEFAULT_GRAPH_SEED, _truncate_seed(op_seed)
        else:
            seeds = None, None

    if seeds == (None, None) and config.deterministic_ops_enabled():
        raise RuntimeError(  # pylint: disable=g-doc-exception
            'Random ops require a seed to be set when determinism is enabled. '
            'Please set a seed before running the op, e.g. by calling '
            'tf.random.set_seed(1).')

    # Avoid (0, 0) as the C++ ops interpret it as nondeterminism, which would
    # be unexpected since Python docs say nondeterminism is (None, None).
    if seeds == (0, 0):
        return (0, _MAXINT32)
    return seeds
Esempio n. 5
0
def get_seed(op_seed):
    """Returns the local seeds an operation should use given an op-specific seed.

  Given operation-specific seed, `op_seed`, this helper function returns two
  seeds derived from graph-level and op-level seeds. Many random operations
  internally use the two seeds to allow user to change the seed globally for a
  graph, or for only specific operations.

  For details on how the graph-level seed interacts with op seeds, see
  `tf.random.set_random_seed`.

  Args:
    op_seed: integer.

  Returns:
    A tuple of two integers that should be used for the local seed of this
    operation.
  """
    eager = context.executing_eagerly()

    if eager:
        global_seed = context.global_seed()
    else:
        global_seed = ops.get_default_graph().seed

    if global_seed is not None:
        if op_seed is None:
            # pylint: disable=protected-access
            if hasattr(ops.get_default_graph(), '_seed_used'):
                ops.get_default_graph()._seed_used = True
            if eager:
                op_seed = context.internal_operation_seed()
            else:
                op_seed = ops.get_default_graph()._last_id

        seeds = _truncate_seed(global_seed), _truncate_seed(op_seed)
    else:
        if op_seed is not None:
            seeds = DEFAULT_GRAPH_SEED, _truncate_seed(op_seed)
        else:
            seeds = None, None
    # Avoid (0, 0) as the C++ ops interpret it as nondeterminism, which would
    # be unexpected since Python docs say nondeterminism is (None, None).
    if seeds == (0, 0):
        return (0, _MAXINT32)
    return seeds
Esempio n. 6
0
def get_seed(op_seed):
  """Returns the local seeds an operation should use given an op-specific seed.

  Given operation-specific seed, `op_seed`, this helper function returns two
  seeds derived from graph-level and op-level seeds. Many random operations
  internally use the two seeds to allow user to change the seed globally for a
  graph, or for only specific operations.

  For details on how the graph-level seed interacts with op seeds, see
  `tf.compat.v1.random.set_random_seed`.

  Args:
    op_seed: integer.

  Returns:
    A tuple of two integers that should be used for the local seed of this
    operation.
  """
  eager = context.executing_eagerly()

  if eager:
    global_seed = context.global_seed()
  else:
    global_seed = ops.get_default_graph().seed

  if global_seed is not None:
    if op_seed is None:
      # pylint: disable=protected-access
      if hasattr(ops.get_default_graph(), '_seed_used'):
        ops.get_default_graph()._seed_used = True
      if eager:
        op_seed = context.internal_operation_seed()
      else:
        op_seed = ops.get_default_graph()._last_id

    seeds = _truncate_seed(global_seed), _truncate_seed(op_seed)
  else:
    if op_seed is not None:
      seeds = DEFAULT_GRAPH_SEED, _truncate_seed(op_seed)
    else:
      seeds = None, None
  # Avoid (0, 0) as the C++ ops interpret it as nondeterminism, which would
  # be unexpected since Python docs say nondeterminism is (None, None).
  if seeds == (0, 0):
    return (0, _MAXINT32)
  return seeds
Esempio n. 7
0
def SecureGetRandomSeed(op_seed=None):
    """Returns the local seeds an operation should use given an op-specific seed.
    Args:
      op_seed: integer.

    Returns:
      A tuple of two integers that should be used for the local seed of this
      operation.
    """
    eager = context.executing_eagerly()

    if eager:
        global_seed = context.global_seed()
    else:
        global_seed = ops.get_default_graph().seed

    seeds = (0, 0)
    if global_seed is not None:
        if op_seed is None:
            # pylint: disable=protected-access
            if hasattr(ops.get_default_graph(), '_seed_used'):
                ops.get_default_graph()._seed_used = True
            if eager:
                op_seed = context.internal_operation_seed()
            else:
                op_seed = ops.get_default_graph()._last_id

        seeds = random_seed._truncate_seed(
            global_seed), random_seed._truncate_seed(op_seed)
    else:
        if op_seed is not None:
            seeds = random_seed.DEFAULT_GRAPH_SEED, random_seed._truncate_seed(
                op_seed)
        else:
            if py_protocol_handler.is_activated():
                seeds = py_protocol_handler.rand_seed(
                    0), py_protocol_handler.rand_seed(0)

    # Avoid (0, 0) as the C++ ops interpret it as nondeterminism, which would
    # random_seedbe unexpected since Python docs say nondeterminism is (None, None).
    if seeds == (0, 0):
        return (0, random_seed._MAXINT32)

    return seeds
Esempio n. 8
0
    def __init__(self, name, collections=None, capture_by_value=None):
        """Construct a new FuncGraph.

    The graph will inherit its graph key, collections, seed, and distribution
    strategy stack from the current context or graph.

    Args:
      name: the name of the function.
      collections: a dictionary of collections this FuncGraph should start
        with. If not specified (None), the FuncGraph will read (but not write
        to) the outer graph's collections that are not whitelisted, and both
        read and write to the outer graph's collections that are whitelisted.
        The current whitelisted collections are the global variables, the
        local variables, and the trainable variables.
        Defaults to None.
      capture_by_value: An optional boolean. If True, the func graph will
        capture Variables by value instead of reference. By default inherit
        from outer graphs, and failing that will default to False.
    """
        super(FuncGraph, self).__init__()

        self.name = name
        self.inputs = []
        self.outputs = []
        self.control_outputs = []
        self.control_captures = set()
        self.structured_input_signature = None
        self.structured_outputs = None
        self._weak_variables = []
        self._watched_variables = weakref.WeakSet()
        self.outer_graph = ops.get_default_graph()
        self.captures = py_collections.OrderedDict()
        # Inherit capture-by-value from outer graph.
        if capture_by_value is not None:
            self.capture_by_value = capture_by_value
        elif self.outer_graph is not None and isinstance(
                self.outer_graph, FuncGraph):
            self.capture_by_value = self.outer_graph.capture_by_value
        else:
            self.capture_by_value = False

        self._building_function = True
        # Map from resource tensor name to last op (in program order) which uses
        # this tensor. Used to enforce that execution order matches program order
        # for resource tensors.
        self._last_op_using_resource_tensor = {}

        graph = self.outer_graph

        if context.executing_eagerly():
            self.seed = context.global_seed()
            # [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of
            # any None op_seed for random_op in the function, in which case we end up
            # using function seed, which could be unintended behavior for the op.
            self._seed_used = False
        else:
            self.seed = graph.seed
            self._seed_used = False
            # TODO(allenl): Figure out if we can remove colocation stack
            # specialization (currently used in cond_v2), here and in the cache key.
            self._colocation_stack = graph._colocation_stack.copy()  # pylint: disable=protected-access

        if collections is None:
            for collection_name in graph.get_all_collection_keys():
                if collection_name not in WHITELIST_COLLECTIONS:
                    self._collections[collection_name] = graph.get_collection(
                        collection_name)
            for collection_name in WHITELIST_COLLECTIONS:
                self._collections[collection_name] = graph.get_collection_ref(
                    collection_name)
        else:
            self._collections = collections
Esempio n. 9
0
  def __init__(self, name, read_only_collections=True):
    """Construct a new FuncGraph.

    The graph will inherit its graph key, collections, seed, and distribution
    strategy stack from the current context or graph.

    Args:
      name: the name of the function.
      read_only_collections: whether to not write function graph collections
        back to default graph. Defaults to True.
    """
    super(FuncGraph, self).__init__()

    self.name = name
    self.inputs = []
    self.outputs = []
    self.structured_outputs = None
    self._read_only_collections = read_only_collections
    self._weak_variables = []
    self.outer_graph = ops.get_default_graph()
    self.captures = collections.OrderedDict()

    self._building_function = True
    # Map from resource tensor name to last op (in program order) which uses
    # this tensor. Used to enforce that execution order matches program order
    # for resource tensors.
    self._last_op_using_resource_tensor = {}

    graph = self.outer_graph

    # pylint: disable=protected-access
    # TODO(b/112906995, nareshmodi): distribution strategy depends on inheriting
    # this stack from the default graph even in eager mode. Maybe it should be
    # part of the eager context? This would also allow us to remove a
    # get_default_graph() call from the function cache lookup.
    self._distribution_strategy_stack = graph._distribution_strategy_stack
    # We ignore device placements from any outer scopes while tracing the
    # function when possible, to avoid hard-coding them in the function
    # graph. "Default" placements come from the PartitionedCallOp's placement,
    # so that the same trace of the Python function may be placed on several
    # different devices and saved functions may be placed on new devices when
    # restored.
    if context.executing_eagerly():
      self.seed = context.global_seed()
      self._xla_compile = (context.context().device_spec.device_type == "TPU")
      if self._distribution_strategy_stack or self._xla_compile:
        self._add_device_to_stack(context.context().device_name)
    else:
      self.seed = graph.seed
      self._xla_compile = getattr(graph, "_xla_compile", False)
      # TODO(allenl): Figure out if we can remove colocation stack
      # specialization (currently used in cond_v2), here and in the cache key.
      self._colocation_stack = graph._colocation_stack.copy()
      if (self._distribution_strategy_stack
          or self._xla_compile
          or device_stack_has_callable(graph._device_function_stack)):
        # Hard-code devices from device functions in the function body
        self._device_function_stack = graph._device_function_stack.copy()
    if not self._read_only_collections:
      self._collections = graph._collections
    else:
      for collection_name in graph.get_all_collection_keys():
        if collection_name not in WHITELIST_COLLECTIONS:
          self._collections[collection_name] = graph.get_collection(
              collection_name)
      for collection_name in WHITELIST_COLLECTIONS:
        self._collections[collection_name] = graph.get_collection_ref(
            collection_name)

    self._variable_creator_stack = graph._variable_creator_stack
    # Inherit the graph key, since this is used for matching variables in
    # optimizers.
    self._graph_key = graph._graph_key
Esempio n. 10
0
    def __init__(self, name, read_only_collections=True):
        """Construct a new FuncGraph.

    The graph will inherit its graph key, collections, seed, and distribution
    strategy stack from the current context or graph.

    Args:
      name: the name of the function.
      read_only_collections: whether to not write function graph collections
        back to default graph. Defaults to True.
    """
        super(FuncGraph, self).__init__()

        self.name = name
        self.inputs = []
        self.outputs = []
        self.structured_outputs = None
        self._read_only_collections = read_only_collections
        self._weak_variables = []
        self.outer_graph = ops.get_default_graph()
        self.captures = collections.OrderedDict()

        self._building_function = True
        # Map from resource tensor name to last op (in program order) which uses
        # this tensor. Used to enforce that execution order matches program order
        # for resource tensors.
        self._last_op_using_resource_tensor = {}

        graph = self.outer_graph

        # pylint: disable=protected-access
        # TODO(b/112906995, nareshmodi): distribution strategy depends on inheriting
        # this stack from the default graph even in eager mode. Maybe it should be
        # part of the eager context? This would also allow us to remove a
        # get_default_graph() call from the function cache lookup.
        self._distribution_strategy_stack = list(
            graph._distribution_strategy_stack)
        # We ignore device placements from any outer scopes while tracing the
        # function when possible, to avoid hard-coding them in the function
        # graph. "Default" placements come from the PartitionedCallOp's placement,
        # so that the same trace of the Python function may be placed on several
        # different devices and saved functions may be placed on new devices when
        # restored.
        if context.executing_eagerly():
            self.seed = context.global_seed()
            self._xla_compile = (
                context.context().device_spec.device_type == "TPU")
            if self._distribution_strategy_stack or self._xla_compile:
                self._add_device_to_stack(context.context().device_name)
        else:
            self.seed = graph.seed
            self._xla_compile = getattr(graph, "_xla_compile", False)
            # TODO(allenl): Figure out if we can remove colocation stack
            # specialization (currently used in cond_v2), here and in the cache key.
            self._colocation_stack = graph._colocation_stack.copy()
            if (self._distribution_strategy_stack or self._xla_compile or
                    device_stack_has_callable(graph._device_function_stack)):
                # Hard-code devices from device functions in the function body
                self._device_function_stack = graph._device_function_stack.copy(
                )
        if not self._read_only_collections:
            self._collections = graph._collections
        else:
            for collection_name in graph.get_all_collection_keys():
                if collection_name not in WHITELIST_COLLECTIONS:
                    self._collections[collection_name] = graph.get_collection(
                        collection_name)
            for collection_name in WHITELIST_COLLECTIONS:
                self._collections[collection_name] = graph.get_collection_ref(
                    collection_name)

        self._variable_creator_stack = graph._variable_creator_stack
        # Inherit the graph key, since this is used for matching variables in
        # optimizers.
        self._graph_key = graph._graph_key
Esempio n. 11
0
  def __init__(self, name, collections=None, capture_by_value=None):
    """Construct a new FuncGraph.

    The graph will inherit its graph key, collections, seed, and distribution
    strategy stack from the current context or graph.

    Args:
      name: the name of the function.
      collections: a dictionary of collections this FuncGraph should start
        with. If not specified (None), the FuncGraph will read (but not write
        to) the outer graph's collections that are not whitelisted, and both
        read and write to the outer graph's collections that are whitelisted.
        The current whitelisted collections are the global variables, the
        local variables, and the trainable variables.
        Defaults to None.
      capture_by_value: An optional boolean. If True, the func graph will
        capture Variables by value instead of reference. By default inherit
        from outer graphs, and failing that will default to False.
    """
    super(FuncGraph, self).__init__()

    self.name = name
    self.inputs = []
    self.outputs = []
    self.control_outputs = []
    self.control_captures = set()
    self.structured_input_signature = None
    self.structured_outputs = None
    self._weak_variables = []
    self._watched_variables = weakref.WeakSet()
    self.outer_graph = ops.get_default_graph()
    self.captures = py_collections.OrderedDict()
    # Inherit capture-by-value from outer graph.
    if capture_by_value is not None:
      self.capture_by_value = capture_by_value
    elif self.outer_graph is not None and isinstance(
        self.outer_graph, FuncGraph):
      self.capture_by_value = self.outer_graph.capture_by_value
    else:
      self.capture_by_value = False

    self._building_function = True
    # Map from resource tensor name to last op (in program order) which uses
    # this tensor. Used to enforce that execution order matches program order
    # for resource tensors.
    self._last_op_using_resource_tensor = {}

    graph = self.outer_graph

    if context.executing_eagerly():
      self.seed = context.global_seed()
      # [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of
      # any None op_seed for random_op in the function, in which case we end up
      # using function seed, which could be unintended behavior for the op.
      self._seed_used = False
    else:
      self.seed = graph.seed
      self._seed_used = False
      # TODO(allenl): Figure out if we can remove colocation stack
      # specialization (currently used in cond_v2), here and in the cache key.
      self._colocation_stack = graph._colocation_stack.copy()  # pylint: disable=protected-access

    if collections is None:
      for collection_name in graph.get_all_collection_keys():
        if collection_name not in WHITELIST_COLLECTIONS:
          self._collections[collection_name] = graph.get_collection(
              collection_name)
      for collection_name in WHITELIST_COLLECTIONS:
        self._collections[collection_name] = graph.get_collection_ref(
            collection_name)
    else:
      self._collections = collections