Exemplo n.º 1
0
    def _check_same_graph(self):
        """Checks that the module is not being connect to multiple Graphs.

    An instance of a Sonnet module 'owns' the variables it contains, and permits
    seamless variable sharing. As such, connecting a single module instance to
    multiple Graphs is not possible - this function will raise an error should
    that occur.

    Raises:
      DifferentGraphError: if the module is connected to a different Graph than
        it was previously used in.
    """
        with ops.init_scope():
            # We need `init_scope` incase we're running inside a defun. In that case
            # what we want is information about where the function will be called not
            # where the function is being built.
            current_graph = tf.get_default_graph()
            will_call_in_eager_context = tf.executing_eagerly()

        if self._graph is None:
            self._graph = current_graph
            self._set_module_info()

        if not will_call_in_eager_context:
            # Same graph checks only make sense when calling from graph mode (in eager
            # mode there is a single process level context where all modules are
            # created).
            if self._graph != current_graph:
                raise DifferentGraphError(
                    "Cannot connect module to multiple Graphs.")
Exemplo n.º 2
0
  def _check_same_graph(self):
    """Checks that the module is not being connect to multiple Graphs.

    An instance of a Sonnet module 'owns' the variables it contains, and permits
    seamless variable sharing. As such, connecting a single module instance to
    multiple Graphs is not possible - this function will raise an error should
    that occur.

    Raises:
      DifferentGraphError: if the module is connected to a different Graph than
        it was previously used in.
    """
    current_graph = tf.get_default_graph()
    if self._graph is None:
      self._graph = current_graph
      self._set_module_info()
    elif self._graph != current_graph:
      raise DifferentGraphError("Cannot connect module to multiple Graphs.")