Ejemplo n.º 1
0
  def wrapper(obj, *args, **kwargs):
    def default_context_manager(reuse=None):
      variable_scope = obj.variable_scope
      return tf.variable_scope(variable_scope, reuse=reuse)

    variable_scope_context_manager = getattr(obj, "_enter_variable_scope",
                                             default_context_manager)
    graph = tf.get_default_graph()

    # Temporarily enter the variable scope to capture it
    with variable_scope_context_manager() as tmp_variable_scope:
      variable_scope = tmp_variable_scope

    with variable_scope_ops._pure_variable_scope(
        variable_scope, reuse=tf.AUTO_REUSE) as pure_variable_scope:

      name_scope = variable_scope.original_name_scope
      if name_scope[-1] != "/":
        name_scope += "/"

      with tf.name_scope(name_scope):
        sub_scope = snt_util.to_snake_case(method.__name__)
        with tf.name_scope(sub_scope) as scope:
          out_ops = method(obj, *args, **kwargs)
          return out_ops
Ejemplo n.º 2
0
  def wrapper(obj, *args, **kwargs):
    def default_context_manager(reuse=None):
      variable_scope = obj.variable_scope
      return tf.variable_scope(variable_scope, reuse=reuse)

    variable_scope_context_manager = getattr(obj, "_enter_variable_scope",
                                             default_context_manager)
    graph = tf.get_default_graph()

    # Temporarily enter the variable scope to capture it
    with variable_scope_context_manager() as tmp_variable_scope:
      variable_scope = tmp_variable_scope

    with variable_scope_ops._pure_variable_scope(
        variable_scope, reuse=tf.AUTO_REUSE) as pure_variable_scope:

      name_scope = variable_scope.original_name_scope
      if name_scope[-1] != "/":
        name_scope += "/"

      with tf.name_scope(name_scope):
        sub_scope = snt_util.to_snake_case(method.__name__)
        with tf.name_scope(sub_scope) as scope:
          out_ops = method(obj, *args, **kwargs)
          return out_ops
Ejemplo n.º 3
0
 def testToSnakeCase(self):
   test_cases = [
       ("UpperCamelCase", "upper_camel_case"),
       ("lowerCamelCase", "lower_camel_case"),
       ("endsWithXYZ", "ends_with_xyz"),
       ("already_snake_case", "already_snake_case"),
       ("__private__", "private"),
       ("LSTMModule", "lstm_module"),
       ("version123p56vfxObject", "version_123p56vfx_object"),
       ("version123P56VFXObject", "version_123p56vfx_object"),
       ("versionVFX123P56Object", "version_vfx123p56_object"),
       ("versionVfx123P56Object", "version_vfx_123p56_object"),
       ("lstm1", "lstm_1"),
       ("LSTM1", "lstm1"),
   ]
   for camel_case, snake_case in test_cases:
     actual = util.to_snake_case(camel_case)
     self.assertEqual(actual, snake_case, "_to_snake_case(%s) -> %s != %s" %
                      (camel_case, actual, snake_case))
Ejemplo n.º 4
0
 def testToSnakeCase(self):
   test_cases = [
       ("UpperCamelCase", "upper_camel_case"),
       ("lowerCamelCase", "lower_camel_case"),
       ("endsWithXYZ", "ends_with_xyz"),
       ("already_snake_case", "already_snake_case"),
       ("__private__", "private"),
       ("LSTMModule", "lstm_module"),
       ("version123p56vfxObject", "version_123p56vfx_object"),
       ("version123P56VFXObject", "version_123p56vfx_object"),
       ("versionVFX123P56Object", "version_vfx123p56_object"),
       ("versionVfx123P56Object", "version_vfx_123p56_object"),
       ("lstm1", "lstm_1"),
       ("LSTM1", "lstm1"),
   ]
   for camel_case, snake_case in test_cases:
     actual = util.to_snake_case(camel_case)
     self.assertEqual(actual, snake_case, "_to_snake_case(%s) -> %s != %s" %
                      (camel_case, actual, snake_case))
Ejemplo n.º 5
0
    def __init__(self, _sentinel=None, custom_getter=None, name=None):  # pylint: disable=invalid-name
        """Performs the initialisation necessary for all AbstractModule instances.

    Every subclass of AbstractModule must begin their constructor with a call to
    this constructor, i.e.

    `super(MySubModule, self).__init__(custom_getter=custom_getter, name=name)`.

    If you instantiate sub-modules in __init__ you must create them within the
    `_enter_variable_scope` context manager to ensure they are in the module's
    variable scope. Alternatively, instantiate sub-modules in `_build`.

    Args:
      _sentinel: Variable that only carries a non-None value if `__init__` was
          called without named parameters. If this is the case, a deprecation
          warning is issued in form of a `ValueError`.
      custom_getter: Callable or dictionary of callables to use as
        custom getters inside the module. If a dictionary, the keys
        correspond to regexes to match variable names. See the `tf.get_variable`
        documentation for information about the custom_getter API.
      name: Name of this module. Used to construct the Templated build function.
          If `None` the module's class name is used (converted to snake case).

    Raises:
      TypeError: If `name` is not a string.
      TypeError: If a given `custom_getter` is not callable.
      ValueError: If `__init__` was called without named arguments.
    """
        if _sentinel is not None:
            raise ValueError("Calling AbstractModule.__init__ without named "
                             "arguments is not supported.")

        if name is None:
            name = util.to_snake_case(self.__class__.__name__)
        elif not isinstance(name, six.string_types):
            raise TypeError("Name must be a string, not {} of type {}.".format(
                name, type(name)))

        self._is_connected = False
        self._connected_subgraphs = []

        # If the given custom getter is a dictionary with a per-variable custom
        # getter, wrap it into a single custom getter.
        if isinstance(custom_getter, collections.Mapping):
            self._custom_getter = util.custom_getter_router(
                custom_getter_map=custom_getter,
                name_fn=lambda name: name[len(self.scope_name) + 1:])
        elif custom_getter is not None and not callable(custom_getter):
            raise TypeError("Given custom_getter is not callable.")
        else:
            self._custom_getter = custom_getter

        self._template = tf.make_template(name,
                                          self._build_wrapper,
                                          create_scope_now_=True,
                                          custom_getter_=self._custom_getter)

        self._original_name = name
        self._unique_name = self._template.variable_scope.name.split("/")[-1]

        # Update __call__ and the object docstrings to enable better introspection.
        self.__doc__ = self._build.__doc__
        self.__call__.__func__.__doc__ = self._build.__doc__

        # Keep track of which graph this module has been connected to. Sonnet
        # modules cannot be connected to multiple graphs, as transparent variable
        # sharing is impossible in that case.
        self._graph = None

        # Container for all variables created in this module and its sub-modules.
        self._all_variables = set([])

        # Calling `.defun()` causes the module's call method to become wrapped as
        # a graph function.
        self._defun_wrapped = False
Ejemplo n.º 6
0
  def __init__(self, _sentinel=None, custom_getter=None,
               name=None):  # pylint: disable=invalid-name
    """Performs the initialisation necessary for all AbstractModule instances.

    Every subclass of AbstractModule must begin their constructor with a call to
    this constructor, i.e.

    `super(MySubModule, self).__init__(custom_getter=custom_getter, name=name)`.

    If you instantiate sub-modules in __init__ you must create them within the
    `_enter_variable_scope` context manager to ensure they are in the module's
    variable scope. Alternatively, instantiate sub-modules in `_build`.

    Args:
      _sentinel: Variable that only carries a non-None value if `__init__` was
          called without named parameters. If this is the case, a deprecation
          warning is issued in form of a `ValueError`.
      custom_getter: Callable or dictionary of callables to use as
        custom getters inside the module. If a dictionary, the keys
        correspond to regexes to match variable names. See the `tf.get_variable`
        documentation for information about the custom_getter API.
      name: Name of this module. Used to construct the Templated build function.
          If `None` the module's class name is used (converted to snake case).

    Raises:
      TypeError: If `name` is not a string.
      TypeError: If a given `custom_getter` is not callable.
      ValueError: If `__init__` was called without named arguments.
    """
    if _sentinel is not None:
      raise ValueError("Calling AbstractModule.__init__ without named "
                       "arguments is not supported.")

    if name is None:
      name = util.to_snake_case(self.__class__.__name__)
    elif not isinstance(name, six.string_types):
      raise TypeError("Name must be a string.")

    self._connected_subgraphs = []

    # If the given custom getter is a dictionary with a per-variable custom
    # getter, wrap it into a single custom getter.
    if isinstance(custom_getter, collections.Mapping):
      self._custom_getter = util.custom_getter_router(
          custom_getter_map=custom_getter,
          name_fn=lambda name: name[len(self.scope_name) + 1:])
    else:
      if not (custom_getter is None or callable(custom_getter)):
        raise TypeError("Given custom_getter is not callable.")
      self._custom_getter = custom_getter
    self._custom_getter = _maybe_wrap_custom_getter(
        _variable_tracking_custom_getter, self._custom_getter)

    self._template = tf.make_template(name,
                                      self._build_wrapper,
                                      create_scope_now_=True,
                                      custom_getter_=self._custom_getter)

    self._original_name = name
    self._unique_name = self._template.variable_scope.name.split("/")[-1]

    # Update __call__ and the object docstrings to enable better introspection.
    self.__doc__ = self._build.__doc__
    self.__call__.__func__.__doc__ = self._build.__doc__

    # Keep track of which graph this module has been connected to. Sonnet
    # modules cannot be connected to multiple graphs, as transparent variable
    # sharing is impossible in that case.
    self._graph = None

    # Container for all variables created in this module and its sub-modules.
    self._all_variables = set([])