Beispiel #1
0
    def variables(self):
        """**All** `tf.Variable`s used when the module is connected.

    This property does not rely on global collections and should generally be
    preferred vs. `get_variables` and `get_all_variables`.

    See the documentation for `AbstractModule._capture_variables()` for more
    information about what variables are captured.

    Returns:
      A sorted (by variable name) tuple of `tf.Variable` objects.

    Raises:
      NotConnectedError: If the module is not connected to the Graph.
    """
        self._ensure_is_connected()
        return util.sort_by_name(self._all_variables)
Beispiel #2
0
  def variables(self):
    """**All** `tf.Variable`s used when the module is connected.

    This property does not rely on global collections and should generally be
    preferred vs. `get_variables` and `get_all_variables`.

    See the documentation for `AbstractModule._capture_variables()` for more
    information about what variables are captured.

    Returns:
      A sorted (by variable name) tuple of `tf.Variable` objects.

    Raises:
      NotConnectedError: If the module is not connected to the Graph.
    """
    self._ensure_is_connected()
    return util.sort_by_name(self._all_variables)
Beispiel #3
0
    def get_all_variables(self, collection=tf.GraphKeys.TRAINABLE_VARIABLES):
        """Returns all `tf.Variable`s used when the module is connected.

    See the documentation for `AbstractModule._capture_variables()` for more
    information.

    Args:
      collection: Collection to restrict query to. By default this is
        `tf.Graphkeys.TRAINABLE_VARIABLES`, which doesn't include non-trainable
        variables such as moving averages.

    Returns:
      A sorted (by variable name) tuple of `tf.Variable` objects.

    Raises:
      NotConnectedError: If the module is not connected to the Graph.
    """
        self._ensure_is_connected()
        collection_variables = set(tf.get_collection(collection))
        # Return variables in self._all_variables that are in `collection`
        return util.sort_by_name(self._all_variables & collection_variables)
Beispiel #4
0
  def get_all_variables(self, collection=tf.GraphKeys.TRAINABLE_VARIABLES):
    """Returns all `tf.Variable`s used when the module is connected.

    See the documentation for `AbstractModule._capture_variables()` for more
    information.

    Args:
      collection: Collection to restrict query to. By default this is
        `tf.Graphkeys.TRAINABLE_VARIABLES`, which doesn't include non-trainable
        variables such as moving averages.

    Returns:
      A sorted (by variable name) tuple of `tf.Variable` objects.

    Raises:
      NotConnectedError: If the module is not connected to the Graph.
    """
    self._ensure_is_connected()
    collection_variables = set(tf.get_collection(collection))
    # Return variables in self._all_variables that are in `collection`
    return util.sort_by_name(self._all_variables & collection_variables)